🚀 Wav2vec 2.0 搭配巴西葡萄牙語開放數據集 v2
本項目展示了一個針對巴西葡萄牙語微調的 Wav2vec 模型,使用了以下數據集進行訓練:
- CETUC:包含約 145 小時的巴西葡萄牙語語音,分佈在 50 名男性和 50 名女性說話者中,每人朗讀約 1000 個從 CETEN - Folha 語料庫中挑選的語音平衡句子。
- 多語言 Librispeech (MLS):一個包含多種語言的大型數據集。MLS 基於 LibriVox 等公共領域的有聲讀物錄音。該數據集總共包含 6000 小時多種語言的轉錄數據。本項目中使用的葡萄牙語數據集(主要是巴西變體)點擊此處,約有 284 小時的語音,來自 62 位朗讀者朗讀的 55 本有聲讀物。
- VoxForge:一個旨在構建聲學模型開放數據集的項目。該語料庫包含約 100 位說話者的 4130 條巴西葡萄牙語語音,採樣率從 16kHz 到 44.1kHz 不等。
- Common Voice 6.1:由 Mozilla 基金會發起的項目,旨在創建多種語言的廣泛開放數據集以訓練自動語音識別(ASR)模型。在這個項目中,志願者通過官方網站捐贈和驗證語音數據。本項目使用的葡萄牙語數據集(主要是巴西變體)是 6.1 版本(pt_63h_2020 - 12 - 11),包含約 50 小時經過驗證的語音和 1120 位獨特的說話者。
- [Lapsbm](https://github.com/falabrasil/gitlab - resources):“Falabrasil - UFPA”是 Fala Brasil 團隊用於評估巴西葡萄牙語 ASR 系統的數據集。包含 35 位說話者(其中 10 位女性),每人朗讀 20 條獨特的句子,總共 700 條巴西葡萄牙語語音。音頻以 22.05kHz 採樣,未進行環境控制。
這些數據集被合併以構建一個更大的巴西葡萄牙語數據集。除了 Common Voice 的開發集和測試集分別用於驗證和測試外,所有數據都用於訓練。
原始模型使用 fairseq 進行微調。本項目使用的是原始模型的轉換版本。
⚠️ 重要提示
Common Voice 測試報告的字錯率(WER)為 10%,然而,該模型使用了 Common Voice 中除測試集實例之外的所有驗證實例進行訓練。這意味著訓練集中的一些說話者可能出現在測試集中。
🚀 快速開始
安裝依賴
%%capture
!pip install datasets
!pip install jiwer
!pip install torchaudio
!pip install transformers
!pip install soundfile
導入必要的庫
import torchaudio
from datasets import load_dataset, load_metric
from transformers import (
Wav2Vec2ForCTC,
Wav2Vec2Processor,
)
import torch
import re
import sys
準備工作
chars_to_ignore_regex = '[\,\?\.\!\;\:\"]'
wer = load_metric("wer")
device = "cuda"
加載模型和處理器
model_name = 'lgris/wav2vec2-large-xlsr-open-brazilian-portuguese-v2'
model = Wav2Vec2ForCTC.from_pretrained(model_name).to(device)
processor = Wav2Vec2Processor.from_pretrained(model_name)
定義預測函數
def map_to_pred(batch):
features = processor(batch["speech"], sampling_rate=batch["sampling_rate"][0], padding=True, return_tensors="pt")
input_values = features.input_values.to(device)
attention_mask = features.attention_mask.to(device)
with torch.no_grad():
logits = model(input_values, attention_mask=attention_mask).logits
pred_ids = torch.argmax(logits, dim=-1)
batch["predicted"] = processor.batch_decode(pred_ids)
batch["predicted"] = [pred.lower() for pred in batch["predicted"]]
batch["target"] = batch["sentence"]
return batch
💻 使用示例
針對 Common Voice 進行測試(領域內)
dataset = load_dataset("common_voice", "pt", split="test", data_dir="./cv-corpus-6.1-2020-12-11")
resampler = torchaudio.transforms.Resample(orig_freq=48_000, new_freq=16_000)
def map_to_array(batch):
speech, _ = torchaudio.load(batch["path"])
batch["speech"] = resampler.forward(speech.squeeze(0)).numpy()
batch["sampling_rate"] = resampler.new_freq
batch["sentence"] = re.sub(chars_to_ignore_regex, '', batch["sentence"]).lower().replace("’", "'")
return batch
ds = dataset.map(map_to_array)
result = ds.map(map_to_pred, batched=True, batch_size=1, remove_columns=list(ds.features.keys()))
print(wer.compute(predictions=result["predicted"], references=result["target"]))
for pred, target in zip(result["predicted"][:10], result["target"][:10]):
print(pred, "|", target)
結果:10.69%
針對 TEDx 進行測試(領域外)
!gdown --id 1HJEnvthaGYwcV_whHEywgH2daIN4bQna
!tar -xf tedx.tar.gz
dataset = load_dataset('csv', data_files={'test': 'test.csv'})['test']
def map_to_array(batch):
speech, _ = torchaudio.load(batch["path"])
batch["speech"] = speech.squeeze(0).numpy()
batch["sampling_rate"] = resampler.new_freq
batch["sentence"] = re.sub(chars_to_ignore_regex, '', batch["sentence"]).lower().replace("’", "'")
return batch
ds = dataset.map(map_to_array)
result = ds.map(map_to_pred, batched=True, batch_size=1, remove_columns=list(ds.features.keys()))
print(wer.compute(predictions=result["predicted"], references=result["target"]))
for pred, target in zip(result["predicted"][:10], result["target"][:10]):
print(pred, "|", target)
結果:34.53%
📚 詳細文檔
模型信息
屬性 |
詳情 |
模型類型 |
wav2vec2 - large - xlsr - open - brazilian - portuguese - v2 |
訓練數據 |
Common Voice、Multilingual Librispeech (MLS)、CETUC、Lapsbm、VoxForge |
評估指標 |
字錯率(WER) |
標籤信息
- 音頻
- 語音
- wav2vec2
- 葡萄牙語
- 葡萄牙語語音語料庫
- 自動語音識別
- PyTorch
- hf - asr - leaderboard
模型索引
模型名稱為 wav2vec2 - large - xlsr - open - brazilian - portuguese - v2,在 Common Voice 數據集上的測試字錯率(WER)為 10.69%。
📄 許可證
本項目採用 Apache 2.0 許可證。