🚀 トランスフォーマーモデル
このモデルは、多ラベル分類のためにgo_emotionsデータセットでファインチューニングされたModernBERT-largeモデルです。英語のテキストからすべての感情を抽出したり、特定の感情を検出したりするために使用できます。閾値は、すべてのラベルに対するF1マクロを最大化することで検証セット上で選択されています。また、Flash Attention 2を使用すると推論を高速化できます。
モデルの品質はすべてのクラスにわたって大きく異なります(下のメトリクスの表を参照)。賞賛、喜び、楽観主義、恐怖、悔いなど、モデルが高い認識品質を示すクラスもあれば、モデルにとって難しいクラスもあります。例えば、失望や気付きなどは、学習データに含まれる例が非常に少ないためです。
🚀 クイックスタート
このモデルはHuggingface Transformersを使って簡単に利用できます。
ModernBERTアーキテクチャはTransformersバージョン4.48.0以降でサポートされているため、以下のコマンドでインストールする必要があります。
pip install "transformers>=4.48.0"
✨ 主な機能
- 英語のテキストから多ラベルの感情を抽出することができます。
- 特定の感情を検出することも可能です。
- Flash Attention 2を使用して推論を高速化できます。
📦 インストール
pip install "transformers>=4.48.0"
💻 使用例
基本的な使用法
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
tokenizer = AutoTokenizer.from_pretrained('fyaronskiy/ModernBERT-large-go-emotions')
model = AutoModelForSequenceClassification.from_pretrained('fyaronskiy/ModernBERT-large-go-emotions')
best_thresholds = [0.5510204081632653, 0.26530612244897955, 0.14285714285714285, 0.12244897959183673, 0.44897959183673464, 0.22448979591836732, 0.2040816326530612, 0.4081632653061224, 0.5306122448979591, 0.22448979591836732, 0.2857142857142857, 0.3061224489795918, 0.2040816326530612, 0.14285714285714285, 0.1020408163265306, 0.4693877551020408, 0.24489795918367346, 0.3061224489795918, 0.2040816326530612, 0.36734693877551017, 0.2857142857142857, 0.04081632653061224, 0.3061224489795918, 0.16326530612244897, 0.26530612244897955, 0.32653061224489793, 0.12244897959183673, 0.2040816326530612]
LABELS = ['admiration', 'amusement', 'anger', 'annoyance', 'approval', 'caring', 'confusion', 'curiosity', 'desire', 'disappointment', 'disapproval', 'disgust', 'embarrassment', 'excitement', 'fear', 'gratitude', 'grief', 'joy', 'love', 'nervousness', 'optimism', 'pride', 'realization', 'relief', 'remorse', 'sadness', 'surprise', 'neutral']
ID2LABEL = dict(enumerate(LABELS))
高度な使用法
テキストから感情を抽出する方法
def detect_emotions(text):
inputs = tokenizer(text, truncation=True, add_special_tokens=True, max_length=128, return_tensors='pt')
with torch.no_grad():
logits = model(**inputs).logits
probas = torch.sigmoid(logits).squeeze(dim=0)
class_binary_labels = (probas > torch.tensor(best_thresholds)).int()
return [ID2LABEL[label_id] for label_id, value in enumerate(class_binary_labels) if value == 1]
print(detect_emotions('You have excellent service and the best coffee in the city, I love your coffee shop!'))
すべての感情とそのスコアを取得する方法
def predict(text):
inputs = tokenizer(text, truncation=True, add_special_tokens=True, max_length=128, return_tensors='pt')
with torch.no_grad():
logits = model(**inputs).logits
probas = torch.sigmoid(logits).squeeze(dim=0).tolist()
probas = [round(proba, 3) for proba in probas]
labels2probas = dict(zip(LABELS, probas))
probas_dict_sorted = dict(sorted(labels2probas.items(), key=lambda x: x[1], reverse=True))
return probas_dict_sorted
print(predict('You have excellent service and the best coffee in the city, I love your coffee shop!'))
📚 ドキュメント
go-emotionsのテスト分割における評価結果
属性 |
詳細 |
モデルタイプ |
ModernBERT-largeをファインチューニングしたモデル |
学習データ |
go_emotionsデータセット |
クラス |
精度 |
再現率 |
F1スコア |
サポート |
閾値 |
admiration |
0.68 |
0.72 |
0.7 |
504 |
0.55 |
amusement |
0.76 |
0.91 |
0.83 |
264 |
0.27 |
anger |
0.44 |
0.53 |
0.48 |
198 |
0.14 |
annoyance |
0.27 |
0.46 |
0.34 |
320 |
0.12 |
approval |
0.41 |
0.38 |
0.4 |
351 |
0.45 |
caring |
0.37 |
0.46 |
0.41 |
135 |
0.22 |
confusion |
0.36 |
0.51 |
0.42 |
153 |
0.2 |
curiosity |
0.45 |
0.77 |
0.57 |
284 |
0.41 |
desire |
0.66 |
0.46 |
0.54 |
83 |
0.53 |
disappointment |
0.41 |
0.26 |
0.32 |
151 |
0.22 |
disapproval |
0.39 |
0.54 |
0.45 |
267 |
0.29 |
disgust |
0.52 |
0.41 |
0.46 |
123 |
0.31 |
embarrassment |
0.52 |
0.41 |
0.45 |
37 |
0.2 |
excitement |
0.29 |
0.59 |
0.39 |
103 |
0.14 |
fear |
0.55 |
0.78 |
0.65 |
78 |
0.1 |
gratitude |
0.96 |
0.88 |
0.92 |
352 |
0.47 |
grief |
0.29 |
0.67 |
0.4 |
6 |
0.24 |
joy |
0.57 |
0.66 |
0.61 |
161 |
0.31 |
love |
0.74 |
0.87 |
0.8 |
238 |
0.2 |
nervousness |
0.37 |
0.43 |
0.4 |
23 |
0.37 |
optimism |
0.6 |
0.58 |
0.59 |
186 |
0.29 |
pride |
0.28 |
0.44 |
0.34 |
16 |
0.04 |
realization |
0.36 |
0.19 |
0.24 |
145 |
0.31 |
relief |
0.62 |
0.45 |
0.53 |
11 |
0.16 |
remorse |
0.51 |
0.84 |
0.63 |
56 |
0.27 |
sadness |
0.54 |
0.56 |
0.55 |
156 |
0.33 |
surprise |
0.47 |
0.63 |
0.54 |
141 |
0.12 |
neutral |
0.58 |
0.82 |
0.68 |
1787 |
0.2 |
micro avg |
0.54 |
0.67 |
0.6 |
6329 |
|
macro avg |
0.5 |
0.58 |
0.52 |
6329 |
|
weighted avg |
0.55 |
0.67 |
0.6 |
6329 |
|
samples avg |
0.59 |
0.69 |
0.61 |
6329 |
|
📄 ライセンス
このプロジェクトはMITライセンスの下で公開されています。