🚀 GoEmotions BERT分類器
GoEmotions BERT分類器是一個經過微調的模型,基於BERT-base-uncased在go_emotions數據集上進行多標籤分類(28種情緒)訓練得到。它能夠準確地對文本中的情緒進行分類,為情緒分析提供了強大的支持。
🚀 快速開始
若要使用優化閾值進行準確預測,請使用Gradio演示。
✨ 主要特性
- 基於BERT-base-uncased架構,擁有1.1億個參數,具有強大的特徵提取能力。
- 在GoEmotions數據集上進行訓練,該數據集包含5.8萬條Reddit評論,涵蓋28種情緒。
- 使用Focal Loss(gamma=2)作為損失函數,有效處理類別不平衡問題。
- 採用AdamW優化器(lr=2e-5,weight_decay=0.01)進行訓練。
- 經過5個訓練週期,在Kaggle T4 x2 GPUs硬件上完成訓練。
📦 安裝指南
文檔未提及安裝步驟,暫不展示。
💻 使用示例
基礎用法
該模型使用存儲在thresholds.json
中的優化閾值進行預測。以下是Python示例代碼:
from transformers import BertForSequenceClassification, BertTokenizer
import torch
import json
import requests
repo_id = "logasanjeev/goemotions-bert"
model = BertForSequenceClassification.from_pretrained(repo_id)
tokenizer = BertTokenizer.from_pretrained(repo_id)
thresholds_url = f"https://huggingface.co/{repo_id}/raw/main/thresholds.json"
thresholds_data = json.loads(requests.get(thresholds_url).text)
emotion_labels = thresholds_data["emotion_labels"]
thresholds = thresholds_data["thresholds"]
text = "I’m just chilling today."
encodings = tokenizer(text, padding='max_length', truncation=True, max_length=128, return_tensors='pt')
with torch.no_grad():
logits = torch.sigmoid(model(**encodings).logits).numpy()[0]
predictions = [(emotion_labels[i], logit) for i, (logit, thresh) in enumerate(zip(logits, thresholds)) if logit >= thresh]
print(sorted(predictions, key=lambda x: x[1], reverse=True))
📚 詳細文檔
模型詳情
屬性 |
詳情 |
模型架構 |
BERT-base-uncased(1.1億個參數) |
訓練數據 |
GoEmotions(5.8萬條Reddit評論,28種情緒) |
損失函數 |
Focal Loss(gamma=2) |
優化器 |
AdamW(學習率=2e-5,權重衰減=0.01) |
訓練週期 |
5 |
硬件環境 |
Kaggle T4 x2 GPUs |
性能指標
- Micro F1:0.6025(優化閾值)
- Macro F1:0.5266
- 精確率:0.5425
- 召回率:0.6775
- 漢明損失:0.0372
- 平均正預測數:1.4564
各類別性能
以下表格展示了在測試集上使用優化閾值(見thresholds.json
)的各類別指標:
情緒 |
F1分數 |
精確率 |
召回率 |
樣本數 |
欽佩 |
0.7022 |
0.6980 |
0.7063 |
504 |
娛樂 |
0.8171 |
0.7692 |
0.8712 |
264 |
憤怒 |
0.5123 |
0.5000 |
0.5253 |
198 |
煩惱 |
0.3820 |
0.2908 |
0.5563 |
320 |
認可 |
0.4112 |
0.3485 |
0.5014 |
351 |
關心 |
0.4601 |
0.4045 |
0.5333 |
135 |
困惑 |
0.4488 |
0.4533 |
0.4444 |
153 |
好奇 |
0.5721 |
0.4402 |
0.8169 |
284 |
渴望 |
0.4068 |
0.6857 |
0.2892 |
83 |
失望 |
0.3476 |
0.3220 |
0.3775 |
151 |
不認可 |
0.4126 |
0.3433 |
0.5169 |
267 |
厭惡 |
0.4950 |
0.6329 |
0.4065 |
123 |
尷尬 |
0.5000 |
0.7368 |
0.3784 |
37 |
興奮 |
0.4084 |
0.4432 |
0.3786 |
103 |
恐懼 |
0.6311 |
0.5078 |
0.8333 |
78 |
感激 |
0.9173 |
0.9744 |
0.8665 |
352 |
悲痛 |
0.2500 |
0.5000 |
0.1667 |
6 |
喜悅 |
0.6246 |
0.5798 |
0.6770 |
161 |
愛 |
0.8110 |
0.7630 |
0.8655 |
238 |
緊張 |
0.3830 |
0.3750 |
0.3913 |
23 |
樂觀 |
0.5777 |
0.5856 |
0.5699 |
186 |
驕傲 |
0.4138 |
0.4615 |
0.3750 |
16 |
領悟 |
0.2421 |
0.5111 |
0.1586 |
145 |
寬慰 |
0.5385 |
0.4667 |
0.6364 |
11 |
懊悔 |
0.6797 |
0.5361 |
0.9286 |
56 |
悲傷 |
0.5391 |
0.6900 |
0.4423 |
156 |
驚訝 |
0.5724 |
0.5570 |
0.5887 |
141 |
中立 |
0.6895 |
0.5826 |
0.8444 |
1787 |
📄 許可證
本項目採用MIT許可證。