🚀 SEW-tiny
SEW-tiny是基于16kHz采样语音音频预训练的基础模型。该模型可用于自动语音识别、说话人识别、意图分类、情感识别等下游任务。使用模型时,请确保输入的语音也是16kHz采样的。
🚀 快速开始
本模型基于 SEW by ASAPP Research 开发。相关论文为 Performance-Efficiency Trade-offs in Unsupervised Pre-training for Speech Recognition,作者包括 Felix Wu、Kwangyoun Kim、Jing Pan、Kyu Han、Kilian Q. Weinberger、Yoav Artzi。
摘要:本文研究了自动语音识别(ASR)预训练模型中的性能 - 效率权衡问题。聚焦于wav2vec 2.0,我们对影响模型性能和效率的几种架构设计进行了形式化。综合所有观察结果,我们推出了SEW(Squeezed and Efficient Wav2vec),这是一种预训练模型架构,在各种训练设置下,在性能和效率方面都有显著提升。例如,在LibriSpeech的100h - 960h半监督设置下,与wav2vec 2.0相比,SEW的推理速度提高了1.9倍,单词错误率相对降低了13.5%。在推理时间相近的情况下,SEW在不同模型规模下将单词错误率降低了25 - 50%。
原始模型可在 https://github.com/asappresearch/sew#model-checkpoints 找到。
✨ 主要特性
- 适用音频格式:适用于16kHz采样的语音音频。
- 下游任务广泛:可用于自动语音识别、说话人识别、意图分类、情感识别等下游任务。
- 性能与效率提升:在性能和效率方面相较于wav2vec 2.0有显著提升。
📦 安装指南
文档中未提及具体安装步骤,可参考相关依赖库的安装方式,如 transformers
、datasets
、soundfile
、torch
、jiwer
等。
💻 使用示例
基础用法
以下代码展示了如何将该模型作为独立的声学模型来转录音频文件:
from transformers import Wav2Vec2Processor, SEWForCTC
from datasets import load_dataset
import soundfile as sf
import torch
processor = Wav2Vec2Processor.from_pretrained("asapp/sew-tiny-100k-ft-ls100h")
model = SEWForCTC.from_pretrained("asapp/sew-tiny-100k-ft-ls100h")
ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
input_values = processor(ds[0]["audio"]["array"], return_tensors="pt").input_values
logits = model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)
高级用法
以下代码展示了如何在LibriSpeech的“clean”和“other”测试数据上评估 asapp/sew-tiny-100k-ft-ls100h 模型:
from datasets import load_dataset
from transformers import SEWForCTC, Wav2Vec2Processor
import torch
from jiwer import wer
librispeech_eval = load_dataset("librispeech_asr", "clean", split="test")
model = SEWForCTC.from_pretrained("asapp/sew-tiny-100k-ft-ls100h").to("cuda")
processor = Wav2Vec2Processor.from_pretrained("asapp/sew-tiny-100k-ft-ls100h")
def map_to_pred(batch):
input_values = processor(batch["audio"][0]["array"], sampling_rate=16000,
return_tensors="pt", padding="longest").input_values
with torch.no_grad():
logits = model(input_values.to("cuda")).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)
batch["transcription"] = transcription
return batch
result = librispeech_eval.map(map_to_pred, batched=True, batch_size=1, remove_columns=["audio"])
print("WER:", wer(result["text"], result["transcription"]))
📚 详细文档
- 数据集:使用了
librispeech_asr
数据集。
- 评估指标:使用单词错误率(WER)进行评估。
🔧 技术细节
本文聚焦于wav2vec 2.0,对影响模型性能和效率的几种架构设计进行了形式化。推出的SEW模型在各种训练设置下,在性能和效率方面都有显著提升。例如,在LibriSpeech的100h - 960h半监督设置下,与wav2vec 2.0相比,SEW的推理速度提高了1.9倍,单词错误率相对降低了13.5%。在推理时间相近的情况下,SEW在不同模型规模下将单词错误率降低了25 - 50%。
📄 许可证
本项目采用 apache-2.0
许可证。
相关信息表格
属性 |
详情 |
模型类型 |
SEW-tiny |
训练数据 |
librispeech_asr |
标签 |
音频、语音、自动语音识别、hf-asr-leaderboard |
许可证 |
apache-2.0 |
模型评估结果
数据集 |
测试WER |
LibriSpeech (clean) |
10.61 |
LibriSpeech (other) |
23.74 |