🚀 语音情感识别模型
该模型是基于 facebook/wav2vec2-xls-r-300m 微调而来,用于语音情感识别(SER)任务。它能够对俄语语音中的情感进行分类,为语音交互场景提供情感分析支持。
🚀 快速开始
本模型可用于语音情感识别任务,下面将介绍如何使用它。
✨ 主要特性
📦 安装指南
文档未提及具体安装步骤,可参考 transformers
库的安装方法来安装所需依赖。
💻 使用示例
基础用法
from transformers.pipelines import pipeline
pipe = pipeline(model="KELONMYOSA/wav2vec2-xls-r-300m-emotion-ru", trust_remote_code=True)
result = pipe("speech.wav")
print(result)
高级用法
import librosa
import torch
import torch.nn.functional as F
from transformers import AutoConfig, Wav2Vec2Processor, AutoModelForAudioClassification
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name_or_path = "KELONMYOSA/wav2vec2-xls-r-300m-emotion-ru"
config = AutoConfig.from_pretrained(model_name_or_path)
processor = Wav2Vec2Processor.from_pretrained(model_name_or_path)
sampling_rate = processor.feature_extractor.sampling_rate
model = AutoModelForAudioClassification.from_pretrained(model_name_or_path, trust_remote_code=True).to(device)
def predict(path):
speech, sr = librosa.load(path, sr=sampling_rate)
features = processor(speech, sampling_rate=sampling_rate, return_tensors="pt", padding=True)
input_values = features.input_values.to(device)
attention_mask = features.attention_mask.to(device)
with torch.no_grad():
logits = model(input_values, attention_mask=attention_mask).logits
scores = F.softmax(logits, dim=1).detach().cpu().numpy()[0]
outputs = [{"label": config.id2label[i], "score": round(score, 5)} for i, score in
enumerate(scores)]
return outputs
print(predict("speech.wav"))
运行上述代码后,输出示例如下:
[{'label': 'neutral', 'score': 0.00318}, {'label': 'positive', 'score': 0.00376}, {'label': 'sad', 'score': 0.00145}, {'label': 'angry', 'score': 0.98984}, {'label': 'other', 'score': 0.00176}]
📚 详细文档
数据集
用于微调预训练模型的数据集是 DUSHA 数据集。该数据集包含约 125,000 条俄语音频记录,涵盖了与虚拟助手对话中常见的四种基本情感:快乐(积极)、悲伤、愤怒和中性情感。
情感标签
emotions = ['neutral', 'positive', 'angry', 'sad', 'other']
🔧 技术细节
评估结果
该模型在评估中取得了以下结果:
- 训练损失:0.528700
- 验证损失:0.349617
- 准确率:0.901369
情感类型 |
精确率 |
召回率 |
F1 分数 |
样本数 |
中性 |
0.92 |
0.94 |
0.93 |
15886 |
积极 |
0.85 |
0.79 |
0.82 |
2481 |
悲伤 |
0.77 |
0.82 |
0.79 |
2506 |
愤怒 |
0.89 |
0.83 |
0.86 |
3072 |
其他 |
0.99 |
0.74 |
0.85 |
226 |
准确率 |
|
|
0.90 |
24171 |
宏平均 |
0.89 |
0.82 |
0.85 |
24171 |
加权平均 |
0.90 |
0.90 |
0.90 |
24171 |
📄 许可证
本项目采用 Apache-2.0 许可证。