🚀 Wav2Vec2-2-Bart-Large-Tedlium
本模型是一個序列到序列(seq2seq)模型,在 TEDLIUM 語料庫(第3版)上進行訓練。它將語音編碼器與文本解碼器相結合,以執行自動語音識別任務。編碼器的權重使用來自 @facebook 的 Wav2Vec2 LV-60k 檢查點進行初始化。解碼器的權重使用來自 @facebook 的 Bart large 檢查點進行初始化。
使用該模型時,請確保您的語音輸入採樣率為 16Khz。該模型在開發集上的詞錯誤率(WER)為 9.0%,在測試集上為 6.4%。訓練日誌 記錄了 50k 步微調過程中的訓練和評估進度。
🚀 快速開始
本模型可作為獨立的聲學模型用於轉錄音頻文件,使用方法如下:
💻 使用示例
基礎用法
from transformers import AutoProcessor, SpeechEncoderDecoderModel
from datasets import load_dataset
import torch
processor = AutoProcessor.from_pretrained("sanchit-gandhi/wav2vec2-2-bart-large-tedlium")
model = SpeechEncoderDecoderModel.from_pretrained("sanchit-gandhi/wav2vec2-2-bart-large-tedlium")
ds = load_dataset("sanchit-gandhi/tedlium_dummy", split="validation")
input_values = processor(ds[0]["audio"]["array"], return_tensors="pt", padding="longest").input_values
generated = model.generate(input_values)
decoded = processor.batch_decode(generated, skip_special_tokens=True)
print("Target: ", ds["text"][0])
print("Transcription: ", decoded[0])
📚 詳細文檔
評估
以下代碼片段展示瞭如何在 TEDLIUM 測試數據上評估 Wav2Vec2-Large-Tedlium 模型:
from datasets import load_dataset
from transformers import AutoProcessor, SpeechEncoderDecoderModel
import torch
from jiwer import wer
tedlium_eval = load_dataset("LIUM/tedlium", "release3", split="test")
def filter_ds(text):
return text != "ignore_time_segment_in_scoring"
tedlium_eval = tedlium_eval.map(filter_ds, input_columns=["text"])
model = SpeechEncoderDecoderModel.from_pretrained("sanchit-gandhi/wav2vec2-2-bart-large-tedlium").to("cuda")
processor = AutoProcessor.from_pretrained("sanchit-gandhi/wav2vec2-2-bart-large-tedlium")
gen_kwargs = {
"max_length": 200,
"num_beams": 5,
"length_penalty": 1.2
}
def map_to_pred(batch):
input_values = processor(batch["audio"]["array"], return_tensors="pt", padding="longest").input_values
with torch.no_grad():
generated = model.generate(input_values.to("cuda"), **gen_kwargs)
decoded = processor.batch_decode(generated, skip_special_tokens=True)
batch["transcription"] = decoded[0]
return batch
result = tedlium_eval.map(map_to_pred, batched=True, batch_size=1, remove_columns=["speech"])
print("WER:", wer(result["text"], result["transcription"]))
📄 許可證
本項目採用 CC BY 4.0 許可證。
📊 模型指標
屬性 |
詳情 |
模型類型 |
序列到序列(seq2seq)模型 |
訓練數據 |
TEDLIUM 語料庫(第3版) |
開發集詞錯誤率(WER) |
9.0% |
測試集詞錯誤率(WER) |
6.4% |
⚠️ 重要提示
使用該模型時,請確保您的語音輸入採樣率為 16Khz。