🚀 Wav2Vec2 LJSpeech Gruut
Wav2Vec2 LJSpeech Gruut是一個基於wav2vec 2.0架構的自動語音識別模型。該模型是Wav2Vec2-Base在LJSpech Phonemes數據集上的微調版本。它不預測單詞序列,而是預測音素序列,例如["h", "ɛ", "l", "ˈoʊ", "w", "ˈɚ", "l", "d"]
,其詞彙表包含了gruut中的不同國際音標(IPA)音素。
🚀 快速開始
本模型基於HuggingFace的PyTorch框架進行訓練,所有訓練工作在配備Tesla A100 GPU的Google Cloud Engine虛擬機上完成。訓練所需的所有腳本可在文件和版本標籤中找到,訓練指標通過Tensorboard記錄,可在訓練指標查看。
✨ 主要特性
- 基於wav2vec 2.0架構,在語音識別任務上有良好表現。
- 經過微調,可預測音素序列,適用於音素識別場景。
📦 安裝指南
文檔未提及安裝步驟,可參考HuggingFace相關庫的安裝方法來安裝依賴,如transformers
、librosa
、torch
、datasets
等。
💻 使用示例
基礎用法
from transformers import AutoProcessor, AutoModelForCTC, Wav2Vec2Processor
import librosa
import torch
from itertools import groupby
from datasets import load_dataset
def decode_phonemes(
ids: torch.Tensor, processor: Wav2Vec2Processor, ignore_stress: bool = False
) -> str:
"""CTC-like decoding. First removes consecutive duplicates, then removes special tokens."""
ids = [id_ for id_, _ in groupby(ids)]
special_token_ids = processor.tokenizer.all_special_ids + [
processor.tokenizer.word_delimiter_token_id
]
phonemes = [processor.decode(id_) for id_ in ids if id_ not in special_token_ids]
prediction = " ".join(phonemes)
if ignore_stress == True:
prediction = prediction.replace("ˈ", "").replace("ˌ", "")
return prediction
checkpoint = "bookbot/wav2vec2-ljspeech-gruut"
model = AutoModelForCTC.from_pretrained(checkpoint)
processor = AutoProcessor.from_pretrained(checkpoint)
sr = processor.feature_extractor.sampling_rate
ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
audio_array = ds[0]["audio"]["array"]
inputs = processor(audio_array, return_tensors="pt", padding=True)
with torch.no_grad():
logits = model(inputs["input_values"]).logits
predicted_ids = torch.argmax(logits, dim=-1)
prediction = decode_phonemes(predicted_ids[0], processor, ignore_stress=True)
📚 詳細文檔
模型信息
屬性 |
詳情 |
模型類型 |
wav2vec2-ljspeech-gruut |
訓練數據 |
LJSpech Phonemes 數據集 |
模型參數數量 |
94M |
架構 |
wav2vec 2.0 |
評估結果
該模型在評估中取得了以下結果:
數據集 |
音素錯誤率(PER,無重音) |
字符錯誤率(CER,無重音) |
LJSpech Phonemes 測試數據 |
0.99% |
0.58% |
訓練過程
訓練超參數
訓練過程中使用了以下超參數:
learning_rate
:0.0001
train_batch_size
:16
eval_batch_size
:8
seed
:42
gradient_accumulation_steps
:2
total_train_batch_size
:32
optimizer
:Adam,betas=(0.9,0.999)
,epsilon=1e-08
lr_scheduler_type
:線性
lr_scheduler_warmup_steps
:1000
num_epochs
:30.0
mixed_precision_training
:Native AMP
訓練結果
訓練損失 |
輪數 |
步數 |
驗證損失 |
詞錯誤率(WER) |
字符錯誤率(CER) |
無記錄 |
1.0 |
348 |
2.2818 |
1.0 |
1.0 |
2.6692 |
2.0 |
696 |
0.2045 |
0.0527 |
0.0299 |
0.2225 |
3.0 |
1044 |
0.1162 |
0.0319 |
0.0189 |
0.2225 |
4.0 |
1392 |
0.0927 |
0.0235 |
0.0147 |
0.0868 |
5.0 |
1740 |
0.0797 |
0.0218 |
0.0143 |
0.0598 |
6.0 |
2088 |
0.0715 |
0.0197 |
0.0128 |
0.0598 |
7.0 |
2436 |
0.0652 |
0.0160 |
0.0103 |
0.0447 |
8.0 |
2784 |
0.0571 |
0.0152 |
0.0095 |
0.0368 |
9.0 |
3132 |
0.0608 |
0.0163 |
0.0112 |
0.0368 |
10.0 |
3480 |
0.0586 |
0.0137 |
0.0083 |
0.0303 |
11.0 |
3828 |
0.0641 |
0.0141 |
0.0085 |
0.0273 |
12.0 |
4176 |
0.0656 |
0.0131 |
0.0079 |
0.0232 |
13.0 |
4524 |
0.0690 |
0.0133 |
0.0082 |
0.0232 |
14.0 |
4872 |
0.0598 |
0.0128 |
0.0079 |
0.0189 |
15.0 |
5220 |
0.0671 |
0.0121 |
0.0074 |
0.017 |
16.0 |
5568 |
0.0654 |
0.0114 |
0.0069 |
0.017 |
17.0 |
5916 |
0.0751 |
0.0118 |
0.0073 |
0.0146 |
18.0 |
6264 |
0.0653 |
0.0112 |
0.0068 |
0.0127 |
19.0 |
6612 |
0.0682 |
0.0112 |
0.0069 |
0.0127 |
20.0 |
6960 |
0.0678 |
0.0114 |
0.0068 |
0.0114 |
21.0 |
7308 |
0.0656 |
0.0111 |
0.0066 |
0.0101 |
22.0 |
7656 |
0.0669 |
0.0109 |
0.0066 |
0.0092 |
23.0 |
8004 |
0.0677 |
0.0108 |
0.0065 |
0.0092 |
24.0 |
8352 |
0.0653 |
0.0104 |
0.0063 |
0.0088 |
25.0 |
8700 |
0.0673 |
0.0102 |
0.0063 |
0.0074 |
26.0 |
9048 |
0.0669 |
0.0105 |
0.0064 |
0.0074 |
27.0 |
9396 |
0.0707 |
0.0101 |
0.0061 |
0.0066 |
28.0 |
9744 |
0.0673 |
0.0100 |
0.0060 |
0.0058 |
29.0 |
10092 |
0.0689 |
0.0100 |
0.0059 |
0.0058 |
30.0 |
10440 |
0.0683 |
0.0099 |
0.0058 |
免責聲明
請考慮預訓練數據集可能帶來的偏差,這些偏差可能會影響本模型的結果。
作者信息
Wav2Vec2 LJSpeech Gruut由Wilson Wongso進行訓練和評估,所有計算和開發工作在Google Cloud上完成。
框架版本
- Transformers 4.26.0.dev0
- Pytorch 1.10.0
- Datasets 2.7.1
- Tokenizers 0.13.2
- Gruut 2.3.4
📄 許可證
本項目採用Apache-2.0許可證。