728x90
참고 : https://www.philschmid.de/tensorflow-sentence-transformers
HuggingFace 에는 Tensorflow 모델 형식인 h5 파일이 없는상태
h5 모델이 없는 상태에서도 Tensorflow 모델로 불러올 수 있다.
클래스 구현
import tensorflow as tf
from typing import Union , List
from transformers import TFAutoModel
from transformers import AutoTokenizer
class TFSentenceTransformer(tf.keras.layers.Layer):
def __init__(self, model_name_or_path):
super(TFSentenceTransformer, self).__init__()
# loads transformers model
self.model = TFAutoModel.from_pretrained(model_name_or_path, from_pt=True)
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
def call(self, inputs, normalize=False):
# runs model on inputs
model_output = self.model(inputs)
# Perform pooling. In this case, mean pooling.
embeddings = self.mean_pooling(model_output, inputs["attention_mask"])
# normalizes the embeddings if wanted
if normalize:
embeddings = self.normalize(embeddings)
return embeddings
def encode(self, sentence : Union[str,List], normalize=False):
inputs = self.tokenizer(sentence, padding=True, truncation=True, return_tensors='tf')
features = self.call(inputs, normalize=normalize)
if type(sentence) == str:
return features[0]
return features
def mean_pooling(self, model_output, attention_mask):
token_embeddings = model_output[0] # First element of model_output contains all token embeddings
input_mask_expanded = tf.cast(
tf.broadcast_to(tf.expand_dims(attention_mask, -1), tf.shape(token_embeddings)),
tf.float32
)
return tf.math.reduce_sum(token_embeddings * input_mask_expanded, axis=1) / tf.clip_by_value(tf.math.reduce_sum(input_mask_expanded, axis=1), 1e-9, tf.float32.max)
def normalize(self, embeddings):
embeddings, _ = tf.linalg.normalize(embeddings, 2, axis=1)
return embeddings
임베딩
model_id = 'j5ng/sentence-klue-roberta-base'
model = TFSentenceTransformer(model_id)
print(model.encode("안녕하세요")[:5])
tf.Tensor([-0.2527147 -0.05963629 0.16842306 0.15223232 0.5281706 ], shape=(5,), dtype=float32)
반응형
'Machine learning > NLP' 카테고리의 다른 글
[NLP]. SentenceTransformer Tokenize 멀티턴 형식으로 수정하기 (0) | 2022.12.22 |
---|---|
[NLP]. 텍스트 데이터 정제(이모지 , 특수문자, url , 한자 제거) (0) | 2022.12.21 |
[NLP]. Sentence-Transformer 모델 onnx 형식으로 변환하기 (0) | 2022.12.12 |
[NLP]. 오타 생성기 구현하기 : Text Noise Augmentation (1) | 2022.10.29 |
[NLP]. 챗봇 답변 Top-k sampling 구현 (0) | 2022.09.27 |