模型简介
模型特点
模型能力
使用案例
🚀 RoBERTa-base情感多标签分类模型
本项目基于roberta-base模型,在go_emotions数据集上进行训练,实现多标签文本分类,可对文本中的多种情感进行识别和分类。
✨ 主要特性
- 多标签分类:能够对输入文本同时预测多个情感标签。
- ONNX版本支持:提供ONNX格式的模型,包括INT8量化版本,推理速度更快,依赖更小,跨平台性更好。
- 多方式推理:支持在Huggingface Transformers中通过多种方式使用,如使用pipeline进行推理。
📚 详细文档
模型概述
该模型基于roberta-base,在go_emotions数据集上进行多标签分类训练。
ONNX版本
该模型的ONNX格式版本(包括INT8量化的ONNX版本)可在https://huggingface.co/SamLowe/roberta-base-go_emotions-onnx获取。这些版本在推理时更快,尤其是对于小批量数据,能大幅减少推理所需依赖的大小,使模型推理更具跨平台性。对于量化版本,如果仅需要推理,在保留几乎所有准确性的同时,可将模型文件/下载大小减少75%。
模型使用的数据集
go_emotions基于Reddit数据,有28个标签。这是一个多标签数据集,对于任何给定的输入文本,可能适用一个或多个标签。因此,该模型是一个多标签分类模型,对于任何给定的输入文本,会输出28个“概率”浮点数。通常,会对每个标签的概率应用0.5的阈值进行预测。
模型创建方式
模型使用AutoModelForSequenceClassification.from_pretrained
,设置problem_type="multi_label_classification"
进行训练,训练3个epoch,学习率为2e - 5,权重衰减为0.01。
推理
在Huggingface Transformers中有多种使用该模型的方法,最简单的可能是使用pipeline:
from transformers import pipeline
classifier = pipeline(task="text-classification", model="SamLowe/roberta-base-go_emotions", top_k=None)
sentences = ["I am not having a great day"]
model_outputs = classifier(sentences)
print(model_outputs[0])
# 为每个标签生成一个字典列表
评估 / 指标
模型的评估信息可在以下链接获取:
- https://github.com/samlowe/go_emotions-dataset/blob/main/eval-roberta-base-go_emotions.ipynb
总结
如上述笔记本中所示,使用数据集的测试集对多标签输出(通过0.5的阈值将每个输出二值化)进行评估,结果如下:
- 准确率:0.474
- 精确率:0.575
- 召回率:0.396
- F1值:0.450
但考虑到多标签的性质(每个标签实际上是一个独立的二分类)以及数据集中标签的表示差异很大,按每个标签进行指标测量更有意义。
对模型输出应用0.5的阈值进行二值化后,按每个标签的指标如下:
情感标签 | 准确率 | 精确率 | 召回率 | F1值 | MCC值 | 样本数 | 阈值 |
---|---|---|---|---|---|---|---|
admiration | 0.946 | 0.725 | 0.675 | 0.699 | 0.670 | 504 | 0.5 |
amusement | 0.982 | 0.790 | 0.871 | 0.829 | 0.821 | 264 | 0.5 |
anger | 0.970 | 0.652 | 0.379 | 0.479 | 0.483 | 198 | 0.5 |
annoyance | 0.940 | 0.472 | 0.159 | 0.238 | 0.250 | 320 | 0.5 |
approval | 0.942 | 0.609 | 0.302 | 0.404 | 0.403 | 351 | 0.5 |
caring | 0.973 | 0.448 | 0.319 | 0.372 | 0.364 | 135 | 0.5 |
confusion | 0.972 | 0.500 | 0.431 | 0.463 | 0.450 | 153 | 0.5 |
curiosity | 0.950 | 0.537 | 0.356 | 0.428 | 0.412 | 284 | 0.5 |
desire | 0.987 | 0.630 | 0.410 | 0.496 | 0.502 | 83 | 0.5 |
disappointment | 0.974 | 0.625 | 0.199 | 0.302 | 0.343 | 151 | 0.5 |
disapproval | 0.950 | 0.494 | 0.307 | 0.379 | 0.365 | 267 | 0.5 |
disgust | 0.982 | 0.707 | 0.333 | 0.453 | 0.478 | 123 | 0.5 |
embarrassment | 0.994 | 0.750 | 0.243 | 0.367 | 0.425 | 37 | 0.5 |
excitement | 0.983 | 0.603 | 0.340 | 0.435 | 0.445 | 103 | 0.5 |
fear | 0.992 | 0.758 | 0.603 | 0.671 | 0.672 | 78 | 0.5 |
gratitude | 0.990 | 0.960 | 0.881 | 0.919 | 0.914 | 352 | 0.5 |
grief | 0.999 | 0.000 | 0.000 | 0.000 | 0.000 | 6 | 0.5 |
joy | 0.978 | 0.647 | 0.559 | 0.600 | 0.590 | 161 | 0.5 |
love | 0.982 | 0.773 | 0.832 | 0.802 | 0.793 | 238 | 0.5 |
nervousness | 0.996 | 0.600 | 0.130 | 0.214 | 0.278 | 23 | 0.5 |
optimism | 0.972 | 0.667 | 0.376 | 0.481 | 0.488 | 186 | 0.5 |
pride | 0.997 | 0.000 | 0.000 | 0.000 | 0.000 | 16 | 0.5 |
realization | 0.974 | 0.541 | 0.138 | 0.220 | 0.264 | 145 | 0.5 |
relief | 0.998 | 0.000 | 0.000 | 0.000 | 0.000 | 11 | 0.5 |
remorse | 0.991 | 0.553 | 0.750 | 0.636 | 0.640 | 56 | 0.5 |
sadness | 0.977 | 0.621 | 0.494 | 0.550 | 0.542 | 156 | 0.5 |
surprise | 0.981 | 0.750 | 0.404 | 0.525 | 0.542 | 141 | 0.5 |
neutral | 0.782 | 0.694 | 0.604 | 0.646 | 0.492 | 1787 | 0.5 |
为每个标签优化阈值以获得最佳F1指标,会得到稍好的指标 - 牺牲一些精确率以获得更高的召回率,从而有利于F1值(上述笔记本中展示了具体做法):
情感标签 | 准确率 | 精确率 | 召回率 | F1值 | MCC值 | 样本数 | 阈值 |
---|---|---|---|---|---|---|---|
admiration | 0.940 | 0.651 | 0.776 | 0.708 | 0.678 | 504 | 0.25 |
amusement | 0.982 | 0.781 | 0.890 | 0.832 | 0.825 | 264 | 0.45 |
anger | 0.959 | 0.454 | 0.601 | 0.517 | 0.502 | 198 | 0.15 |
annoyance | 0.864 | 0.243 | 0.619 | 0.349 | 0.328 | 320 | 0.10 |
approval | 0.926 | 0.432 | 0.442 | 0.437 | 0.397 | 351 | 0.30 |
caring | 0.972 | 0.426 | 0.385 | 0.405 | 0.391 | 135 | 0.40 |
confusion | 0.974 | 0.548 | 0.412 | 0.470 | 0.462 | 153 | 0.55 |
curiosity | 0.943 | 0.473 | 0.711 | 0.568 | 0.552 | 284 | 0.25 |
desire | 0.985 | 0.518 | 0.530 | 0.524 | 0.516 | 83 | 0.25 |
disappointment | 0.974 | 0.562 | 0.298 | 0.390 | 0.398 | 151 | 0.40 |
disapproval | 0.941 | 0.414 | 0.468 | 0.439 | 0.409 | 267 | 0.30 |
disgust | 0.978 | 0.523 | 0.463 | 0.491 | 0.481 | 123 | 0.20 |
embarrassment | 0.994 | 0.567 | 0.459 | 0.507 | 0.507 | 37 | 0.10 |
excitement | 0.981 | 0.500 | 0.417 | 0.455 | 0.447 | 103 | 0.35 |
fear | 0.991 | 0.712 | 0.667 | 0.689 | 0.685 | 78 | 0.40 |
gratitude | 0.990 | 0.957 | 0.889 | 0.922 | 0.917 | 352 | 0.45 |
grief | 0.999 | 0.333 | 0.333 | 0.333 | 0.333 | 6 | 0.05 |
joy | 0.978 | 0.623 | 0.646 | 0.634 | 0.623 | 161 | 0.40 |
love | 0.982 | 0.740 | 0.899 | 0.812 | 0.807 | 238 | 0.25 |
nervousness | 0.996 | 0.571 | 0.348 | 0.432 | 0.444 | 23 | 0.25 |
optimism | 0.971 | 0.580 | 0.565 | 0.572 | 0.557 | 186 | 0.20 |
pride | 0.998 | 0.875 | 0.438 | 0.583 | 0.618 | 16 | 0.10 |
realization | 0.961 | 0.270 | 0.262 | 0.266 | 0.246 | 145 | 0.15 |
relief | 0.992 | 0.152 | 0.636 | 0.246 | 0.309 | 11 | 0.05 |
remorse | 0.991 | 0.541 | 0.946 | 0.688 | 0.712 | 56 | 0.10 |
sadness | 0.977 | 0.599 | 0.583 | 0.591 | 0.579 | 156 | 0.40 |
surprise | 0.977 | 0.543 | 0.674 | 0.601 | 0.593 | 141 | 0.15 |
neutral | 0.758 | 0.598 | 0.810 | 0.688 | 0.513 | 1787 | 0.25 |
这会提高整体指标:
- 精确率:0.542
- 召回率:0.577
- F1值:0.541
或者如果按每个标签的样本数相对大小进行加权计算:
- 精确率:0.572
- 召回率:0.677
- F1值:0.611
数据集评论
一些标签(如感激)在单独考虑时表现非常好,F1值超过0.9,而其他标签(如解脱)表现非常差。
这是一个具有挑战性的数据集。像解脱这样的标签在训练数据中的示例要少得多(在40k +的数据中少于100个,在测试集中只有11个)。
但在go_emotions的训练数据中也可以看到一些模糊性和/或标注错误,这可能会限制模型的性能。对数据集进行数据清理,减少标注中的一些错误、模糊性、冲突和重复,将产生性能更高的模型。
📄 许可证
本项目采用MIT许可证。








