Machine learning/NLP

[NLP]. SentenceTransformer Tokenize 멀티턴 형식으로 수정하기

Acdong 2022. 12. 22. 17:40
728x90

 

def origin_tokenize(self, texts: Union[List[str], List[Dict], List[Tuple[str, str]]]):
    """
    Tokenizes the texts
    """
    return self._first_module().tokenize(texts)

def tokenize(self, texts: Union[List[str], List[Dict], List[Tuple[str, str]]]):
    """
    Tokenization is Mutiturn Utterance Custom
    안녕 [SEP] 뭐해 ㅋㅋㅋㅋㅋ [SEP] 나 집에서 넷플릭스 보고있지
    """
    encoded_dict = self.origin_tokenize(texts)
    idx = []
    input_ids = encoded_dict['input_ids'][0].tolist()
    for i in range(len(input_ids)):
        if input_ids[i] == 3 and len(input_ids) - 1 != i:
            idx.append(i)

    token_type_id = []
    if len(idx) == 2:
        for i in range(len(input_ids)):
            if i <= idx[0]:
                token_type_id.append(0)
            elif i <= idx[1]:
                token_type_id.append(1)
            else:
                token_type_id.append(0)

        encoded_dict['token_type_ids'] = torch.unsqueeze(
            torch.tensor(token_type_id), 0)
    return encoded_dict

 

 

def custom_tokenizer(sent, MAX_LEN):
    encoded_dict = tokenizer.encode_plus(
      text = sent,
      add_special_tokens = True, # 시작점에 CLS, 끝점에 SEP가 추가된다.
      max_length = MAX_LEN,
      pad_to_max_length = True,
      return_attention_mask = True,
      truncation=True
    )
    
    idx = []
    input_ids = encoded_dict['input_ids'][0].tolist()
    for i in range(len(input_ids)):
        if input_ids[i] == 3 and len(input_ids) - 1 != i:
            idx.append(i)

    token_type_id = []
    if len(idx) == 2:
        for i in range(len(input_ids)):
            if i <= idx[0]:
                token_type_id.append(0)
            elif i <= idx[1]:
                token_type_id.append(1)
            else:
                token_type_id.append(0)

        encoded_dict['token_type_ids'] = torch.unsqueeze(
            torch.tensor(token_type_id), 0)
    return encoded_dict
반응형