đ voxforge1-xlsr: Wav2vec 2.0 with VoxForge Dataset
This project demonstrates a fine - tuned Wav2vec model for Brazilian Portuguese, leveraging the VoxForge dataset. In this notebook, the model is tested against various available Brazilian Portuguese datasets.
⨠Features
- Multilingual Compatibility: Supports Brazilian Portuguese speech recognition.
- Multiple Datasets: Tested on multiple Brazilian Portuguese datasets, including CETUC, Common Voice, LaPS BM, MLS, Multilingual TEDx (Portuguese), SID, and VoxForge.
- Language Model Integration: Allows for the use of a 4 - gram language model to improve recognition accuracy.
đĻ Installation
The installation steps involve installing necessary Python libraries and downloading the required datasets.
!pip install torch==1.8.2+cu111 torchvision==0.9.2+cu111 torchaudio===0.8.2 -f https://download.pytorch.org/whl/lts/1.8/torch_lts.html
!pip install datasets
!pip install jiwer
!pip install transformers
!pip install soundfile
!pip install pyctcdecode
!pip install https://github.com/kpu/kenlm/archive/master.zip
!gdown --id 1HFECzIizf-bmkQRLiQD0QVqcGtOG5upI
!mkdir bp_dataset
!unzip bp_dataset -d bp_dataset/
!gdown --id 1GJIKseP5ZkTbllQVgOL98R4yYAcIySFP
đģ Usage Examples
Basic Usage
MODEL_NAME = "lgris/voxforge1-xlsr"
stt = STT(MODEL_NAME)
ds = load_data('cetuc_dataset')
result = ds.map(stt.batch_predict, batched=True, batch_size=8)
wer, mer, wil = calc_metrics(result["sentence"], result["predicted"])
print("CETUC WER:", wer)
Advanced Usage
stt = STT(MODEL_NAME, lm='pt-BR-wiki.word.4-gram.arpa')
ds = load_data('cetuc_dataset')
result = ds.map(stt.batch_predict, batched=True, batch_size=8)
wer, mer, wil = calc_metrics(result["sentence"], result["predicted"])
print("CETUC WER:", wer)
đ Documentation
Dataset Information
Dataset |
Train |
Valid |
Test |
CETUC |
|
-- |
5.4h |
Common Voice |
|
-- |
9.5h |
LaPS BM |
|
-- |
0.1h |
MLS |
|
-- |
3.7h |
Multilingual TEDx (Portuguese) |
|
-- |
1.8h |
SID |
|
-- |
1.0h |
VoxForge |
3.9h |
-- |
0.1h |
Total |
3.9h |
-- |
21.6h |
Summary of Results
|
CETUC |
CV |
LaPS |
MLS |
SID |
TEDx |
VF |
AVG |
voxforge_1 (demonstration below) |
0.468 |
0.608 |
0.503 |
0.505 |
0.717 |
0.731 |
0.561 |
0.584 |
voxforge_1 + 4 - gram (demonstration below) |
0.322 |
0.471 |
0.356 |
0.378 |
0.586 |
0.637 |
0.428 |
0.454 |
đ§ Technical Details
Model Class
class STT:
def __init__(self,
model_name,
device='cuda' if torch.cuda.is_available() else 'cpu',
lm=None):
self.model_name = model_name
self.model = Wav2Vec2ForCTC.from_pretrained(model_name).to(device)
self.processor = Wav2Vec2Processor.from_pretrained(model_name)
self.vocab_dict = self.processor.tokenizer.get_vocab()
self.sorted_dict = {
k.lower(): v for k, v in sorted(self.vocab_dict.items(),
key=lambda item: item[1])
}
self.device = device
self.lm = lm
if self.lm:
self.lm_decoder = build_ctcdecoder(
list(self.sorted_dict.keys()),
self.lm
)
def batch_predict(self, batch):
features = self.processor(batch["speech"],
sampling_rate=batch["sampling_rate"][0],
padding=True,
return_tensors="pt")
input_values = features.input_values.to(self.device)
attention_mask = features.attention_mask.to(self.device)
with torch.no_grad():
logits = self.model(input_values, attention_mask=attention_mask).logits
if self.lm:
logits = logits.cpu().numpy()
batch["predicted"] = []
for sample_logits in logits:
batch["predicted"].append(self.lm_decoder.decode(sample_logits))
else:
pred_ids = torch.argmax(logits, dim=-1)
batch["predicted"] = self.processor.batch_decode(pred_ids)
return batch
Helper Functions
chars_to_ignore_regex = '[\,\?\.\!\;\:\"]'
def map_to_array(batch):
speech, _ = torchaudio.load(batch["path"])
batch["speech"] = speech.squeeze(0).numpy()
batch["sampling_rate"] = 16_000
batch["sentence"] = re.sub(chars_to_ignore_regex, '', batch["sentence"]).lower().replace("â", "'")
batch["target"] = batch["sentence"]
return batch
def calc_metrics(truths, hypos):
wers = []
mers = []
wils = []
for t, h in zip(truths, hypos):
try:
wers.append(jiwer.wer(t, h))
mers.append(jiwer.mer(t, h))
wils.append(jiwer.wil(t, h))
except:
pass
wer = sum(wers)/len(wers)
mer = sum(mers)/len(mers)
wil = sum(wils)/len(wils)
return wer, mer, wil
def load_data(dataset):
data_files = {'test': f'{dataset}/test.csv'}
dataset = load_dataset('csv', data_files=data_files)["test"]
return dataset.map(map_to_array)
đ License
This project is licensed under the Apache - 2.0 license.