🚀 文本分類ONNX模型
本項目提供了一個基於ONNX的文本分類模型,可用於多類別和多標籤的情感分類任務。該模型基於RoBERTa架構,在go_emotions數據集上進行訓練,具有高精度和快速推理的特點。
✨ 主要特性
- ONNX版本:提供全精度和INT8量化兩種ONNX版本,可根據需求選擇。
- 高精度:與原始Transformers模型具有相同的準確率和指標。
- 快速推理:在推理速度上比普通Transformers模型更快,尤其是對於小批量數據。
- 小模型尺寸:量化版本的模型尺寸僅為全精度模型的四分之一。
📦 安裝指南
本項目未提及具體安裝步驟,可根據使用的庫(如transformers
、optimum
、onnxruntime
等)進行安裝。例如,使用以下命令安裝所需庫:
pip install transformers optimum onnxruntime
💻 使用示例
基礎用法
使用Optimum庫的ONNX類進行文本分類:
sentences = ["ONNX is seriously fast for small batches. Impressive"]
from transformers import AutoTokenizer, pipeline
from optimum.onnxruntime import ORTModelForSequenceClassification
model_id = "SamLowe/roberta-base-go_emotions-onnx"
file_name = "onnx/model_quantized.onnx"
model = ORTModelForSequenceClassification.from_pretrained(model_id, file_name=file_name)
tokenizer = AutoTokenizer.from_pretrained(model_id)
onnx_classifier = pipeline(
task="text-classification",
model=model,
tokenizer=tokenizer,
top_k=None,
function_to_apply="sigmoid",
)
model_outputs = onnx_classifier(sentences)
print(model_outputs)
高級用法
使用ONNXRuntime進行文本分類:
from tokenizers import Tokenizer
import onnxruntime as ort
from os import cpu_count
import numpy as np
sentences = ["hello world"]
labels = ['admiration', 'amusement', 'anger', 'annoyance', 'approval', 'caring', 'confusion', 'curiosity', 'desire', 'disappointment', 'disapproval', 'disgust', 'embarrassment', 'excitement', 'fear', 'gratitude', 'grief', 'joy', 'love', 'nervousness', 'optimism', 'pride', 'realization', 'relief', 'remorse', 'sadness', 'surprise', 'neutral']
tokenizer = Tokenizer.from_pretrained("SamLowe/roberta-base-go_emotions")
params = {**tokenizer.padding, "length": None}
tokenizer.enable_padding(**params)
tokens_obj = tokenizer.encode_batch(sentences)
def load_onnx_model(model_filepath):
_options = ort.SessionOptions()
_options.inter_op_num_threads, _options.intra_op_num_threads = cpu_count(), cpu_count()
_providers = ["CPUExecutionProvider"]
return ort.InferenceSession(path_or_bytes=model_filepath, sess_options=_options, providers=_providers)
model = load_onnx_model("path_to_model_dot_onnx_or_model_quantized_dot_onnx")
output_names = [model.get_outputs()[0].name]
input_feed_dict = {
"input_ids": [t.ids for t in tokens_obj],
"attention_mask": [t.attention_mask for t in tokens_obj]
}
logits = model.run(output_names=output_names, input_feed=input_feed_dict)[0]
def sigmoid(x):
return 1.0 / (1.0 + np.exp(-x))
model_outputs = sigmoid(logits)
for probas in model_outputs:
top_result_index = np.argmax(probas)
print(labels[top_result_index], "with score:", probas[top_result_index])
📚 詳細文檔
模型版本
- 全精度ONNX版本:
onnx/model.onnx
,與原始Transformers模型具有相同的準確率和指標,模型大小為499MB,在推理速度上比普通Transformers模型更快,尤其是對於小批量數據。
- 量化(INT8)ONNX版本:
onnx/model_quantized.onnx
,模型尺寸為125MB,幾乎保留了全精度模型的所有準確率,推理速度比全精度ONNX版本和普通Transformers模型更快。
指標
使用固定閾值0.5將分數轉換為每個標籤的二進制預測:
模型版本 |
準確率 |
精確率 |
召回率 |
F1值 |
全精度ONNX版本 |
0.474 |
0.575 |
0.396 |
0.450 |
量化(INT8)ONNX版本 |
0.475 |
0.582 |
0.398 |
0.447 |
示例筆記本
後續將提供包含更多使用細節、準確率和性能展示的筆記本。
📄 許可證
本項目採用MIT許可證。