🚀 Data2Vec-Audio-Base-960h
基于Facebook的Data2Vec框架,在960小时的Librispeech语音音频数据上进行预训练和微调的基础模型,可用于自动语音识别任务。
🚀 快速开始
本模型是在16kHz采样的语音音频上,基于960小时的Librispeech数据进行预训练和微调的基础模型。使用该模型时,请确保输入的语音也采样为16kHz。
论文链接:Paper
作者:Alexei Baevski, Wei-Ning Hsu, Qiantong Xu, Arun Babu, Jiatao Gu, Michael Auli
摘要
虽然自监督学习的总体思路在不同模态之间是相同的,但实际的算法和目标却有很大差异,因为它们是针对单一模态开发的。为了更接近通用的自监督学习,我们提出了data2vec框架,该框架对语音、自然语言处理或计算机视觉使用相同的学习方法。其核心思想是在使用标准Transformer架构的自蒸馏设置中,基于输入的掩码视图来预测完整输入数据的潜在表示。与预测特定模态的目标(如单词、视觉标记或人类语音单元,这些本质上是局部的)不同,data2vec预测包含整个输入信息的上下文潜在表示。在语音识别、图像分类和自然语言理解的主要基准测试上的实验表明,该方法达到了新的技术水平,或与主流方法具有竞争力。
原始模型可在 https://github.com/pytorch/fairseq/tree/main/examples/data2vec 找到。
✨ 主要特性
- 多模态通用框架:使用相同的学习方法应用于语音、NLP和计算机视觉领域。
- 上下文潜在表示预测:预测包含整个输入信息的上下文潜在表示,而非局部特定模态目标。
- 优秀的实验表现:在语音识别、图像分类和自然语言理解的主要基准测试中达到新的技术水平或具有竞争力。
📦 安装指南
文档中未提及具体安装步骤,可参考原始模型仓库 https://github.com/pytorch/fairseq/tree/main/examples/data2vec 进行安装。
💻 使用示例
基础用法
以下代码展示了如何将该模型作为独立的声学模型来转录音频文件:
from transformers import Wav2Vec2Processor, Data2VecForCTC
from datasets import load_dataset
import torch
processor = Wav2Vec2Processor.from_pretrained("facebook/data2vec-audio-base-960h")
model = Data2VecForCTC.from_pretrained("facebook/data2vec-audio-base-960h")
ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
input_values = processor(ds[0]["audio"]["array"], return_tensors="pt", padding="longest").input_values
logits = model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)
高级用法
以下代码展示了如何在LibriSpeech的“clean”和“other”测试数据上评估 facebook/data2vec-audio-base-960h 模型:
from transformers import Wav2Vec2Processor, Data2VecForCTC
from datasets import load_dataset
import torch
from jiwer import wer
processor = Wav2Vec2Processor.from_pretrained("facebook/data2vec-audio-base-960h").to("cuda")
model = Data2VecForCTC.from_pretrained("facebook/data2vec-audio-base-960h")
librispeech_eval = load_dataset("librispeech_asr", "clean", split="test")
def map_to_pred(batch):
input_values = processor(batch["audio"]["array"], return_tensors="pt", padding="longest").input_values
with torch.no_grad():
logits = model(input_values.to("cuda")).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)
batch["transcription"] = transcription
return batch
result = librispeech_eval.map(map_to_pred, batched=True, batch_size=1, remove_columns=["audio"])
print("WER:", wer(result["text"], result["transcription"]))
评估结果
数据集类型 |
字错率 (WER) |
"clean" |
2.77 |
"other" |
7.08 |
📚 详细文档
预训练方法

更多信息,请参考 官方论文。
📄 许可证
本项目采用Apache-2.0许可证。
🔍 其他信息