🚀 现代BERT-large情感分类模型
这是一个基于transformers
库的多标签情感分类模型,使用 ModernBERT-large 模型在 go_emotions 数据集上进行微调。该模型可用于从英文文本中提取所有情感或检测特定情感。
🚀 快速开始
本模型使用Huggingface Transformers库,使用起来非常简单。不过,由于ModernBERT
架构在transformers
4.48.0及更高版本中得到支持,所以你需要安装相应版本的库:
pip install "transformers>=4.48.0"
✨ 主要特性
- 多标签分类:能够从英文文本中同时识别多种情感。
- 阈值优化:在验证集上通过最大化所有标签的F1宏分数来选择最佳阈值。
- 推理加速:支持使用Flash Attention 2来加速推理过程。
📦 安装指南
确保你已经安装了transformers
库的4.48.0及以上版本:
pip install "transformers>=4.48.0"
💻 使用示例
基础用法
以下是如何从文本中提取情感的示例代码:
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
tokenizer = AutoTokenizer.from_pretrained('fyaronskiy/ModernBERT-large-go-emotions')
model = AutoModelForSequenceClassification.from_pretrained('fyaronskiy/ModernBERT-large-go-emotions')
best_thresholds = [0.5510204081632653, 0.26530612244897955, 0.14285714285714285, 0.12244897959183673, 0.44897959183673464, 0.22448979591836732, 0.2040816326530612, 0.4081632653061224, 0.5306122448979591, 0.22448979591836732, 0.2857142857142857, 0.3061224489795918, 0.2040816326530612, 0.14285714285714285, 0.1020408163265306, 0.4693877551020408, 0.24489795918367346, 0.3061224489795918, 0.2040816326530612, 0.36734693877551017, 0.2857142857142857, 0.04081632653061224, 0.3061224489795918, 0.16326530612244897, 0.26530612244897955, 0.32653061224489793, 0.12244897959183673, 0.2040816326530612]
LABELS = ['admiration', 'amusement', 'anger', 'annoyance', 'approval', 'caring', 'confusion', 'curiosity', 'desire', 'disappointment', 'disapproval', 'disgust', 'embarrassment', 'excitement', 'fear', 'gratitude', 'grief', 'joy', 'love', 'nervousness', 'optimism', 'pride', 'realization', 'relief', 'remorse', 'sadness', 'surprise', 'neutral']
ID2LABEL = dict(enumerate(LABELS))
def detect_emotions(text):
inputs = tokenizer(text, truncation=True, add_special_tokens=True, max_length=128, return_tensors='pt')
with torch.no_grad():
logits = model(**inputs).logits
probas = torch.sigmoid(logits).squeeze(dim=0)
class_binary_labels = (probas > torch.tensor(best_thresholds)).int()
return [ID2LABEL[label_id] for label_id, value in enumerate(class_binary_labels) if value == 1]
print(detect_emotions('You have excellent service and the best coffee in the city, I love your coffee shop!'))
高级用法
以下是如何获取所有情感及其得分的示例代码:
def predict(text):
inputs = tokenizer(text, truncation=True, add_special_tokens=True, max_length=128, return_tensors='pt')
with torch.no_grad():
logits = model(**inputs).logits
probas = torch.sigmoid(logits).squeeze(dim=0).tolist()
probas = [round(proba, 3) for proba in probas]
labels2probas = dict(zip(LABELS, probas))
probas_dict_sorted = dict(sorted(labels2probas.items(), key=lambda x: x[1], reverse=True))
return probas_dict_sorted
print(predict('You have excellent service and the best coffee in the city, I love your coffee shop!'))
📚 详细文档
评估结果
以下是模型在go-emotions
测试集上的评估结果:
类别 |
精确率 |
召回率 |
F1分数 |
支持样本数 |
阈值 |
admiration |
0.68 |
0.72 |
0.7 |
504 |
0.55 |
amusement |
0.76 |
0.91 |
0.83 |
264 |
0.27 |
anger |
0.44 |
0.53 |
0.48 |
198 |
0.14 |
annoyance |
0.27 |
0.46 |
0.34 |
320 |
0.12 |
approval |
0.41 |
0.38 |
0.4 |
351 |
0.45 |
caring |
0.37 |
0.46 |
0.41 |
135 |
0.22 |
confusion |
0.36 |
0.51 |
0.42 |
153 |
0.2 |
curiosity |
0.45 |
0.77 |
0.57 |
284 |
0.41 |
desire |
0.66 |
0.46 |
0.54 |
83 |
0.53 |
disappointment |
0.41 |
0.26 |
0.32 |
151 |
0.22 |
disapproval |
0.39 |
0.54 |
0.45 |
267 |
0.29 |
disgust |
0.52 |
0.41 |
0.46 |
123 |
0.31 |
embarrassment |
0.52 |
0.41 |
0.45 |
37 |
0.2 |
excitement |
0.29 |
0.59 |
0.39 |
103 |
0.14 |
fear |
0.55 |
0.78 |
0.65 |
78 |
0.1 |
gratitude |
0.96 |
0.88 |
0.92 |
352 |
0.47 |
grief |
0.29 |
0.67 |
0.4 |
6 |
0.24 |
joy |
0.57 |
0.66 |
0.61 |
161 |
0.31 |
love |
0.74 |
0.87 |
0.8 |
238 |
0.2 |
nervousness |
0.37 |
0.43 |
0.4 |
23 |
0.37 |
optimism |
0.6 |
0.58 |
0.59 |
186 |
0.29 |
pride |
0.28 |
0.44 |
0.34 |
16 |
0.04 |
realization |
0.36 |
0.19 |
0.24 |
145 |
0.31 |
relief |
0.62 |
0.45 |
0.53 |
11 |
0.16 |
remorse |
0.51 |
0.84 |
0.63 |
56 |
0.27 |
sadness |
0.54 |
0.56 |
0.55 |
156 |
0.33 |
surprise |
0.47 |
0.63 |
0.54 |
141 |
0.12 |
neutral |
0.58 |
0.82 |
0.68 |
1787 |
0.2 |
micro avg |
0.54 |
0.67 |
0.6 |
6329 |
|
macro avg |
0.5 |
0.58 |
0.52 |
6329 |
|
weighted avg |
0.55 |
0.67 |
0.6 |
6329 |
|
samples avg |
0.59 |
0.69 |
0.61 |
6329 |
|
📄 许可证
本项目采用MIT许可证。