🚀 Wav2vec 2.0 搭配巴西葡萄牙语开放数据集 v2
本项目展示了一个针对巴西葡萄牙语微调的 Wav2vec 模型,使用了以下数据集进行训练:
- CETUC:包含约 145 小时的巴西葡萄牙语语音,分布在 50 名男性和 50 名女性说话者中,每人朗读约 1000 个从 CETEN - Folha 语料库中挑选的语音平衡句子。
- 多语言 Librispeech (MLS):一个包含多种语言的大型数据集。MLS 基于 LibriVox 等公共领域的有声读物录音。该数据集总共包含 6000 小时多种语言的转录数据。本项目中使用的葡萄牙语数据集(主要是巴西变体)点击此处,约有 284 小时的语音,来自 62 位朗读者朗读的 55 本有声读物。
- VoxForge:一个旨在构建声学模型开放数据集的项目。该语料库包含约 100 位说话者的 4130 条巴西葡萄牙语语音,采样率从 16kHz 到 44.1kHz 不等。
- Common Voice 6.1:由 Mozilla 基金会发起的项目,旨在创建多种语言的广泛开放数据集以训练自动语音识别(ASR)模型。在这个项目中,志愿者通过官方网站捐赠和验证语音数据。本项目使用的葡萄牙语数据集(主要是巴西变体)是 6.1 版本(pt_63h_2020 - 12 - 11),包含约 50 小时经过验证的语音和 1120 位独特的说话者。
- [Lapsbm](https://github.com/falabrasil/gitlab - resources):“Falabrasil - UFPA”是 Fala Brasil 团队用于评估巴西葡萄牙语 ASR 系统的数据集。包含 35 位说话者(其中 10 位女性),每人朗读 20 条独特的句子,总共 700 条巴西葡萄牙语语音。音频以 22.05kHz 采样,未进行环境控制。
这些数据集被合并以构建一个更大的巴西葡萄牙语数据集。除了 Common Voice 的开发集和测试集分别用于验证和测试外,所有数据都用于训练。
原始模型使用 fairseq 进行微调。本项目使用的是原始模型的转换版本。
⚠️ 重要提示
Common Voice 测试报告的字错率(WER)为 10%,然而,该模型使用了 Common Voice 中除测试集实例之外的所有验证实例进行训练。这意味着训练集中的一些说话者可能出现在测试集中。
🚀 快速开始
安装依赖
%%capture
!pip install datasets
!pip install jiwer
!pip install torchaudio
!pip install transformers
!pip install soundfile
导入必要的库
import torchaudio
from datasets import load_dataset, load_metric
from transformers import (
Wav2Vec2ForCTC,
Wav2Vec2Processor,
)
import torch
import re
import sys
准备工作
chars_to_ignore_regex = '[\,\?\.\!\;\:\"]'
wer = load_metric("wer")
device = "cuda"
加载模型和处理器
model_name = 'lgris/wav2vec2-large-xlsr-open-brazilian-portuguese-v2'
model = Wav2Vec2ForCTC.from_pretrained(model_name).to(device)
processor = Wav2Vec2Processor.from_pretrained(model_name)
定义预测函数
def map_to_pred(batch):
features = processor(batch["speech"], sampling_rate=batch["sampling_rate"][0], padding=True, return_tensors="pt")
input_values = features.input_values.to(device)
attention_mask = features.attention_mask.to(device)
with torch.no_grad():
logits = model(input_values, attention_mask=attention_mask).logits
pred_ids = torch.argmax(logits, dim=-1)
batch["predicted"] = processor.batch_decode(pred_ids)
batch["predicted"] = [pred.lower() for pred in batch["predicted"]]
batch["target"] = batch["sentence"]
return batch
💻 使用示例
针对 Common Voice 进行测试(领域内)
dataset = load_dataset("common_voice", "pt", split="test", data_dir="./cv-corpus-6.1-2020-12-11")
resampler = torchaudio.transforms.Resample(orig_freq=48_000, new_freq=16_000)
def map_to_array(batch):
speech, _ = torchaudio.load(batch["path"])
batch["speech"] = resampler.forward(speech.squeeze(0)).numpy()
batch["sampling_rate"] = resampler.new_freq
batch["sentence"] = re.sub(chars_to_ignore_regex, '', batch["sentence"]).lower().replace("’", "'")
return batch
ds = dataset.map(map_to_array)
result = ds.map(map_to_pred, batched=True, batch_size=1, remove_columns=list(ds.features.keys()))
print(wer.compute(predictions=result["predicted"], references=result["target"]))
for pred, target in zip(result["predicted"][:10], result["target"][:10]):
print(pred, "|", target)
结果:10.69%
针对 TEDx 进行测试(领域外)
!gdown --id 1HJEnvthaGYwcV_whHEywgH2daIN4bQna
!tar -xf tedx.tar.gz
dataset = load_dataset('csv', data_files={'test': 'test.csv'})['test']
def map_to_array(batch):
speech, _ = torchaudio.load(batch["path"])
batch["speech"] = speech.squeeze(0).numpy()
batch["sampling_rate"] = resampler.new_freq
batch["sentence"] = re.sub(chars_to_ignore_regex, '', batch["sentence"]).lower().replace("’", "'")
return batch
ds = dataset.map(map_to_array)
result = ds.map(map_to_pred, batched=True, batch_size=1, remove_columns=list(ds.features.keys()))
print(wer.compute(predictions=result["predicted"], references=result["target"]))
for pred, target in zip(result["predicted"][:10], result["target"][:10]):
print(pred, "|", target)
结果:34.53%
📚 详细文档
模型信息
属性 |
详情 |
模型类型 |
wav2vec2 - large - xlsr - open - brazilian - portuguese - v2 |
训练数据 |
Common Voice、Multilingual Librispeech (MLS)、CETUC、Lapsbm、VoxForge |
评估指标 |
字错率(WER) |
标签信息
- 音频
- 语音
- wav2vec2
- 葡萄牙语
- 葡萄牙语语音语料库
- 自动语音识别
- PyTorch
- hf - asr - leaderboard
模型索引
模型名称为 wav2vec2 - large - xlsr - open - brazilian - portuguese - v2,在 Common Voice 数据集上的测试字错率(WER)为 10.69%。
📄 许可证
本项目采用 Apache 2.0 许可证。