🚀 Wav2Vec 2.0をファインチューニングした音声感情認識モデル
このモデルは、音声感情認識(SER)タスク向けにjonatasgrosman/wav2vec2-large-xlsr-53-english をファインチューニングしたものです。元のモデルをファインチューニングするために、いくつかのデータセットが使用されました。
- Surrey Audio-Visual Expressed Emotion (SAVEE) - 4人の男性アクターによる480個の音声ファイル
- Ryerson Audio-Visual Database of Emotional Speech and Song (RAVDESS) - 24人のプロアクター(女性12人、男性12人)による1440個の音声ファイル
- Toronto emotional speech set (TESS) - 2人の女性アクターによる2800個の音声ファイル
分類ラベルとして7つのラベル/感情が使用されています。
emotions = ['angry' 'disgust' 'fear' 'happy' 'neutral' 'sad' 'surprise']
評価セットでは、以下の結果が得られています。
- 損失: 0.104075
- 正解率: 0.97463
🚀 クイックスタート
📦 インストール
pip install transformers librosa torch
💻 使用例
基本的な使用法
from transformers import *
import librosa
import torch
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("r-f/wav2vec-english-speech-emotion-recognition")
model = Wav2Vec2ForCTC.from_pretrained("r-f/wav2vec-english-speech-emotion-recognition")
def predict_emotion(audio_path):
audio, rate = librosa.load(audio_path, sr=16000)
inputs = feature_extractor(audio, sampling_rate=rate, return_tensors="pt", padding=True)
with torch.no_grad():
outputs = model(inputs.input_values)
predictions = torch.nn.functional.softmax(outputs.logits.mean(dim=1), dim=-1)
predicted_label = torch.argmax(predictions, dim=-1)
emotion = model.config.id2label[predicted_label.item()]
return emotion
emotion = predict_emotion("example_audio.wav")
print(f"Predicted emotion: {emotion}")
>> Predicted emotion: angry
📚 ドキュメント
🔧 技術詳細
トレーニング手順
トレーニングハイパーパラメータ
トレーニング中に以下のハイパーパラメータが使用されました。
- 学習率: 0.0001
- トレーニングバッチサイズ: 4
- 評価バッチサイズ: 4
- 評価ステップ: 500
- シード: 42
- 勾配累積ステップ: 2
- オプティマイザ: Adam(betas=(0.9,0.999)、epsilon=1e-08)
- エポック数: 4
- 最大ステップ数: 7500
- 保存ステップ: 1500
トレーニング結果
ステップ |
トレーニング損失 |
検証損失 |
正解率 |
500 |
1.8124 |
1.365212 |
0.486258 |
1000 |
0.8872 |
0.773145 |
0.79704 |
1500 |
0.7035 |
0.574954 |
0.852008 |
2000 |
0.6879 |
1.286738 |
0.775899 |
2500 |
0.6498 |
0.697455 |
0.832981 |
3000 |
0.5696 |
0.33724 |
0.892178 |
3500 |
0.4218 |
0.307072 |
0.911205 |
4000 |
0.3088 |
0.374443 |
0.930233 |
4500 |
0.2688 |
0.260444 |
0.936575 |
5000 |
0.2973 |
0.302985 |
0.92389 |
5500 |
0.1765 |
0.165439 |
0.961945 |
6000 |
0.1475 |
0.170199 |
0.961945 |
6500 |
0.1274 |
0.15531 |
0.966173 |
7000 |
0.0699 |
0.103882 |
0.976744 |
7500 |
0.083 |
0.104075 |
0.97463 |
📄 ライセンス
このモデルはApache-2.0ライセンスの下で提供されています。