BackGround
우리 회사는 SentenceTransformer를 기반으로 파인튜닝한 문장 임베딩 모델을 사용하고 있다.
하지만 모델의 크기가 커질 수록 임베딩 시간은 늘어나고 많은 연산량을 요구하게된다.
그렇다고 모델의 크기를 줄이면 정확도가 떨어진다.
하지만 ONNX는 Inferance 속도를 최대한으로 높히면서 정확도 손실을 최소화하는 여러 가지 기능을 가지고있다.
Sentence-Transfomer 모델을 ONNX Runtime으로 변환하면서 얻었던 장점들을 정리해보고자 한다.
What is ONNX ?
ONNX는 Open Neural Network Exchange의 줄인 말로서
이름과 같이 다른 DNN 프레임워크 환경(ex Tensorflow, PyTorch, etc..)에서 만들어진
모델들을 서로 호환되게 사용할 수 있도록 만들어진 공유 플랫폼이다.
ps. ONNX 또한 DNN 프레임워크라고 부른다.
ONNX의 장점으로는 크게 두 가지가 있다.
1.상호 운용성
: onnx는 여러가지 딥러닝 프레임워크를 변환하여 추론 엔진으로 사용할 수있다.
2. 하드웨어 엑세스
: onnx runtime을 사용하면 하드웨어 최적화에 더 쉽게 접근이 가능하다.
모델링은 익숙한 pytorch , keras로 하고 서빙에 최적화된 ONNX로 변환하여 서빙
How to use Onnx Runtime?
Environment
Python3.9
Ubuntu20.04
1. Pytorch(Sbert)모델을 ONNX 모델로 Export하기
from pathlib import Path
import transformers
from transformers.convert_graph_to_onnx import convert
convert(framework="pt", model="reppley/sentence-roberta-base",
output=Path("onnx_models/sentence-roberta-base.onnx"), opset=11)
* 여기서 "reppley/sentence-roberta-base"는 SentenceTransformer 모델이다.
실제로는 모델의 경로를 넣어주면되고 위 코드처럼 텍스트만 넣어주면 HuggingFace에서 모델을 찾아 로드해준다.
2. 양자화(Quantization) - 옵션
from onnxruntime.quantization import quantize_dynamic, QuantType
quantize_dynamic("onnx_models/sentence-roberta-base.onnx", "onnx_models/sentence-roberta-base_uint8.onnx",
weight_type=QuantType.QUInt8)
ONNX는 모델의 가중치를 Uint8 형식으로 양자화 할수있다. 양자화를 진행하면 모델 크기를 1/4로 줄일 수 있고
추론 속도도 훨씬 빨라진다. ( But 정확도 손실이 있다.)
*Weight 값을 FP32 에서 UINT8(0~255)로 맵핑하는 기법
3. 추론하기
3-1. 모델 불러오기
from transformers import AutoTokenizer
from onnxruntime import InferenceSession
from sentence_transformers import SentenceTransformer
import torch
model = SentenceTransformer('reppley/sentence-roberta-base')
sess = InferenceSession("onnx_models/sentence-roberta-base.onnx",
providers=["CPUExecutionProvider"])
sess_uint8 = InferenceSession("onnx_models/sentence-roberta-base_uint8.onnx",
providers=["CPUExecutionProvider"])
비교를 위해 기존 Sbert 모델, ONNX모델 , ONNX 양자화 모델 3개를 불러왔다.
3-2. 풀링 함수
def mean_pooling(model_output, attention_mask):
model_output = torch.from_numpy(model_output[0])
# First element of model_output contains all token embeddings
token_embeddings = model_output
attention_mask = torch.from_numpy(attention_mask)
input_mask_expanded = attention_mask.unsqueeze(
-1).expand(token_embeddings.size())
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
return sum_embeddings / sum_mask, input_mask_expanded, sum_mask
SentenceTransfomer 모델은 자동으로 풀링을해서 임베딩해주지만 ONNX 모델은 BERT모델과 비슷하기때문에
둘의 임베딩 값을 맞추려면 ONNX 모델 결과의 풀링을 거쳐야한다.
Model Test
기존 SentenceTransformer 모델
%%timeit
query = "안녕하세요"
model.encode(query, convert_to_tensor=True, device='cpu')[:5]
12.6 ms ± 476 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
tensor([ 0.2339, -0.1364, 0.6490, -0.2782, -0.1374])
SentenceTransformer - 임베딩 시간 평균(100개) : 0.12 초
ONNX 변환 모델
%%timeit
query = "안녕하세요"
model_inputs = tokenizer(query, return_tensors="pt")
inputs_onnx = {k: v.cpu().detach().numpy() for k, v in model_inputs.items()}
mean_pooling(sess.run(None, inputs_onnx),inputs_onnx['attention_mask'])[0][0][:5]
7.97 ms ± 84.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
tensor([ 0.2339, -0.1364, 0.6490, -0.2782, -0.1374])
ONNX - 임베딩 시간 평균(100걔) : 0.08초
%%timeit
query = "안녕하세요"
model_inputs = tokenizer(query, return_tensors="pt")
inputs_onnx = {k: v.cpu().detach().numpy() for k, v in model_inputs.items()}
mean_pooling(sess_uint8.run(None, inputs_onnx),inputs_onnx['attention_mask'])[0][0][:5]
2.66 ms ± 128 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
tensor([ 0.2372, -0.1313, 0.6409, -0.2680, -0.1305])
ONNX Uint8 - 임베딩 시간 평균(100개) : 0.02초
Conclusion
ONNX Runtime을 사용해서 추론 속도를 엄청나게(최대 6배) 감소시켰다. 하지만 양자화를 사용할 경우 결과값이 달라지는 것을 볼수있다. 정확도 감소가 조금 있을 것으로 예상되지만 그래도 속도 차이가 많이 나서 양자화 모델을 최종 서빙모델로 선정했다. 속도가 크게 중요하지 않은 Task에서는 ONNX로 모델만 변환해서 사용해도 좋을 것이다.
Reference
https://learn.microsoft.com/ko-kr/windows/ai/windows-ml/tutorials/pytorch-convert-model
https://beeny-ds.tistory.com/22
https://onnx.ai/
https://www.youtube.com/watch?v=MCafgeqWMhQ