🚀 GoEmotions BERT分类器
GoEmotions BERT分类器是一个经过微调的模型,基于BERT-base-uncased在go_emotions数据集上进行多标签分类(28种情绪)训练得到。它能够准确地对文本中的情绪进行分类,为情绪分析提供了强大的支持。
🚀 快速开始
若要使用优化阈值进行准确预测,请使用Gradio演示。
✨ 主要特性
- 基于BERT-base-uncased架构,拥有1.1亿个参数,具有强大的特征提取能力。
- 在GoEmotions数据集上进行训练,该数据集包含5.8万条Reddit评论,涵盖28种情绪。
- 使用Focal Loss(gamma=2)作为损失函数,有效处理类别不平衡问题。
- 采用AdamW优化器(lr=2e-5,weight_decay=0.01)进行训练。
- 经过5个训练周期,在Kaggle T4 x2 GPUs硬件上完成训练。
📦 安装指南
文档未提及安装步骤,暂不展示。
💻 使用示例
基础用法
该模型使用存储在thresholds.json
中的优化阈值进行预测。以下是Python示例代码:
from transformers import BertForSequenceClassification, BertTokenizer
import torch
import json
import requests
repo_id = "logasanjeev/goemotions-bert"
model = BertForSequenceClassification.from_pretrained(repo_id)
tokenizer = BertTokenizer.from_pretrained(repo_id)
thresholds_url = f"https://huggingface.co/{repo_id}/raw/main/thresholds.json"
thresholds_data = json.loads(requests.get(thresholds_url).text)
emotion_labels = thresholds_data["emotion_labels"]
thresholds = thresholds_data["thresholds"]
text = "I’m just chilling today."
encodings = tokenizer(text, padding='max_length', truncation=True, max_length=128, return_tensors='pt')
with torch.no_grad():
logits = torch.sigmoid(model(**encodings).logits).numpy()[0]
predictions = [(emotion_labels[i], logit) for i, (logit, thresh) in enumerate(zip(logits, thresholds)) if logit >= thresh]
print(sorted(predictions, key=lambda x: x[1], reverse=True))
📚 详细文档
模型详情
属性 |
详情 |
模型架构 |
BERT-base-uncased(1.1亿个参数) |
训练数据 |
GoEmotions(5.8万条Reddit评论,28种情绪) |
损失函数 |
Focal Loss(gamma=2) |
优化器 |
AdamW(学习率=2e-5,权重衰减=0.01) |
训练周期 |
5 |
硬件环境 |
Kaggle T4 x2 GPUs |
性能指标
- Micro F1:0.6025(优化阈值)
- Macro F1:0.5266
- 精确率:0.5425
- 召回率:0.6775
- 汉明损失:0.0372
- 平均正预测数:1.4564
各类别性能
以下表格展示了在测试集上使用优化阈值(见thresholds.json
)的各类别指标:
情绪 |
F1分数 |
精确率 |
召回率 |
样本数 |
钦佩 |
0.7022 |
0.6980 |
0.7063 |
504 |
娱乐 |
0.8171 |
0.7692 |
0.8712 |
264 |
愤怒 |
0.5123 |
0.5000 |
0.5253 |
198 |
烦恼 |
0.3820 |
0.2908 |
0.5563 |
320 |
认可 |
0.4112 |
0.3485 |
0.5014 |
351 |
关心 |
0.4601 |
0.4045 |
0.5333 |
135 |
困惑 |
0.4488 |
0.4533 |
0.4444 |
153 |
好奇 |
0.5721 |
0.4402 |
0.8169 |
284 |
渴望 |
0.4068 |
0.6857 |
0.2892 |
83 |
失望 |
0.3476 |
0.3220 |
0.3775 |
151 |
不认可 |
0.4126 |
0.3433 |
0.5169 |
267 |
厌恶 |
0.4950 |
0.6329 |
0.4065 |
123 |
尴尬 |
0.5000 |
0.7368 |
0.3784 |
37 |
兴奋 |
0.4084 |
0.4432 |
0.3786 |
103 |
恐惧 |
0.6311 |
0.5078 |
0.8333 |
78 |
感激 |
0.9173 |
0.9744 |
0.8665 |
352 |
悲痛 |
0.2500 |
0.5000 |
0.1667 |
6 |
喜悦 |
0.6246 |
0.5798 |
0.6770 |
161 |
爱 |
0.8110 |
0.7630 |
0.8655 |
238 |
紧张 |
0.3830 |
0.3750 |
0.3913 |
23 |
乐观 |
0.5777 |
0.5856 |
0.5699 |
186 |
骄傲 |
0.4138 |
0.4615 |
0.3750 |
16 |
领悟 |
0.2421 |
0.5111 |
0.1586 |
145 |
宽慰 |
0.5385 |
0.4667 |
0.6364 |
11 |
懊悔 |
0.6797 |
0.5361 |
0.9286 |
56 |
悲伤 |
0.5391 |
0.6900 |
0.4423 |
156 |
惊讶 |
0.5724 |
0.5570 |
0.5887 |
141 |
中立 |
0.6895 |
0.5826 |
0.8444 |
1787 |
📄 许可证
本项目采用MIT许可证。