🚀 Data2Vec-Audio-Large-960h
Data2Vec-Audio-Large-960h是一个在16kHz采样的语音音频上,基于960小时的Librispeech数据集进行预训练和微调的大型模型。使用该模型时,请确保输入的语音也采样为16kHz。
🚀 快速开始
本模型源自Facebook的Data2Vec,相关研究论文可查看Paper 。作者包括Alexei Baevski、Wei-Ning Hsu、Qiantong Xu、Arun Babu、Jiatao Gu和Michael Auli。
摘要
虽然自监督学习的总体思路在不同模态之间是相同的,但实际的算法和目标却大相径庭,因为它们是针对单一模态开发的。为了更接近通用的自监督学习,我们提出了data2vec框架,该框架对语音、自然语言处理或计算机视觉使用相同的学习方法。其核心思想是在使用标准Transformer架构的自蒸馏设置中,基于输入的掩码视图来预测整个输入数据的潜在表示。与预测特定模态的目标(如单词、视觉标记或人类语音单元,这些本质上是局部的)不同,data2vec预测包含整个输入信息的上下文潜在表示。在语音识别、图像分类和自然语言理解等主要基准测试上的实验表明,该方法达到了新的技术水平,或者与主流方法具有竞争力。
原始模型可在此处找到。
✨ 主要特性
- 多模态适用性:使用相同的学习方法适用于语音、NLP或计算机视觉。
- 高性能表现:在语音识别、图像分类和自然语言理解的主要基准测试中达到新的技术水平或具有竞争力。
📦 安装指南
文档未提及具体安装步骤,可参考原始模型仓库https://github.com/pytorch/fairseq/tree/main/examples/data2vec 进行安装。
💻 使用示例
基础用法
以下代码展示了如何将该模型作为独立的声学模型来转录音频文件:
from transformers import Wav2Vec2Processor, Data2VecAudioForCTC
from datasets import load_dataset
import torch
processor = Wav2Vec2Processor.from_pretrained("facebook/data2vec-audio-large-960h")
model = Data2VecAudioForCTC.from_pretrained("facebook/data2vec-audio-large-960h")
ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
input_values = processor(ds[0]["audio"]["array"], return_tensors="pt", padding="longest").input_values
logits = model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)
高级用法
以下代码片段展示了如何在LibriSpeech的“clean”和“other”测试数据上评估 facebook/data2vec-audio-large-960h 模型:
from transformers import Wav2Vec2Processor, Data2VecAudioForCTC
from datasets import load_dataset
import torch
from jiwer import wer
processor = Wav2Vec2Processor.from_pretrained("facebook/data2vec-audio-large-960h").to("cuda")
model = Data2VecAudioForCTC.from_pretrained("facebook/data2vec-audio-large-960h")
librispeech_eval = load_dataset("librispeech_asr", "clean", split="test")
def map_to_pred(batch):
input_values = processor(batch["audio"]["array"], return_tensors="pt", padding="longest").input_values
with torch.no_grad():
logits = model(input_values.to("cuda")).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)
batch["transcription"] = transcription
return batch
result = librispeech_eval.map(map_to_pred, batched=True, batch_size=1, remove_columns=["audio"])
print("WER:", wer(result["text"], result["transcription"]))
评估结果(字错误率WER):
"clean" |
"other" |
1.89 |
4.07 |
🔧 技术细节
预训练方法

更多信息请查看官方论文。
📄 许可证
本项目采用Apache 2.0许可证。