模型简介
模型特点
模型能力
使用案例
🚀 Wav2Vec2-Large-XLSR-53-rw
该模型基于facebook/wav2vec2-large-xlsr-53在卢旺达语上进行微调,使用了Common Voice数据集,能有效进行语音识别,还可尝试预测缩写中的撇号。
🚀 快速开始
本模型是在卢旺达语上对 facebook/wav2vec2-large-xlsr-53 进行微调得到的。训练使用了 Common Voice 数据集约 25% 的训练数据(仅限于无反对票且时长少于 9.5 秒的语音),并在验证集的 2048 条语音上进行了验证。与 lucio/wav2vec2-large-xlsr-kinyarwanda 模型不同,该模型不预测任何标点符号,而本模型尝试预测用于标记代词与以元音开头的单词缩写的撇号,但可能存在过度泛化的情况。 使用此模型时,请确保语音输入的采样率为 16kHz。
✨ 主要特性
- 语言微调:针对卢旺达语进行了微调,更适合该语言的语音识别任务。
- 标点预测:尝试预测缩写中的撇号,增强了识别结果的完整性。
📦 安装指南
文档未提及安装相关内容,故跳过此章节。
💻 使用示例
基础用法
模型可以直接使用(无需语言模型),示例代码如下:
import torch
import torchaudio
from datasets import load_dataset
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
# WARNING! This will download and extract to use about 80GB on disk.
test_dataset = load_dataset("common_voice", "rw", split="test[:2%]")
processor = Wav2Vec2Processor.from_pretrained("lucio/wav2vec2-large-xlsr-kinyarwanda")
model = Wav2Vec2ForCTC.from_pretrained("lucio/wav2vec2-large-xlsr-kinyarwanda")
resampler = torchaudio.transforms.Resample(48_000, 16_000)
# Preprocessing the datasets.
# We need to read the audio files as arrays
def speech_file_to_array_fn(batch):
speech_array, sampling_rate = torchaudio.load(batch["path"])
batch["speech"] = resampler(speech_array).squeeze().numpy()
return batch
test_dataset = test_dataset.map(speech_file_to_array_fn)
inputs = processor(test_dataset[:2]["speech"], sampling_rate=16_000, return_tensors="pt", padding=True)
with torch.no_grad():
logits = model(inputs.input_values, attention_mask=inputs.attention_mask).logits
predicted_ids = torch.argmax(logits, dim=-1)
print("Prediction:", processor.batch_decode(predicted_ids))
print("Reference:", test_dataset["sentence"][:2])
运行结果:
Prediction: ['yaherukago gukora igitaramo yiki mujyiwa na mor mu bubiligi', "ibi rero ntibizashoboka kandi n'umudabizi"]
Reference: ['Yaherukaga gukora igitaramo nk’iki mu Mujyi wa Namur mu Bubiligi.', 'Ibi rero, ntibizashoboka, kandi nawe arabizi.']
📚 详细文档
评估
可以在 Common Voice 的卢旺达语测试数据上对模型进行评估。需要注意的是,仅加载测试数据就需要下载并解压整个 40GB 的卢旺达语数据集到另一个 40GB 的目录,因此磁盘需要有足够的空间(例如,在 Google Colab 的免费版本中无法完成)。此脚本使用了 pcuenq 的 chunked_wer
函数。
import jiwer
import torch
import torchaudio
from datasets import load_dataset, load_metric
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
import re
import unidecode
test_dataset = load_dataset("common_voice", "rw", split="test")
wer = load_metric("wer")
processor = Wav2Vec2Processor.from_pretrained("lucio/wav2vec2-large-xlsr-kinyarwanda-apostrophied")
model = Wav2Vec2ForCTC.from_pretrained("lucio/wav2vec2-large-xlsr-kinyarwanda-apostrophied")
model.to("cuda")
chars_to_ignore_regex = r'[!"#$%&()*+,./:;<=>?@\[\]\\_{}|~£¤¨©ª«¬®¯°·¸»¼½¾ðʺ˜˝ˮ‐–—―‚“”„‟•…″‽₋€™−√�]'
def remove_special_characters(batch):
batch["text"] = re.sub(r'[ʻʽʼ‘’´`]', r"'", batch["sentence"]) # normalize apostrophes
batch["text"] = re.sub(chars_to_ignore_regex, "", batch["text"]).lower().strip() # remove all other punctuation
batch["text"] = re.sub(r"([b-df-hj-np-tv-z])' ([aeiou])", r"\1'\2", batch["text"]) # remove spaces where apostrophe marks a deleted vowel
batch["text"] = re.sub(r"(-| '|' | +)", " ", batch["text"]) # treat dash and other apostrophes as word boundary
batch["text"] = unidecode.unidecode(batch["text"]) # strip accents from loanwords
return batch
## Audio pre-processing
resampler = torchaudio.transforms.Resample(48_000, 16_000)
def speech_file_to_array_fn(batch):
speech_array, sampling_rate = torchaudio.load(batch["path"])
batch["speech"] = resampler(speech_array).squeeze().numpy()
batch["sampling_rate"] = 16_000
return batch
def cv_prepare(batch):
batch = remove_special_characters(batch)
batch = speech_file_to_array_fn(batch)
return batch
test_dataset = test_dataset.map(cv_prepare)
# Preprocessing the datasets.
# We need to read the audio files as arrays
def evaluate(batch):
inputs = processor(batch["speech"], sampling_rate=16_000, return_tensors="pt", padding=True)
with torch.no_grad():
logits = model(inputs.input_values.to("cuda"), attention_mask=inputs.attention_mask.to("cuda")).logits
pred_ids = torch.argmax(logits, dim=-1)
batch["pred_strings"] = processor.batch_decode(pred_ids)
return batch
result = test_dataset.map(evaluate, batched=True, batch_size=8)
def chunked_wer(targets, predictions, chunk_size=None):
if chunk_size is None: return jiwer.wer(targets, predictions)
start = 0
end = chunk_size
H, S, D, I = 0, 0, 0, 0
while start < len(targets):
chunk_metrics = jiwer.compute_measures(targets[start:end], predictions[start:end])
H = H + chunk_metrics["hits"]
S = S + chunk_metrics["substitutions"]
D = D + chunk_metrics["deletions"]
I = I + chunk_metrics["insertions"]
start += chunk_size
end += chunk_size
return float(S + D + I) / float(H + S + D)
print("WER: {:2f}".format(100 * chunked_wer(result["sentence"], result["pred_strings"], chunk_size=4000)))
测试结果:39.92 %
训练
训练使用了 Common Voice 训练数据集中过滤掉有 down_vote
或时长超过 9.5 秒的语音后的示例。总共使用了约 12.5 万个示例,占可用数据的 25%,在 OVHcloud 提供的 1 个 V100 GPU 上进行了约 60 小时的训练:在 3.2 万个示例的一个批次上训练 20 个 epoch,然后在另外 3 个 3.2 万个示例的批次上各训练 10 个 epoch。验证使用了验证数据集的 2048 个示例。
训练脚本 改编自 transformers 仓库中的示例脚本。
📄 许可证
本项目采用 Apache-2.0 许可证。
属性 | 详情 |
---|---|
模型类型 | 基于 XLSR 的 Wav2Vec2 语音识别模型 |
训练数据 | Common Voice 数据集(过滤后约 12.5 万个示例,占可用数据的 25%) |
评估指标 | 词错误率(WER),测试结果为 39.92% |
许可证 | Apache-2.0 |



