🚀 Data2Vec-Audio-Large-960h
Data2Vec-Audio-Large-960h是一個在16kHz採樣的語音音頻上,基於960小時的Librispeech數據集進行預訓練和微調的大型模型。使用該模型時,請確保輸入的語音也採樣為16kHz。
🚀 快速開始
本模型源自Facebook的Data2Vec,相關研究論文可查看Paper 。作者包括Alexei Baevski、Wei-Ning Hsu、Qiantong Xu、Arun Babu、Jiatao Gu和Michael Auli。
摘要
雖然自監督學習的總體思路在不同模態之間是相同的,但實際的算法和目標卻大相徑庭,因為它們是針對單一模態開發的。為了更接近通用的自監督學習,我們提出了data2vec框架,該框架對語音、自然語言處理或計算機視覺使用相同的學習方法。其核心思想是在使用標準Transformer架構的自蒸餾設置中,基於輸入的掩碼視圖來預測整個輸入數據的潛在表示。與預測特定模態的目標(如單詞、視覺標記或人類語音單元,這些本質上是局部的)不同,data2vec預測包含整個輸入信息的上下文潛在表示。在語音識別、圖像分類和自然語言理解等主要基準測試上的實驗表明,該方法達到了新的技術水平,或者與主流方法具有競爭力。
原始模型可在此處找到。
✨ 主要特性
- 多模態適用性:使用相同的學習方法適用於語音、NLP或計算機視覺。
- 高性能表現:在語音識別、圖像分類和自然語言理解的主要基準測試中達到新的技術水平或具有競爭力。
📦 安裝指南
文檔未提及具體安裝步驟,可參考原始模型倉庫https://github.com/pytorch/fairseq/tree/main/examples/data2vec 進行安裝。
💻 使用示例
基礎用法
以下代碼展示瞭如何將該模型作為獨立的聲學模型來轉錄音頻文件:
from transformers import Wav2Vec2Processor, Data2VecAudioForCTC
from datasets import load_dataset
import torch
processor = Wav2Vec2Processor.from_pretrained("facebook/data2vec-audio-large-960h")
model = Data2VecAudioForCTC.from_pretrained("facebook/data2vec-audio-large-960h")
ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
input_values = processor(ds[0]["audio"]["array"], return_tensors="pt", padding="longest").input_values
logits = model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)
高級用法
以下代碼片段展示瞭如何在LibriSpeech的“clean”和“other”測試數據上評估 facebook/data2vec-audio-large-960h 模型:
from transformers import Wav2Vec2Processor, Data2VecAudioForCTC
from datasets import load_dataset
import torch
from jiwer import wer
processor = Wav2Vec2Processor.from_pretrained("facebook/data2vec-audio-large-960h").to("cuda")
model = Data2VecAudioForCTC.from_pretrained("facebook/data2vec-audio-large-960h")
librispeech_eval = load_dataset("librispeech_asr", "clean", split="test")
def map_to_pred(batch):
input_values = processor(batch["audio"]["array"], return_tensors="pt", padding="longest").input_values
with torch.no_grad():
logits = model(input_values.to("cuda")).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)
batch["transcription"] = transcription
return batch
result = librispeech_eval.map(map_to_pred, batched=True, batch_size=1, remove_columns=["audio"])
print("WER:", wer(result["text"], result["transcription"]))
評估結果(字錯誤率WER):
"clean" |
"other" |
1.89 |
4.07 |
🔧 技術細節
預訓練方法

更多信息請查看官方論文。
📄 許可證
本項目採用Apache 2.0許可證。