🚀 SEW-D-mid
SEW-D-mid 是基于 16kHz 采样语音音频预训练的基础模型。该模型由 ASAPP Research 开发,相关信息可参考 SEW-D by ASAPP Research。使用此模型时,需确保输入的语音也采样为 16kHz。请注意,该模型需要在下游任务(如自动语音识别、说话人识别、意图分类、情感识别等)上进行微调。
论文:Performance-Efficiency Trade-offs in Unsupervised Pre-training for Speech Recognition
作者:Felix Wu、Kwangyoun Kim、Jing Pan、Kyu Han、Kilian Q. Weinberger、Yoav Artzi
摘要
本文研究了自动语音识别(ASR)预训练模型中的性能 - 效率权衡问题。我们聚焦于 wav2vec 2.0,并对影响模型性能和效率的几种架构设计进行了形式化。综合所有观察结果,我们提出了 SEW(Squeezed and Efficient Wav2vec),这是一种在各种训练设置下,在性能和效率方面都有显著改进的预训练模型架构。例如,在 LibriSpeech 的 100h - 960h 半监督设置下,与 wav2vec 2.0 相比,SEW 的推理速度提高了 1.9 倍,词错误率相对降低了 13.5%。在推理时间相近的情况下,SEW 在不同模型大小下将词错误率降低了 25 - 50%。
原始模型可在 https://github.com/asappresearch/sew#model-checkpoints 找到。
🚀 快速开始
模型信息
属性 |
详情 |
模型类型 |
语音处理模型 |
训练数据 |
LibriSpeech 数据集 |
标签 |
音频、语音、自动语音识别、HF 自动语音识别排行榜 |
许可证 |
Apache-2.0 |
示例音频
评估结果
数据集 |
测试词错误率 (WER) |
LibriSpeech (clean) |
4.94 |
LibriSpeech (other) |
11.51 |
✨ 主要特性
- 基于 16kHz 采样语音音频进行预训练。
- 可应用于多种下游任务,如自动语音识别、说话人识别等。
- 在性能和效率方面有显著改进。
📦 安装指南
文档未提及安装步骤,暂不提供。
💻 使用示例
基础用法
以下代码展示了如何将该模型作为独立的声学模型来转录音频文件:
from transformers import Wav2Vec2Processor, SEWDForCTC
from datasets import load_dataset
import soundfile as sf
import torch
processor = Wav2Vec2Processor.from_pretrained("asapp/sew-d-mid-400k-ft-ls100h")
model = SEWDForCTC.from_pretrained("asapp/sew-d-mid-400k-ft-ls100h")
ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
input_values = processor(ds[0]["audio"]["array"], return_tensors="pt").input_values
logits = model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)
高级用法
以下代码展示了如何在 LibriSpeech 的 “clean” 和 “other” 测试数据上评估 asapp/sew-d-mid-400k-ft-ls100hh 模型:
from datasets import load_dataset
from transformers import SEWDForCTC, Wav2Vec2Processor
import torch
from jiwer import wer
librispeech_eval = load_dataset("librispeech_asr", "clean", split="test")
model = SEWDForCTC.from_pretrained("asapp/sew-d-mid-400k-ft-ls100h").to("cuda")
processor = Wav2Vec2Processor.from_pretrained("asapp/sew-d-mid-400k-ft-ls100h")
def map_to_pred(batch):
input_values = processor(batch["audio"][0]["array"], sampling_rate=16000,
return_tensors="pt", padding="longest").input_values
with torch.no_grad():
logits = model(input_values.to("cuda")).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)
batch["transcription"] = transcription
return batch
result = librispeech_eval.map(map_to_pred, batched=True, batch_size=1, remove_columns=["audio"])
print("WER:", wer(result["text"], result["transcription"]))
📚 详细文档
注意事项
⚠️ 重要提示
使用该模型时,确保输入的语音采样率为 16kHz。同时,该模型需要在下游任务上进行微调。
论文引用
📄 许可证
本项目采用 Apache-2.0 许可证。