🚀 Wav2Vec2 LJSpeech Gruut
Wav2Vec2 LJSpeech Gruutは、wav2vec 2.0アーキテクチャに基づく自動音声認識モデルです。このモデルは、Wav2Vec2-BaseをLJSpech Phonemesデータセットでファインチューニングしたものです。
このモデルは、単語のシーケンスを予測するように訓練されるのではなく、音素のシーケンス(例:["h", "ɛ", "l", "ˈoʊ", "w", "ˈɚ", "l", "d"]
)を予測するように訓練されています。したがって、モデルの語彙には、gruutに含まれるさまざまなIPA音素が含まれています。
このモデルは、HuggingFaceのPyTorchフレームワークを使用して訓練されました。すべての訓練は、Tesla A100 GPUを搭載したGoogle Cloud Engine VM上で行われました。訓練に使用されたすべての必要なスクリプトは、Files and versionsタブに、またTensorboardを介して記録されたTraining metricsも確認できます。
✨ 主な機能
📦 インストール
このREADMEには具体的なインストール手順が記載されていないため、このセクションをスキップします。
💻 使用例
基本的な使用法
from transformers import AutoProcessor, AutoModelForCTC, Wav2Vec2Processor
import librosa
import torch
from itertools import groupby
from datasets import load_dataset
def decode_phonemes(
ids: torch.Tensor, processor: Wav2Vec2Processor, ignore_stress: bool = False
) -> str:
"""CTC-like decoding. First removes consecutive duplicates, then removes special tokens."""
ids = [id_ for id_, _ in groupby(ids)]
special_token_ids = processor.tokenizer.all_special_ids + [
processor.tokenizer.word_delimiter_token_id
]
phonemes = [processor.decode(id_) for id_ in ids if id_ not in special_token_ids]
prediction = " ".join(phonemes)
if ignore_stress == True:
prediction = prediction.replace("ˈ", "").replace("ˌ", "")
return prediction
checkpoint = "bookbot/wav2vec2-ljspeech-gruut"
model = AutoModelForCTC.from_pretrained(checkpoint)
processor = AutoProcessor.from_pretrained(checkpoint)
sr = processor.feature_extractor.sampling_rate
ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
audio_array = ds[0]["audio"]["array"]
inputs = processor(audio_array, return_tensors="pt", padding=True)
with torch.no_grad():
logits = model(inputs["input_values"]).logits
predicted_ids = torch.argmax(logits, dim=-1)
prediction = decode_phonemes(predicted_ids[0], processor, ignore_stress=True)
📚 ドキュメント
モデル情報
プロパティ |
詳細 |
モデルタイプ |
wav2vec2-ljspeech-gruut |
訓練データ |
LJSpech Phonemes データセット |
評価結果
このモデルは、以下の評価結果を達成しています。
データセット |
音素誤り率 (PER, アクセントなし) |
文字誤り率 (CER, アクセントなし) |
LJSpech Phonemes テストデータ |
0.99% |
0.58% |
訓練手順
訓練ハイパーパラメータ
訓練中に使用されたハイパーパラメータは以下の通りです。
learning_rate
: 0.0001
train_batch_size
: 16
eval_batch_size
: 8
seed
: 42
gradient_accumulation_steps
: 2
total_train_batch_size
: 32
optimizer
: Adam (betas=(0.9,0.999)
および epsilon=1e-08
)
lr_scheduler_type
: linear
lr_scheduler_warmup_steps
: 1000
num_epochs
: 30.0
mixed_precision_training
: Native AMP
訓練結果
訓練損失 |
エポック |
ステップ |
検証損失 |
単語誤り率 (Wer) |
文字誤り率 (Cer) |
No log |
1.0 |
348 |
2.2818 |
1.0 |
1.0 |
2.6692 |
2.0 |
696 |
0.2045 |
0.0527 |
0.0299 |
0.2225 |
3.0 |
1044 |
0.1162 |
0.0319 |
0.0189 |
0.2225 |
4.0 |
1392 |
0.0927 |
0.0235 |
0.0147 |
0.0868 |
5.0 |
1740 |
0.0797 |
0.0218 |
0.0143 |
0.0598 |
6.0 |
2088 |
0.0715 |
0.0197 |
0.0128 |
0.0598 |
7.0 |
2436 |
0.0652 |
0.0160 |
0.0103 |
0.0447 |
8.0 |
2784 |
0.0571 |
0.0152 |
0.0095 |
0.0368 |
9.0 |
3132 |
0.0608 |
0.0163 |
0.0112 |
0.0368 |
10.0 |
3480 |
0.0586 |
0.0137 |
0.0083 |
0.0303 |
11.0 |
3828 |
0.0641 |
0.0141 |
0.0085 |
0.0273 |
12.0 |
4176 |
0.0656 |
0.0131 |
0.0079 |
0.0232 |
13.0 |
4524 |
0.0690 |
0.0133 |
0.0082 |
0.0232 |
14.0 |
4872 |
0.0598 |
0.0128 |
0.0079 |
0.0189 |
15.0 |
5220 |
0.0671 |
0.0121 |
0.0074 |
0.017 |
16.0 |
5568 |
0.0654 |
0.0114 |
0.0069 |
0.017 |
17.0 |
5916 |
0.0751 |
0.0118 |
0.0073 |
0.0146 |
18.0 |
6264 |
0.0653 |
0.0112 |
0.0068 |
0.0127 |
19.0 |
6612 |
0.0682 |
0.0112 |
0.0069 |
0.0127 |
20.0 |
6960 |
0.0678 |
0.0114 |
0.0068 |
0.0114 |
21.0 |
7308 |
0.0656 |
0.0111 |
0.0066 |
0.0101 |
22.0 |
7656 |
0.0669 |
0.0109 |
0.0066 |
0.0092 |
23.0 |
8004 |
0.0677 |
0.0108 |
0.0065 |
0.0092 |
24.0 |
8352 |
0.0653 |
0.0104 |
0.0063 |
0.0088 |
25.0 |
8700 |
0.0673 |
0.0102 |
0.0063 |
0.0074 |
26.0 |
9048 |
0.0669 |
0.0105 |
0.0064 |
0.0074 |
27.0 |
9396 |
0.0707 |
0.0101 |
0.0061 |
0.0066 |
28.0 |
9744 |
0.0673 |
0.0100 |
0.0060 |
0.0058 |
29.0 |
10092 |
0.0689 |
0.0100 |
0.0059 |
0.0058 |
30.0 |
10440 |
0.0683 |
0.0099 |
0.0058 |
🔧 技術詳細
このREADMEには具体的な技術詳細が記載されていないため、このセクションをスキップします。
📄 ライセンス
このモデルは、Apache 2.0ライセンスの下で提供されています。
免責事項
このモデルの結果には、事前訓練データセットに由来するバイアスが含まれている可能性があることに留意してください。
作成者
Wav2Vec2 LJSpeech Gruutは、Wilson Wongsoによって訓練および評価されました。すべての計算と開発はGoogle Cloud上で行われました。
フレームワークのバージョン
- Transformers 4.26.0.dev0
- Pytorch 1.10.0
- Datasets 2.7.1
- Tokenizers 0.13.2
- Gruut 2.3.4