模型概述
模型特點
模型能力
使用案例
🚀 Wav2Vec2-Large-XLSR-53-rw
該模型基於facebook/wav2vec2-large-xlsr-53在盧旺達語上進行微調,使用了Common Voice數據集,能有效進行語音識別,還可嘗試預測縮寫中的撇號。
🚀 快速開始
本模型是在盧旺達語上對 facebook/wav2vec2-large-xlsr-53 進行微調得到的。訓練使用了 Common Voice 數據集約 25% 的訓練數據(僅限於無反對票且時長少於 9.5 秒的語音),並在驗證集的 2048 條語音上進行了驗證。與 lucio/wav2vec2-large-xlsr-kinyarwanda 模型不同,該模型不預測任何標點符號,而本模型嘗試預測用於標記代詞與以元音開頭的單詞縮寫的撇號,但可能存在過度泛化的情況。 使用此模型時,請確保語音輸入的採樣率為 16kHz。
✨ 主要特性
- 語言微調:針對盧旺達語進行了微調,更適合該語言的語音識別任務。
- 標點預測:嘗試預測縮寫中的撇號,增強了識別結果的完整性。
📦 安裝指南
文檔未提及安裝相關內容,故跳過此章節。
💻 使用示例
基礎用法
模型可以直接使用(無需語言模型),示例代碼如下:
import torch
import torchaudio
from datasets import load_dataset
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
# WARNING! This will download and extract to use about 80GB on disk.
test_dataset = load_dataset("common_voice", "rw", split="test[:2%]")
processor = Wav2Vec2Processor.from_pretrained("lucio/wav2vec2-large-xlsr-kinyarwanda")
model = Wav2Vec2ForCTC.from_pretrained("lucio/wav2vec2-large-xlsr-kinyarwanda")
resampler = torchaudio.transforms.Resample(48_000, 16_000)
# Preprocessing the datasets.
# We need to read the audio files as arrays
def speech_file_to_array_fn(batch):
speech_array, sampling_rate = torchaudio.load(batch["path"])
batch["speech"] = resampler(speech_array).squeeze().numpy()
return batch
test_dataset = test_dataset.map(speech_file_to_array_fn)
inputs = processor(test_dataset[:2]["speech"], sampling_rate=16_000, return_tensors="pt", padding=True)
with torch.no_grad():
logits = model(inputs.input_values, attention_mask=inputs.attention_mask).logits
predicted_ids = torch.argmax(logits, dim=-1)
print("Prediction:", processor.batch_decode(predicted_ids))
print("Reference:", test_dataset["sentence"][:2])
運行結果:
Prediction: ['yaherukago gukora igitaramo yiki mujyiwa na mor mu bubiligi', "ibi rero ntibizashoboka kandi n'umudabizi"]
Reference: ['Yaherukaga gukora igitaramo nk’iki mu Mujyi wa Namur mu Bubiligi.', 'Ibi rero, ntibizashoboka, kandi nawe arabizi.']
📚 詳細文檔
評估
可以在 Common Voice 的盧旺達語測試數據上對模型進行評估。需要注意的是,僅加載測試數據就需要下載並解壓整個 40GB 的盧旺達語數據集到另一個 40GB 的目錄,因此磁盤需要有足夠的空間(例如,在 Google Colab 的免費版本中無法完成)。此腳本使用了 pcuenq 的 chunked_wer
函數。
import jiwer
import torch
import torchaudio
from datasets import load_dataset, load_metric
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
import re
import unidecode
test_dataset = load_dataset("common_voice", "rw", split="test")
wer = load_metric("wer")
processor = Wav2Vec2Processor.from_pretrained("lucio/wav2vec2-large-xlsr-kinyarwanda-apostrophied")
model = Wav2Vec2ForCTC.from_pretrained("lucio/wav2vec2-large-xlsr-kinyarwanda-apostrophied")
model.to("cuda")
chars_to_ignore_regex = r'[!"#$%&()*+,./:;<=>?@\[\]\\_{}|~£¤¨©ª«¬®¯°·¸»¼½¾ðʺ˜˝ˮ‐–—―‚“”„‟•…″‽₋€™−√�]'
def remove_special_characters(batch):
batch["text"] = re.sub(r'[ʻʽʼ‘’´`]', r"'", batch["sentence"]) # normalize apostrophes
batch["text"] = re.sub(chars_to_ignore_regex, "", batch["text"]).lower().strip() # remove all other punctuation
batch["text"] = re.sub(r"([b-df-hj-np-tv-z])' ([aeiou])", r"\1'\2", batch["text"]) # remove spaces where apostrophe marks a deleted vowel
batch["text"] = re.sub(r"(-| '|' | +)", " ", batch["text"]) # treat dash and other apostrophes as word boundary
batch["text"] = unidecode.unidecode(batch["text"]) # strip accents from loanwords
return batch
## Audio pre-processing
resampler = torchaudio.transforms.Resample(48_000, 16_000)
def speech_file_to_array_fn(batch):
speech_array, sampling_rate = torchaudio.load(batch["path"])
batch["speech"] = resampler(speech_array).squeeze().numpy()
batch["sampling_rate"] = 16_000
return batch
def cv_prepare(batch):
batch = remove_special_characters(batch)
batch = speech_file_to_array_fn(batch)
return batch
test_dataset = test_dataset.map(cv_prepare)
# Preprocessing the datasets.
# We need to read the audio files as arrays
def evaluate(batch):
inputs = processor(batch["speech"], sampling_rate=16_000, return_tensors="pt", padding=True)
with torch.no_grad():
logits = model(inputs.input_values.to("cuda"), attention_mask=inputs.attention_mask.to("cuda")).logits
pred_ids = torch.argmax(logits, dim=-1)
batch["pred_strings"] = processor.batch_decode(pred_ids)
return batch
result = test_dataset.map(evaluate, batched=True, batch_size=8)
def chunked_wer(targets, predictions, chunk_size=None):
if chunk_size is None: return jiwer.wer(targets, predictions)
start = 0
end = chunk_size
H, S, D, I = 0, 0, 0, 0
while start < len(targets):
chunk_metrics = jiwer.compute_measures(targets[start:end], predictions[start:end])
H = H + chunk_metrics["hits"]
S = S + chunk_metrics["substitutions"]
D = D + chunk_metrics["deletions"]
I = I + chunk_metrics["insertions"]
start += chunk_size
end += chunk_size
return float(S + D + I) / float(H + S + D)
print("WER: {:2f}".format(100 * chunked_wer(result["sentence"], result["pred_strings"], chunk_size=4000)))
測試結果:39.92 %
訓練
訓練使用了 Common Voice 訓練數據集中過濾掉有 down_vote
或時長超過 9.5 秒的語音後的示例。總共使用了約 12.5 萬個示例,佔可用數據的 25%,在 OVHcloud 提供的 1 個 V100 GPU 上進行了約 60 小時的訓練:在 3.2 萬個示例的一個批次上訓練 20 個 epoch,然後在另外 3 個 3.2 萬個示例的批次上各訓練 10 個 epoch。驗證使用了驗證數據集的 2048 個示例。
訓練腳本 改編自 transformers 倉庫中的示例腳本。
📄 許可證
本項目採用 Apache-2.0 許可證。
屬性 | 詳情 |
---|---|
模型類型 | 基於 XLSR 的 Wav2Vec2 語音識別模型 |
訓練數據 | Common Voice 數據集(過濾後約 12.5 萬個示例,佔可用數據的 25%) |
評估指標 | 詞錯誤率(WER),測試結果為 39.92% |
許可證 | Apache-2.0 |



