🚀 Wav2Vec2 LJSpeech Gruut
Wav2Vec2 LJSpeech Gruut是一个基于wav2vec 2.0架构的自动语音识别模型。该模型是Wav2Vec2-Base在LJSpech Phonemes数据集上的微调版本。它不预测单词序列,而是预测音素序列,例如["h", "ɛ", "l", "ˈoʊ", "w", "ˈɚ", "l", "d"]
,其词汇表包含了gruut中的不同国际音标(IPA)音素。
🚀 快速开始
本模型基于HuggingFace的PyTorch框架进行训练,所有训练工作在配备Tesla A100 GPU的Google Cloud Engine虚拟机上完成。训练所需的所有脚本可在文件和版本标签中找到,训练指标通过Tensorboard记录,可在训练指标查看。
✨ 主要特性
- 基于wav2vec 2.0架构,在语音识别任务上有良好表现。
- 经过微调,可预测音素序列,适用于音素识别场景。
📦 安装指南
文档未提及安装步骤,可参考HuggingFace相关库的安装方法来安装依赖,如transformers
、librosa
、torch
、datasets
等。
💻 使用示例
基础用法
from transformers import AutoProcessor, AutoModelForCTC, Wav2Vec2Processor
import librosa
import torch
from itertools import groupby
from datasets import load_dataset
def decode_phonemes(
ids: torch.Tensor, processor: Wav2Vec2Processor, ignore_stress: bool = False
) -> str:
"""CTC-like decoding. First removes consecutive duplicates, then removes special tokens."""
ids = [id_ for id_, _ in groupby(ids)]
special_token_ids = processor.tokenizer.all_special_ids + [
processor.tokenizer.word_delimiter_token_id
]
phonemes = [processor.decode(id_) for id_ in ids if id_ not in special_token_ids]
prediction = " ".join(phonemes)
if ignore_stress == True:
prediction = prediction.replace("ˈ", "").replace("ˌ", "")
return prediction
checkpoint = "bookbot/wav2vec2-ljspeech-gruut"
model = AutoModelForCTC.from_pretrained(checkpoint)
processor = AutoProcessor.from_pretrained(checkpoint)
sr = processor.feature_extractor.sampling_rate
ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
audio_array = ds[0]["audio"]["array"]
inputs = processor(audio_array, return_tensors="pt", padding=True)
with torch.no_grad():
logits = model(inputs["input_values"]).logits
predicted_ids = torch.argmax(logits, dim=-1)
prediction = decode_phonemes(predicted_ids[0], processor, ignore_stress=True)
📚 详细文档
模型信息
属性 |
详情 |
模型类型 |
wav2vec2-ljspeech-gruut |
训练数据 |
LJSpech Phonemes 数据集 |
模型参数数量 |
94M |
架构 |
wav2vec 2.0 |
评估结果
该模型在评估中取得了以下结果:
数据集 |
音素错误率(PER,无重音) |
字符错误率(CER,无重音) |
LJSpech Phonemes 测试数据 |
0.99% |
0.58% |
训练过程
训练超参数
训练过程中使用了以下超参数:
learning_rate
:0.0001
train_batch_size
:16
eval_batch_size
:8
seed
:42
gradient_accumulation_steps
:2
total_train_batch_size
:32
optimizer
:Adam,betas=(0.9,0.999)
,epsilon=1e-08
lr_scheduler_type
:线性
lr_scheduler_warmup_steps
:1000
num_epochs
:30.0
mixed_precision_training
:Native AMP
训练结果
训练损失 |
轮数 |
步数 |
验证损失 |
词错误率(WER) |
字符错误率(CER) |
无记录 |
1.0 |
348 |
2.2818 |
1.0 |
1.0 |
2.6692 |
2.0 |
696 |
0.2045 |
0.0527 |
0.0299 |
0.2225 |
3.0 |
1044 |
0.1162 |
0.0319 |
0.0189 |
0.2225 |
4.0 |
1392 |
0.0927 |
0.0235 |
0.0147 |
0.0868 |
5.0 |
1740 |
0.0797 |
0.0218 |
0.0143 |
0.0598 |
6.0 |
2088 |
0.0715 |
0.0197 |
0.0128 |
0.0598 |
7.0 |
2436 |
0.0652 |
0.0160 |
0.0103 |
0.0447 |
8.0 |
2784 |
0.0571 |
0.0152 |
0.0095 |
0.0368 |
9.0 |
3132 |
0.0608 |
0.0163 |
0.0112 |
0.0368 |
10.0 |
3480 |
0.0586 |
0.0137 |
0.0083 |
0.0303 |
11.0 |
3828 |
0.0641 |
0.0141 |
0.0085 |
0.0273 |
12.0 |
4176 |
0.0656 |
0.0131 |
0.0079 |
0.0232 |
13.0 |
4524 |
0.0690 |
0.0133 |
0.0082 |
0.0232 |
14.0 |
4872 |
0.0598 |
0.0128 |
0.0079 |
0.0189 |
15.0 |
5220 |
0.0671 |
0.0121 |
0.0074 |
0.017 |
16.0 |
5568 |
0.0654 |
0.0114 |
0.0069 |
0.017 |
17.0 |
5916 |
0.0751 |
0.0118 |
0.0073 |
0.0146 |
18.0 |
6264 |
0.0653 |
0.0112 |
0.0068 |
0.0127 |
19.0 |
6612 |
0.0682 |
0.0112 |
0.0069 |
0.0127 |
20.0 |
6960 |
0.0678 |
0.0114 |
0.0068 |
0.0114 |
21.0 |
7308 |
0.0656 |
0.0111 |
0.0066 |
0.0101 |
22.0 |
7656 |
0.0669 |
0.0109 |
0.0066 |
0.0092 |
23.0 |
8004 |
0.0677 |
0.0108 |
0.0065 |
0.0092 |
24.0 |
8352 |
0.0653 |
0.0104 |
0.0063 |
0.0088 |
25.0 |
8700 |
0.0673 |
0.0102 |
0.0063 |
0.0074 |
26.0 |
9048 |
0.0669 |
0.0105 |
0.0064 |
0.0074 |
27.0 |
9396 |
0.0707 |
0.0101 |
0.0061 |
0.0066 |
28.0 |
9744 |
0.0673 |
0.0100 |
0.0060 |
0.0058 |
29.0 |
10092 |
0.0689 |
0.0100 |
0.0059 |
0.0058 |
30.0 |
10440 |
0.0683 |
0.0099 |
0.0058 |
免责声明
请考虑预训练数据集可能带来的偏差,这些偏差可能会影响本模型的结果。
作者信息
Wav2Vec2 LJSpeech Gruut由Wilson Wongso进行训练和评估,所有计算和开发工作在Google Cloud上完成。
框架版本
- Transformers 4.26.0.dev0
- Pytorch 1.10.0
- Datasets 2.7.1
- Tokenizers 0.13.2
- Gruut 2.3.4
📄 许可证
本项目采用Apache-2.0许可证。