🚀 語音情感識別模型
該模型是基於 facebook/wav2vec2-xls-r-300m 微調而來,用於語音情感識別(SER)任務。它能夠對俄語語音中的情感進行分類,為語音交互場景提供情感分析支持。
🚀 快速開始
本模型可用於語音情感識別任務,下面將介紹如何使用它。
✨ 主要特性
📦 安裝指南
文檔未提及具體安裝步驟,可參考 transformers
庫的安裝方法來安裝所需依賴。
💻 使用示例
基礎用法
from transformers.pipelines import pipeline
pipe = pipeline(model="KELONMYOSA/wav2vec2-xls-r-300m-emotion-ru", trust_remote_code=True)
result = pipe("speech.wav")
print(result)
高級用法
import librosa
import torch
import torch.nn.functional as F
from transformers import AutoConfig, Wav2Vec2Processor, AutoModelForAudioClassification
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name_or_path = "KELONMYOSA/wav2vec2-xls-r-300m-emotion-ru"
config = AutoConfig.from_pretrained(model_name_or_path)
processor = Wav2Vec2Processor.from_pretrained(model_name_or_path)
sampling_rate = processor.feature_extractor.sampling_rate
model = AutoModelForAudioClassification.from_pretrained(model_name_or_path, trust_remote_code=True).to(device)
def predict(path):
speech, sr = librosa.load(path, sr=sampling_rate)
features = processor(speech, sampling_rate=sampling_rate, return_tensors="pt", padding=True)
input_values = features.input_values.to(device)
attention_mask = features.attention_mask.to(device)
with torch.no_grad():
logits = model(input_values, attention_mask=attention_mask).logits
scores = F.softmax(logits, dim=1).detach().cpu().numpy()[0]
outputs = [{"label": config.id2label[i], "score": round(score, 5)} for i, score in
enumerate(scores)]
return outputs
print(predict("speech.wav"))
運行上述代碼後,輸出示例如下:
[{'label': 'neutral', 'score': 0.00318}, {'label': 'positive', 'score': 0.00376}, {'label': 'sad', 'score': 0.00145}, {'label': 'angry', 'score': 0.98984}, {'label': 'other', 'score': 0.00176}]
📚 詳細文檔
數據集
用於微調預訓練模型的數據集是 DUSHA 數據集。該數據集包含約 125,000 條俄語音頻記錄,涵蓋了與虛擬助手對話中常見的四種基本情感:快樂(積極)、悲傷、憤怒和中性情感。
情感標籤
emotions = ['neutral', 'positive', 'angry', 'sad', 'other']
🔧 技術細節
評估結果
該模型在評估中取得了以下結果:
- 訓練損失:0.528700
- 驗證損失:0.349617
- 準確率:0.901369
情感類型 |
精確率 |
召回率 |
F1 分數 |
樣本數 |
中性 |
0.92 |
0.94 |
0.93 |
15886 |
積極 |
0.85 |
0.79 |
0.82 |
2481 |
悲傷 |
0.77 |
0.82 |
0.79 |
2506 |
憤怒 |
0.89 |
0.83 |
0.86 |
3072 |
其他 |
0.99 |
0.74 |
0.85 |
226 |
準確率 |
|
|
0.90 |
24171 |
宏平均 |
0.89 |
0.82 |
0.85 |
24171 |
加權平均 |
0.90 |
0.90 |
0.90 |
24171 |
📄 許可證
本項目採用 Apache-2.0 許可證。