🚀 🎧 音声感情認識 with Wav2Vec2
このプロジェクトは、Wav2Vec2 モデルを活用して音声の感情を認識します。目的は、音声録音を Happy、Sad、Surprised などの異なる感情カテゴリに分類することです。
🚀 クイックスタート
この音声感情認識モデルを使用するには、以下の手順に従ってください。まず、必要なライブラリをインポートし、モデルと特徴抽出器を読み込みます。その後、音声ファイルを前処理し、感情を予測します。
from transformers import AutoModelForAudioClassification, AutoFeatureExtractor
import librosa
import torch
import numpy as np
model_id = "firdhokk/speech-emotion-recognition-with-facebook-wav2vec2-large-xlsr-53"
model = AutoModelForAudioClassification.from_pretrained(model_id)
feature_extractor = AutoFeatureExtractor.from_pretrained(model_id, do_normalize=True, return_attention_mask=True)
id2label = model.config.id2label
def preprocess_audio(audio_path, feature_extractor, max_duration=30.0):
audio_array, sampling_rate = librosa.load(audio_path, sr=feature_extractor.sampling_rate)
max_length = int(feature_extractor.sampling_rate * max_duration)
if len(audio_array) > max_length:
audio_array = audio_array[:max_length]
else:
audio_array = np.pad(audio_array, (0, max_length - len(audio_array)))
inputs = feature_extractor(
audio_array,
sampling_rate=feature_extractor.sampling_rate,
max_length=max_length,
truncation=True,
return_attention_mask=True,
return_tensors="pt",
)
return inputs
def predict_emotion(audio_path, model, feature_extractor, id2label, max_duration=30.0):
inputs = preprocess_audio(audio_path, feature_extractor, max_duration)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
inputs = {key: value.to(device) for key, value in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predicted_id = torch.argmax(logits, dim=-1).item()
predicted_label = id2label[predicted_id]
return predicted_label
audio_path = "/content/drive/MyDrive/Audio/Speech_URDU/Happy/SM5_F4_H058.wav"
predicted_emotion = predict_emotion(audio_path, model, feature_extractor, id2label)
print(f"Predicted Emotion: {predicted_emotion}")
✨ 主な機能
- 音声感情認識:音声データから感情を高精度に認識し、異なる感情カテゴリに分類します。
- 多様なデータセット対応:複数のデータセットを用いて学習されており、多様な音声データに対応します。
- 高精度な評価指標:Accuracy、Precision、Recall、F1 Scoreなどの評価指標が高く、性能が優れています。
📦 インストール
このモデルを使用するには、以下のライブラリが必要です。
必要なライブラリ
Property |
Details |
Transformers |
4.44.2 |
Pytorch |
2.4.1+cu121 |
Datasets |
3.0.0 |
Tokenizers |
0.19.1 |
📚 ドキュメント
🗂 データセット
学習と評価に使用されるデータセットは、複数のデータセットから収集されており、以下のものが含まれます。
データセットには、様々な感情でラベル付けされた録音が含まれています。以下は、データセット内の感情の分布です。
Emotion |
Count |
sad |
752 |
happy |
752 |
angry |
752 |
neutral |
716 |
disgust |
652 |
fearful |
652 |
surprised |
652 |
calm |
192 |
この分布は、データセット内の感情のバランスを反映しており、一部の感情は他の感情よりもサンプル数が多いことがわかります。学習時には、サンプル数が少ない "calm" 感情は除外されました。
🎤 前処理
- 音声読み込み:Librosa を使用して音声ファイルを読み込み、numpy配列に変換します。
- 特徴抽出:音声データは、Wav2Vec2 Feature Extractor を使用して処理され、モデルへの入力用に音声特徴が標準化および正規化されます。
🔧 モデル
使用されるモデルは、Wav2Vec2 Large XLR-53 モデルで、音声分類 タスク用にファインチューニングされています。
感情ラベルは数値IDにマッピングされ、モデルの学習と評価に使用されます。
⚙️ 学習
モデルは、以下のパラメータで学習されます。
- 学習率:
5e-05
- 学習バッチサイズ:
2
- 評価バッチサイズ:
2
- 乱数シード:
42
- 勾配累積ステップ:
5
- 総学習バッチサイズ:
10
(勾配累積後の実効バッチサイズ)
- オプティマイザ:Adam(
betas=(0.9, 0.999)
、epsilon=1e-08
)
- 学習率スケジューラ:
linear
- 学習率スケジューラのウォームアップ率:
0.1
- エポック数:
25
- 混合精度学習:Native AMP (Automatic Mixed Precision)
これらのパラメータは、特に Wav2Vec2 のような大規模データセットと深層モデルを扱う場合に、効率的なモデル学習と安定性を保証します。学習には、実験の追跡とモニタリングに Wandb が利用されています。
📊 評価指標
モデルの学習後に得られた評価指標は以下の通りです。
- Loss:
0.4989
- Accuracy:
0.9168
- Precision:
0.9209
- Recall:
0.9168
- F1 Score:
0.9166
これらの指標は、音声感情認識タスクにおけるモデルの性能を示しています。Accuracy、Precision、Recall、F1 Scoreの値が高いことから、モデルは音声データから感情状態を効果的に識別できていることがわかります。
🧪 結果
学習後、モデルはテストデータセットで評価され、結果は このリンク の Wandb を使用してモニタリングされています。
Training Loss |
Epoch |
Step |
Validation Loss |
Accuracy |
Precision |
Recall |
F1 |
1.9343 |
0.9995 |
394 |
1.9277 |
0.2505 |
0.1425 |
0.2505 |
0.1691 |
1.7944 |
1.9990 |
788 |
1.6446 |
0.4574 |
0.5759 |
0.4574 |
0.4213 |
1.4601 |
2.9985 |
1182 |
1.3242 |
0.5953 |
0.6183 |
0.5953 |
0.5709 |
1.0551 |
3.9980 |
1576 |
1.0764 |
0.6623 |
0.6659 |
0.6623 |
0.6447 |
0.8934 |
5.0 |
1971 |
0.9209 |
0.7059 |
0.7172 |
0.7059 |
0.6825 |
1.1156 |
5.9995 |
2365 |
0.8292 |
0.7465 |
0.7635 |
0.7465 |
0.7442 |
0.6307 |
6.9990 |
2759 |
0.6439 |
0.8043 |
0.8090 |
0.8043 |
0.8020 |
0.774 |
7.9985 |
3153 |
0.6666 |
0.7921 |
0.8117 |
0.7921 |
0.7916 |
0.5537 |
8.9980 |
3547 |
0.5111 |
0.8245 |
0.8268 |
0.8245 |
0.8205 |
0.3762 |
10.0 |
3942 |
0.5506 |
0.8306 |
0.8390 |
0.8306 |
0.8296 |
0.716 |
10.9995 |
4336 |
0.5499 |
0.8276 |
0.8465 |
0.8276 |
0.8268 |
0.5372 |
11.9990 |
4730 |
0.5463 |
0.8377 |
0.8606 |
0.8377 |
0.8404 |
0.3746 |
12.9985 |
5124 |
0.4758 |
0.8611 |
0.8714 |
0.8611 |
0.8597 |
0.4317 |
13.9980 |
5518 |
0.4438 |
0.8742 |
0.8843 |
0.8742 |
0.8756 |
0.2104 |
15.0 |
5913 |
0.4426 |
0.8803 |
0.8864 |
0.8803 |
0.8806 |
0.3193 |
15.9995 |
6307 |
0.4741 |
0.8671 |
0.8751 |
0.8671 |
0.8683 |
0.3445 |
16.9990 |
6701 |
0.3850 |
0.9037 |
0.9047 |
0.9037 |
0.9038 |
0.2777 |
17.9985 |
7095 |
0.4802 |
0.8834 |
0.8923 |
0.8834 |
0.8836 |
0.4406 |
18.9980 |
7489 |
0.4053 |
0.9047 |
0.9096 |
0.9047 |
0.9043 |
0.1707 |
20.0 |
7884 |
0.4434 |
0.9067 |
0.9129 |
0.9067 |
0.9069 |
0.2138 |
20.9995 |
8278 |
0.5051 |
0.9037 |
0.9155 |
0.9037 |
0.9053 |
0.1812 |
21.9990 |
8672 |
0.4238 |
0.8955 |
0.9007 |
0.8955 |
0.8953 |
0.3639 |
22.9985 |
9066 |
0.4021 |
0.9138 |
0.9182 |
0.9138 |
0.9143 |
0.3193 |
23.9980 |
9460 |
0.4989 |
0.9168 |
0.9209 |
0.9168 |
0.9166 |
0.2067 |
24.9873 |
9850 |
0.4959 |
0.8976 |
0.9032 |
0.8976 |
0.8975 |
📄 ライセンス
このモデルは、Apache 2.0ライセンスの下で提供されています。