🚀 GoEmotions BERT 分類器
このモデルは、BERT-base-uncased を go_emotions データセットでファインチューニングし、28種類の感情を対象としたマルチラベル分類を行うことができます。
✨ 主な機能
- 28種類の感情を識別するマルチラベル分類が可能です。
- 最適化された閾値を使用することで、高精度な予測が可能です。
📦 インストール
このモデルを使用するには、transformers
ライブラリが必要です。以下のコマンドでインストールできます。
pip install transformers
💻 使用例
基本的な使用法
from transformers import BertForSequenceClassification, BertTokenizer
import torch
import json
import requests
repo_id = "logasanjeev/goemotions-bert"
model = BertForSequenceClassification.from_pretrained(repo_id)
tokenizer = BertTokenizer.from_pretrained(repo_id)
thresholds_url = f"https://huggingface.co/{repo_id}/raw/main/thresholds.json"
thresholds_data = json.loads(requests.get(thresholds_url).text)
emotion_labels = thresholds_data["emotion_labels"]
thresholds = thresholds_data["thresholds"]
text = "I’m just chilling today."
encodings = tokenizer(text, padding='max_length', truncation=True, max_length=128, return_tensors='pt')
with torch.no_grad():
logits = torch.sigmoid(model(**encodings).logits).numpy()[0]
predictions = [(emotion_labels[i], logit) for i, (logit, thresh) in enumerate(zip(logits, thresholds)) if logit >= thresh]
print(sorted(predictions, key=lambda x: x[1], reverse=True))
📚 ドキュメント
モデルの詳細
プロパティ |
詳細 |
モデルタイプ |
BERT-base-uncased (110Mパラメータ) |
学習データ |
GoEmotions (58k件のRedditコメント、28種類の感情) |
損失関数 |
Focal Loss (gamma=2) |
オプティマイザー |
AdamW (lr=2e-5, weight_decay=0.01) |
エポック数 |
5 |
ハードウェア |
Kaggle T4 x2 GPUs |
試してみる
最適化された閾値を使用して正確な予測を行うには、Gradioデモ を使用してください。
性能
- Micro F1: 0.6025 (最適化された閾値)
- Macro F1: 0.5266
- Precision: 0.5425
- Recall: 0.6775
- Hamming Loss: 0.0372
- 平均正の予測数: 1.4564
クラス別の性能
以下の表は、最適化された閾値(thresholds.json
を参照)を使用したテストセットのクラスごとのメトリクスを示しています。
感情 |
F1スコア |
精度 |
再現率 |
サポート |
admiration |
0.7022 |
0.6980 |
0.7063 |
504 |
amusement |
0.8171 |
0.7692 |
0.8712 |
264 |
anger |
0.5123 |
0.5000 |
0.5253 |
198 |
annoyance |
0.3820 |
0.2908 |
0.5563 |
320 |
approval |
0.4112 |
0.3485 |
0.5014 |
351 |
caring |
0.4601 |
0.4045 |
0.5333 |
135 |
confusion |
0.4488 |
0.4533 |
0.4444 |
153 |
curiosity |
0.5721 |
0.4402 |
0.8169 |
284 |
desire |
0.4068 |
0.6857 |
0.2892 |
83 |
disappointment |
0.3476 |
0.3220 |
0.3775 |
151 |
disapproval |
0.4126 |
0.3433 |
0.5169 |
267 |
disgust |
0.4950 |
0.6329 |
0.4065 |
123 |
embarrassment |
0.5000 |
0.7368 |
0.3784 |
37 |
excitement |
0.4084 |
0.4432 |
0.3786 |
103 |
fear |
0.6311 |
0.5078 |
0.8333 |
78 |
gratitude |
0.9173 |
0.9744 |
0.8665 |
352 |
grief |
0.2500 |
0.5000 |
0.1667 |
6 |
joy |
0.6246 |
0.5798 |
0.6770 |
161 |
love |
0.8110 |
0.7630 |
0.8655 |
238 |
nervousness |
0.3830 |
0.3750 |
0.3913 |
23 |
optimism |
0.5777 |
0.5856 |
0.5699 |
186 |
pride |
0.4138 |
0.4615 |
0.3750 |
16 |
realization |
0.2421 |
0.5111 |
0.1586 |
145 |
relief |
0.5385 |
0.4667 |
0.6364 |
11 |
remorse |
0.6797 |
0.5361 |
0.9286 |
56 |
sadness |
0.5391 |
0.6900 |
0.4423 |
156 |
surprise |
0.5724 |
0.5570 |
0.5887 |
141 |
neutral |
0.6895 |
0.5826 |
0.8444 |
1787 |
📄 ライセンス
このモデルは MIT ライセンスの下で提供されています。