đ GoEmotions BERT Classifier
Fine-tuned BERT model for multi-label emotion classification.
This project presents a fine-tuned BERT-base-uncased model on the go_emotions dataset for multi-label classification of 28 emotions.
⨠Features
- Multi-label Classification: Capable of classifying text into 28 different emotions.
- Optimized Thresholds: Ensures more accurate predictions.
- High Performance: Achieves good scores on various evaluation metrics.
đĻ Installation
No specific installation steps are provided in the original document.
đ Quick Start
For accurate predictions with optimized thresholds, use the Gradio demo.
đ Documentation
Model Details
Property |
Details |
Architecture |
BERT-base-uncased (110M parameters) |
Training Data |
GoEmotions (58k Reddit comments, 28 emotions) |
Loss Function |
Focal Loss (gamma=2) |
Optimizer |
AdamW (lr=2e-5, weight_decay=0.01) |
Epochs |
5 |
Hardware |
Kaggle T4 x2 GPUs |
Performance
- Micro F1: 0.6025 (optimized thresholds)
- Macro F1: 0.5266
- Precision: 0.5425
- Recall: 0.6775
- Hamming Loss: 0.0372
- Avg Positive Predictions: 1.4564
Class-Wise Performance
The following table shows per-class metrics on the test set using optimized thresholds (see thresholds.json
):
Emotion |
F1 Score |
Precision |
Recall |
Support |
admiration |
0.7022 |
0.6980 |
0.7063 |
504 |
amusement |
0.8171 |
0.7692 |
0.8712 |
264 |
anger |
0.5123 |
0.5000 |
0.5253 |
198 |
annoyance |
0.3820 |
0.2908 |
0.5563 |
320 |
approval |
0.4112 |
0.3485 |
0.5014 |
351 |
caring |
0.4601 |
0.4045 |
0.5333 |
135 |
confusion |
0.4488 |
0.4533 |
0.4444 |
153 |
curiosity |
0.5721 |
0.4402 |
0.8169 |
284 |
desire |
0.4068 |
0.6857 |
0.2892 |
83 |
disappointment |
0.3476 |
0.3220 |
0.3775 |
151 |
disapproval |
0.4126 |
0.3433 |
0.5169 |
267 |
disgust |
0.4950 |
0.6329 |
0.4065 |
123 |
embarrassment |
0.5000 |
0.7368 |
0.3784 |
37 |
excitement |
0.4084 |
0.4432 |
0.3786 |
103 |
fear |
0.6311 |
0.5078 |
0.8333 |
78 |
gratitude |
0.9173 |
0.9744 |
0.8665 |
352 |
grief |
0.2500 |
0.5000 |
0.1667 |
6 |
joy |
0.6246 |
0.5798 |
0.6770 |
161 |
love |
0.8110 |
0.7630 |
0.8655 |
238 |
nervousness |
0.3830 |
0.3750 |
0.3913 |
23 |
optimism |
0.5777 |
0.5856 |
0.5699 |
186 |
pride |
0.4138 |
0.4615 |
0.3750 |
16 |
realization |
0.2421 |
0.5111 |
0.1586 |
145 |
relief |
0.5385 |
0.4667 |
0.6364 |
11 |
remorse |
0.6797 |
0.5361 |
0.9286 |
56 |
sadness |
0.5391 |
0.6900 |
0.4423 |
156 |
surprise |
0.5724 |
0.5570 |
0.5887 |
141 |
neutral |
0.6895 |
0.5826 |
0.8444 |
1787 |
đģ Usage Examples
Basic Usage
from transformers import BertForSequenceClassification, BertTokenizer
import torch
import json
import requests
repo_id = "logasanjeev/goemotions-bert"
model = BertForSequenceClassification.from_pretrained(repo_id)
tokenizer = BertTokenizer.from_pretrained(repo_id)
thresholds_url = f"https://huggingface.co/{repo_id}/raw/main/thresholds.json"
thresholds_data = json.loads(requests.get(thresholds_url).text)
emotion_labels = thresholds_data["emotion_labels"]
thresholds = thresholds_data["thresholds"]
text = "Iâm just chilling today."
encodings = tokenizer(text, padding='max_length', truncation=True, max_length=128, return_tensors='pt')
with torch.no_grad():
logits = torch.sigmoid(model(**encodings).logits).numpy()[0]
predictions = [(emotion_labels[i], logit) for i, (logit, thresh) in enumerate(zip(logits, thresholds)) if logit >= thresh]
print(sorted(predictions, key=lambda x: x[1], reverse=True))
đ License
This project is licensed under the MIT license.