🚀 Wav2Vec2-Base-100h
基於100小時Librispeech數據集預訓練和微調的語音識別基礎模型,可有效處理16kHz採樣的語音音頻。
🚀 快速開始
本模型是在16kHz採樣的語音音頻上,基於100小時的Librispeech數據集進行預訓練和微調的基礎模型。在使用該模型時,請確保輸入的語音也採樣為16kHz。
✨ 主要特性
- 首次證明了從純語音音頻中學習強大的表徵,然後在轉錄語音上進行微調,能夠在概念上更簡單的同時,超越最佳的半監督方法。
- wav2vec 2.0在潛在空間中對語音輸入進行掩碼處理,並解決了一個基於聯合學習的潛在表徵量化定義的對比任務。
- 使用Librispeech的所有標註數據進行實驗,在乾淨/其他測試集上實現了1.8/3.3的字錯率(WER)。
- 當將標註數據量減少到一小時時,wav2vec 2.0在100小時子集上的表現優於之前的最優方法,同時使用的標註數據少了100倍。
- 僅使用十分鐘的標註數據,並在53000小時的未標註數據上進行預訓練,仍然可以實現4.8/8.2的字錯率(WER),證明了在有限標註數據下進行語音識別的可行性。
📚 詳細文檔
論文信息
- 論文鏈接
- 作者:Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli
摘要
我們首次證明了,從純語音音頻中學習強大的表徵,然後在轉錄語音上進行微調,能夠在概念上更簡單的同時,超越最佳的半監督方法。wav2vec 2.0在潛在空間中對語音輸入進行掩碼處理,並解決了一個基於聯合學習的潛在表徵量化定義的對比任務。使用Librispeech的所有標註數據進行實驗,在乾淨/其他測試集上實現了1.8/3.3的字錯率(WER)。當將標註數據量減少到一小時時,wav2vec 2.0在100小時子集上的表現優於之前的最優方法,同時使用的標註數據少了100倍。僅使用十分鐘的標註數據,並在53000小時的未標註數據上進行預訓練,仍然可以實現4.8/8.2的字錯率(WER)。這證明了在有限標註數據下進行語音識別的可行性。
原始模型
原始模型可在此處找到。
💻 使用示例
基礎用法
以下代碼展示瞭如何將該模型作為獨立的聲學模型來轉錄音頻文件:
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
from datasets import load_dataset
import soundfile as sf
import torch
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-100h")
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-100h")
def map_to_array(batch):
speech, _ = sf.read(batch["file"])
batch["speech"] = speech
return batch
ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
ds = ds.map(map_to_array)
input_values = processor(ds[0]["audio"]["array"], return_tensors="pt", padding="longest").input_values
logits = model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)
高級用法
以下代碼展示瞭如何在LibriSpeech的“乾淨”和“其他”測試數據上評估 facebook/wav2vec2-base-100h 模型:
from datasets import load_dataset
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
import soundfile as sf
import torch
from jiwer import wer
librispeech_eval = load_dataset("librispeech_asr", "clean", split="test")
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-100h").to("cuda")
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-100h")
def map_to_pred(batch):
input_values = processor(batch["audio"]["array"], return_tensors="pt", padding="longest").input_values
with torch.no_grad():
logits = model(input_values.to("cuda")).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)
batch["transcription"] = transcription
return batch
result = librispeech_eval.map(map_to_pred, batched=True, batch_size=1, remove_columns=["speech"])
print("WER:", wer(result["text"], result["transcription"]))
評估結果
📄 許可證
本項目採用Apache-2.0許可證。
📦 數據集與標籤
屬性 |
詳情 |
數據集 |
librispeech_asr |
標籤 |
音頻、自動語音識別 |