Machine learning/NLP
[NLP]. SentenceTransformer 모델 TensorFlow로 불러오기
Acdong
2022. 12. 12. 23:32
728x90
참고 : https://www.philschmid.de/tensorflow-sentence-transformers
HuggingFace 에는 Tensorflow 모델 형식인 h5 파일이 없는상태
h5 모델이 없는 상태에서도 Tensorflow 모델로 불러올 수 있다.
클래스 구현
import tensorflow as tf
from typing import Union , List
from transformers import TFAutoModel
from transformers import AutoTokenizer
class TFSentenceTransformer(tf.keras.layers.Layer):
def __init__(self, model_name_or_path):
super(TFSentenceTransformer, self).__init__()
# loads transformers model
self.model = TFAutoModel.from_pretrained(model_name_or_path, from_pt=True)
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
def call(self, inputs, normalize=False):
# runs model on inputs
model_output = self.model(inputs)
# Perform pooling. In this case, mean pooling.
embeddings = self.mean_pooling(model_output, inputs["attention_mask"])
# normalizes the embeddings if wanted
if normalize:
embeddings = self.normalize(embeddings)
return embeddings
def encode(self, sentence : Union[str,List], normalize=False):
inputs = self.tokenizer(sentence, padding=True, truncation=True, return_tensors='tf')
features = self.call(inputs, normalize=normalize)
if type(sentence) == str:
return features[0]
return features
def mean_pooling(self, model_output, attention_mask):
token_embeddings = model_output[0] # First element of model_output contains all token embeddings
input_mask_expanded = tf.cast(
tf.broadcast_to(tf.expand_dims(attention_mask, -1), tf.shape(token_embeddings)),
tf.float32
)
return tf.math.reduce_sum(token_embeddings * input_mask_expanded, axis=1) / tf.clip_by_value(tf.math.reduce_sum(input_mask_expanded, axis=1), 1e-9, tf.float32.max)
def normalize(self, embeddings):
embeddings, _ = tf.linalg.normalize(embeddings, 2, axis=1)
return embeddings
임베딩
model_id = 'j5ng/sentence-klue-roberta-base'
model = TFSentenceTransformer(model_id)
print(model.encode("안녕하세요")[:5])
tf.Tensor([-0.2527147 -0.05963629 0.16842306 0.15223232 0.5281706 ], shape=(5,), dtype=float32)
반응형