模型概述
模型特點
模型能力
使用案例
🚀 RoBERTa-base情感多標籤分類模型
本項目基於roberta-base模型,在go_emotions數據集上進行訓練,實現多標籤文本分類,可對文本中的多種情感進行識別和分類。
✨ 主要特性
- 多標籤分類:能夠對輸入文本同時預測多個情感標籤。
- ONNX版本支持:提供ONNX格式的模型,包括INT8量化版本,推理速度更快,依賴更小,跨平臺性更好。
- 多方式推理:支持在Huggingface Transformers中通過多種方式使用,如使用pipeline進行推理。
📚 詳細文檔
模型概述
該模型基於roberta-base,在go_emotions數據集上進行多標籤分類訓練。
ONNX版本
該模型的ONNX格式版本(包括INT8量化的ONNX版本)可在https://huggingface.co/SamLowe/roberta-base-go_emotions-onnx獲取。這些版本在推理時更快,尤其是對於小批量數據,能大幅減少推理所需依賴的大小,使模型推理更具跨平臺性。對於量化版本,如果僅需要推理,在保留幾乎所有準確性的同時,可將模型文件/下載大小減少75%。
模型使用的數據集
go_emotions基於Reddit數據,有28個標籤。這是一個多標籤數據集,對於任何給定的輸入文本,可能適用一個或多個標籤。因此,該模型是一個多標籤分類模型,對於任何給定的輸入文本,會輸出28個“概率”浮點數。通常,會對每個標籤的概率應用0.5的閾值進行預測。
模型創建方式
模型使用AutoModelForSequenceClassification.from_pretrained
,設置problem_type="multi_label_classification"
進行訓練,訓練3個epoch,學習率為2e - 5,權重衰減為0.01。
推理
在Huggingface Transformers中有多種使用該模型的方法,最簡單的可能是使用pipeline:
from transformers import pipeline
classifier = pipeline(task="text-classification", model="SamLowe/roberta-base-go_emotions", top_k=None)
sentences = ["I am not having a great day"]
model_outputs = classifier(sentences)
print(model_outputs[0])
# 為每個標籤生成一個字典列表
評估 / 指標
模型的評估信息可在以下鏈接獲取:
- https://github.com/samlowe/go_emotions-dataset/blob/main/eval-roberta-base-go_emotions.ipynb
總結
如上述筆記本中所示,使用數據集的測試集對多標籤輸出(通過0.5的閾值將每個輸出二值化)進行評估,結果如下:
- 準確率:0.474
- 精確率:0.575
- 召回率:0.396
- F1值:0.450
但考慮到多標籤的性質(每個標籤實際上是一個獨立的二分類)以及數據集中標籤的表示差異很大,按每個標籤進行指標測量更有意義。
對模型輸出應用0.5的閾值進行二值化後,按每個標籤的指標如下:
情感標籤 | 準確率 | 精確率 | 召回率 | F1值 | MCC值 | 樣本數 | 閾值 |
---|---|---|---|---|---|---|---|
admiration | 0.946 | 0.725 | 0.675 | 0.699 | 0.670 | 504 | 0.5 |
amusement | 0.982 | 0.790 | 0.871 | 0.829 | 0.821 | 264 | 0.5 |
anger | 0.970 | 0.652 | 0.379 | 0.479 | 0.483 | 198 | 0.5 |
annoyance | 0.940 | 0.472 | 0.159 | 0.238 | 0.250 | 320 | 0.5 |
approval | 0.942 | 0.609 | 0.302 | 0.404 | 0.403 | 351 | 0.5 |
caring | 0.973 | 0.448 | 0.319 | 0.372 | 0.364 | 135 | 0.5 |
confusion | 0.972 | 0.500 | 0.431 | 0.463 | 0.450 | 153 | 0.5 |
curiosity | 0.950 | 0.537 | 0.356 | 0.428 | 0.412 | 284 | 0.5 |
desire | 0.987 | 0.630 | 0.410 | 0.496 | 0.502 | 83 | 0.5 |
disappointment | 0.974 | 0.625 | 0.199 | 0.302 | 0.343 | 151 | 0.5 |
disapproval | 0.950 | 0.494 | 0.307 | 0.379 | 0.365 | 267 | 0.5 |
disgust | 0.982 | 0.707 | 0.333 | 0.453 | 0.478 | 123 | 0.5 |
embarrassment | 0.994 | 0.750 | 0.243 | 0.367 | 0.425 | 37 | 0.5 |
excitement | 0.983 | 0.603 | 0.340 | 0.435 | 0.445 | 103 | 0.5 |
fear | 0.992 | 0.758 | 0.603 | 0.671 | 0.672 | 78 | 0.5 |
gratitude | 0.990 | 0.960 | 0.881 | 0.919 | 0.914 | 352 | 0.5 |
grief | 0.999 | 0.000 | 0.000 | 0.000 | 0.000 | 6 | 0.5 |
joy | 0.978 | 0.647 | 0.559 | 0.600 | 0.590 | 161 | 0.5 |
love | 0.982 | 0.773 | 0.832 | 0.802 | 0.793 | 238 | 0.5 |
nervousness | 0.996 | 0.600 | 0.130 | 0.214 | 0.278 | 23 | 0.5 |
optimism | 0.972 | 0.667 | 0.376 | 0.481 | 0.488 | 186 | 0.5 |
pride | 0.997 | 0.000 | 0.000 | 0.000 | 0.000 | 16 | 0.5 |
realization | 0.974 | 0.541 | 0.138 | 0.220 | 0.264 | 145 | 0.5 |
relief | 0.998 | 0.000 | 0.000 | 0.000 | 0.000 | 11 | 0.5 |
remorse | 0.991 | 0.553 | 0.750 | 0.636 | 0.640 | 56 | 0.5 |
sadness | 0.977 | 0.621 | 0.494 | 0.550 | 0.542 | 156 | 0.5 |
surprise | 0.981 | 0.750 | 0.404 | 0.525 | 0.542 | 141 | 0.5 |
neutral | 0.782 | 0.694 | 0.604 | 0.646 | 0.492 | 1787 | 0.5 |
為每個標籤優化閾值以獲得最佳F1指標,會得到稍好的指標 - 犧牲一些精確率以獲得更高的召回率,從而有利於F1值(上述筆記本中展示了具體做法):
情感標籤 | 準確率 | 精確率 | 召回率 | F1值 | MCC值 | 樣本數 | 閾值 |
---|---|---|---|---|---|---|---|
admiration | 0.940 | 0.651 | 0.776 | 0.708 | 0.678 | 504 | 0.25 |
amusement | 0.982 | 0.781 | 0.890 | 0.832 | 0.825 | 264 | 0.45 |
anger | 0.959 | 0.454 | 0.601 | 0.517 | 0.502 | 198 | 0.15 |
annoyance | 0.864 | 0.243 | 0.619 | 0.349 | 0.328 | 320 | 0.10 |
approval | 0.926 | 0.432 | 0.442 | 0.437 | 0.397 | 351 | 0.30 |
caring | 0.972 | 0.426 | 0.385 | 0.405 | 0.391 | 135 | 0.40 |
confusion | 0.974 | 0.548 | 0.412 | 0.470 | 0.462 | 153 | 0.55 |
curiosity | 0.943 | 0.473 | 0.711 | 0.568 | 0.552 | 284 | 0.25 |
desire | 0.985 | 0.518 | 0.530 | 0.524 | 0.516 | 83 | 0.25 |
disappointment | 0.974 | 0.562 | 0.298 | 0.390 | 0.398 | 151 | 0.40 |
disapproval | 0.941 | 0.414 | 0.468 | 0.439 | 0.409 | 267 | 0.30 |
disgust | 0.978 | 0.523 | 0.463 | 0.491 | 0.481 | 123 | 0.20 |
embarrassment | 0.994 | 0.567 | 0.459 | 0.507 | 0.507 | 37 | 0.10 |
excitement | 0.981 | 0.500 | 0.417 | 0.455 | 0.447 | 103 | 0.35 |
fear | 0.991 | 0.712 | 0.667 | 0.689 | 0.685 | 78 | 0.40 |
gratitude | 0.990 | 0.957 | 0.889 | 0.922 | 0.917 | 352 | 0.45 |
grief | 0.999 | 0.333 | 0.333 | 0.333 | 0.333 | 6 | 0.05 |
joy | 0.978 | 0.623 | 0.646 | 0.634 | 0.623 | 161 | 0.40 |
love | 0.982 | 0.740 | 0.899 | 0.812 | 0.807 | 238 | 0.25 |
nervousness | 0.996 | 0.571 | 0.348 | 0.432 | 0.444 | 23 | 0.25 |
optimism | 0.971 | 0.580 | 0.565 | 0.572 | 0.557 | 186 | 0.20 |
pride | 0.998 | 0.875 | 0.438 | 0.583 | 0.618 | 16 | 0.10 |
realization | 0.961 | 0.270 | 0.262 | 0.266 | 0.246 | 145 | 0.15 |
relief | 0.992 | 0.152 | 0.636 | 0.246 | 0.309 | 11 | 0.05 |
remorse | 0.991 | 0.541 | 0.946 | 0.688 | 0.712 | 56 | 0.10 |
sadness | 0.977 | 0.599 | 0.583 | 0.591 | 0.579 | 156 | 0.40 |
surprise | 0.977 | 0.543 | 0.674 | 0.601 | 0.593 | 141 | 0.15 |
neutral | 0.758 | 0.598 | 0.810 | 0.688 | 0.513 | 1787 | 0.25 |
這會提高整體指標:
- 精確率:0.542
- 召回率:0.577
- F1值:0.541
或者如果按每個標籤的樣本數相對大小進行加權計算:
- 精確率:0.572
- 召回率:0.677
- F1值:0.611
數據集評論
一些標籤(如感激)在單獨考慮時表現非常好,F1值超過0.9,而其他標籤(如解脫)表現非常差。
這是一個具有挑戰性的數據集。像解脫這樣的標籤在訓練數據中的示例要少得多(在40k +的數據中少於100個,在測試集中只有11個)。
但在go_emotions的訓練數據中也可以看到一些模糊性和/或標註錯誤,這可能會限制模型的性能。對數據集進行數據清理,減少標註中的一些錯誤、模糊性、衝突和重複,將產生性能更高的模型。
📄 許可證
本項目採用MIT許可證。








