🚀 Wav2Vec2-2-Bart-Large-Tedlium
本模型是一个序列到序列(seq2seq)模型,在 TEDLIUM 语料库(第3版)上进行训练。它将语音编码器与文本解码器相结合,以执行自动语音识别任务。编码器的权重使用来自 @facebook 的 Wav2Vec2 LV-60k 检查点进行初始化。解码器的权重使用来自 @facebook 的 Bart large 检查点进行初始化。
使用该模型时,请确保您的语音输入采样率为 16Khz。该模型在开发集上的词错误率(WER)为 9.0%,在测试集上为 6.4%。训练日志 记录了 50k 步微调过程中的训练和评估进度。
🚀 快速开始
本模型可作为独立的声学模型用于转录音频文件,使用方法如下:
💻 使用示例
基础用法
from transformers import AutoProcessor, SpeechEncoderDecoderModel
from datasets import load_dataset
import torch
processor = AutoProcessor.from_pretrained("sanchit-gandhi/wav2vec2-2-bart-large-tedlium")
model = SpeechEncoderDecoderModel.from_pretrained("sanchit-gandhi/wav2vec2-2-bart-large-tedlium")
ds = load_dataset("sanchit-gandhi/tedlium_dummy", split="validation")
input_values = processor(ds[0]["audio"]["array"], return_tensors="pt", padding="longest").input_values
generated = model.generate(input_values)
decoded = processor.batch_decode(generated, skip_special_tokens=True)
print("Target: ", ds["text"][0])
print("Transcription: ", decoded[0])
📚 详细文档
评估
以下代码片段展示了如何在 TEDLIUM 测试数据上评估 Wav2Vec2-Large-Tedlium 模型:
from datasets import load_dataset
from transformers import AutoProcessor, SpeechEncoderDecoderModel
import torch
from jiwer import wer
tedlium_eval = load_dataset("LIUM/tedlium", "release3", split="test")
def filter_ds(text):
return text != "ignore_time_segment_in_scoring"
tedlium_eval = tedlium_eval.map(filter_ds, input_columns=["text"])
model = SpeechEncoderDecoderModel.from_pretrained("sanchit-gandhi/wav2vec2-2-bart-large-tedlium").to("cuda")
processor = AutoProcessor.from_pretrained("sanchit-gandhi/wav2vec2-2-bart-large-tedlium")
gen_kwargs = {
"max_length": 200,
"num_beams": 5,
"length_penalty": 1.2
}
def map_to_pred(batch):
input_values = processor(batch["audio"]["array"], return_tensors="pt", padding="longest").input_values
with torch.no_grad():
generated = model.generate(input_values.to("cuda"), **gen_kwargs)
decoded = processor.batch_decode(generated, skip_special_tokens=True)
batch["transcription"] = decoded[0]
return batch
result = tedlium_eval.map(map_to_pred, batched=True, batch_size=1, remove_columns=["speech"])
print("WER:", wer(result["text"], result["transcription"]))
📄 许可证
本项目采用 CC BY 4.0 许可证。
📊 模型指标
属性 |
详情 |
模型类型 |
序列到序列(seq2seq)模型 |
训练数据 |
TEDLIUM 语料库(第3版) |
开发集词错误率(WER) |
9.0% |
测试集词错误率(WER) |
6.4% |
⚠️ 重要提示
使用该模型时,请确保您的语音输入采样率为 16Khz。