🚀 現代BERT-large情感分類模型
這是一個基於transformers
庫的多標籤情感分類模型,使用 ModernBERT-large 模型在 go_emotions 數據集上進行微調。該模型可用於從英文文本中提取所有情感或檢測特定情感。
🚀 快速開始
本模型使用Huggingface Transformers庫,使用起來非常簡單。不過,由於ModernBERT
架構在transformers
4.48.0及更高版本中得到支持,所以你需要安裝相應版本的庫:
pip install "transformers>=4.48.0"
✨ 主要特性
- 多標籤分類:能夠從英文文本中同時識別多種情感。
- 閾值優化:在驗證集上通過最大化所有標籤的F1宏分數來選擇最佳閾值。
- 推理加速:支持使用Flash Attention 2來加速推理過程。
📦 安裝指南
確保你已經安裝了transformers
庫的4.48.0及以上版本:
pip install "transformers>=4.48.0"
💻 使用示例
基礎用法
以下是如何從文本中提取情感的示例代碼:
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
tokenizer = AutoTokenizer.from_pretrained('fyaronskiy/ModernBERT-large-go-emotions')
model = AutoModelForSequenceClassification.from_pretrained('fyaronskiy/ModernBERT-large-go-emotions')
best_thresholds = [0.5510204081632653, 0.26530612244897955, 0.14285714285714285, 0.12244897959183673, 0.44897959183673464, 0.22448979591836732, 0.2040816326530612, 0.4081632653061224, 0.5306122448979591, 0.22448979591836732, 0.2857142857142857, 0.3061224489795918, 0.2040816326530612, 0.14285714285714285, 0.1020408163265306, 0.4693877551020408, 0.24489795918367346, 0.3061224489795918, 0.2040816326530612, 0.36734693877551017, 0.2857142857142857, 0.04081632653061224, 0.3061224489795918, 0.16326530612244897, 0.26530612244897955, 0.32653061224489793, 0.12244897959183673, 0.2040816326530612]
LABELS = ['admiration', 'amusement', 'anger', 'annoyance', 'approval', 'caring', 'confusion', 'curiosity', 'desire', 'disappointment', 'disapproval', 'disgust', 'embarrassment', 'excitement', 'fear', 'gratitude', 'grief', 'joy', 'love', 'nervousness', 'optimism', 'pride', 'realization', 'relief', 'remorse', 'sadness', 'surprise', 'neutral']
ID2LABEL = dict(enumerate(LABELS))
def detect_emotions(text):
inputs = tokenizer(text, truncation=True, add_special_tokens=True, max_length=128, return_tensors='pt')
with torch.no_grad():
logits = model(**inputs).logits
probas = torch.sigmoid(logits).squeeze(dim=0)
class_binary_labels = (probas > torch.tensor(best_thresholds)).int()
return [ID2LABEL[label_id] for label_id, value in enumerate(class_binary_labels) if value == 1]
print(detect_emotions('You have excellent service and the best coffee in the city, I love your coffee shop!'))
高級用法
以下是如何獲取所有情感及其得分的示例代碼:
def predict(text):
inputs = tokenizer(text, truncation=True, add_special_tokens=True, max_length=128, return_tensors='pt')
with torch.no_grad():
logits = model(**inputs).logits
probas = torch.sigmoid(logits).squeeze(dim=0).tolist()
probas = [round(proba, 3) for proba in probas]
labels2probas = dict(zip(LABELS, probas))
probas_dict_sorted = dict(sorted(labels2probas.items(), key=lambda x: x[1], reverse=True))
return probas_dict_sorted
print(predict('You have excellent service and the best coffee in the city, I love your coffee shop!'))
📚 詳細文檔
評估結果
以下是模型在go-emotions
測試集上的評估結果:
類別 |
精確率 |
召回率 |
F1分數 |
支持樣本數 |
閾值 |
admiration |
0.68 |
0.72 |
0.7 |
504 |
0.55 |
amusement |
0.76 |
0.91 |
0.83 |
264 |
0.27 |
anger |
0.44 |
0.53 |
0.48 |
198 |
0.14 |
annoyance |
0.27 |
0.46 |
0.34 |
320 |
0.12 |
approval |
0.41 |
0.38 |
0.4 |
351 |
0.45 |
caring |
0.37 |
0.46 |
0.41 |
135 |
0.22 |
confusion |
0.36 |
0.51 |
0.42 |
153 |
0.2 |
curiosity |
0.45 |
0.77 |
0.57 |
284 |
0.41 |
desire |
0.66 |
0.46 |
0.54 |
83 |
0.53 |
disappointment |
0.41 |
0.26 |
0.32 |
151 |
0.22 |
disapproval |
0.39 |
0.54 |
0.45 |
267 |
0.29 |
disgust |
0.52 |
0.41 |
0.46 |
123 |
0.31 |
embarrassment |
0.52 |
0.41 |
0.45 |
37 |
0.2 |
excitement |
0.29 |
0.59 |
0.39 |
103 |
0.14 |
fear |
0.55 |
0.78 |
0.65 |
78 |
0.1 |
gratitude |
0.96 |
0.88 |
0.92 |
352 |
0.47 |
grief |
0.29 |
0.67 |
0.4 |
6 |
0.24 |
joy |
0.57 |
0.66 |
0.61 |
161 |
0.31 |
love |
0.74 |
0.87 |
0.8 |
238 |
0.2 |
nervousness |
0.37 |
0.43 |
0.4 |
23 |
0.37 |
optimism |
0.6 |
0.58 |
0.59 |
186 |
0.29 |
pride |
0.28 |
0.44 |
0.34 |
16 |
0.04 |
realization |
0.36 |
0.19 |
0.24 |
145 |
0.31 |
relief |
0.62 |
0.45 |
0.53 |
11 |
0.16 |
remorse |
0.51 |
0.84 |
0.63 |
56 |
0.27 |
sadness |
0.54 |
0.56 |
0.55 |
156 |
0.33 |
surprise |
0.47 |
0.63 |
0.54 |
141 |
0.12 |
neutral |
0.58 |
0.82 |
0.68 |
1787 |
0.2 |
micro avg |
0.54 |
0.67 |
0.6 |
6329 |
|
macro avg |
0.5 |
0.58 |
0.52 |
6329 |
|
weighted avg |
0.55 |
0.67 |
0.6 |
6329 |
|
samples avg |
0.59 |
0.69 |
0.61 |
6329 |
|
📄 許可證
本項目採用MIT許可證。