🚀 wav2vec2-large-xlsr-53-th
このプロジェクトは、タイ語の音声認識に特化しており、wav2vec2-large-xlsr-53
をタイ語の Common Voice 7.0 でファインチューニングしたものです。詳細な情報はブログで確認できます。
🚀 クイックスタート
このモデルは、wav2vec2-large-xlsr-53 をベースに、Fine-tuning Wav2Vec2 for English ASR の方法を参考に、Common Voice Corpus 7.0 のタイ語のサンプルを使ってファインチューニングされています。ノートブックとスクリプトは vistec-ai/wav2vec2-large-xlsr-53-th で、事前学習済みモデルとプロセッサは airesearch/wav2vec2-large-xlsr-53-th で見つけることができます。
✨ 主な機能
robust-speech-event
の eval.py
に syllable_tokenize
、word_tokenize
(PyThaiNLP) と deepcut のトークナイザーを追加しています。
- タイ語の音声認識に特化したファインチューニングを行っています。
📦 インストール
インストールに関する具体的なコマンドは提供されていないため、このセクションは省略されます。
💻 使用例
基本的な使用法
processor = Wav2Vec2Processor.from_pretrained("airesearch/wav2vec2-large-xlsr-53-th")
model = Wav2Vec2ForCTC.from_pretrained("airesearch/wav2vec2-large-xlsr-53-th")
def speech_file_to_array_fn(batch,
text_col="sentence",
fname_col="path",
resampling_to=16000):
speech_array, sampling_rate = torchaudio.load(batch[fname_col])
resampler=torchaudio.transforms.Resample(sampling_rate, resampling_to)
batch["speech"] = resampler(speech_array)[0].numpy()
batch["sampling_rate"] = resampling_to
batch["target_text"] = batch[text_col]
return batch
test_dataset = test_dataset.map(speech_file_to_array_fn)
inputs = processor(test_dataset["speech"][:2], sampling_rate=16_000, return_tensors="pt", padding=True)
with torch.no_grad():
logits = model(inputs.input_values,).logits
predicted_ids = torch.argmax(logits, dim=-1)
print("Prediction:", processor.batch_decode(predicted_ids))
print("Reference:", test_dataset["sentence"][:2])
>> Prediction: ['และ เขา ก็ สัมผัส ดีบุก', 'คุณ สามารถ รับทราบ เมื่อ ข้อความ นี้ ถูก อ่าน แล้ว']
>> Reference: ['และเขาก็สัมผัสดีบุก', 'คุณสามารถรับทราบเมื่อข้อความนี้ถูกอ่านแล้ว']
📚 ドキュメント
データセット
Common Voice Corpus 7.0 は、5GBで133時間の検証済みタイ語データ(合計255時間)を含んでいます。pythainlp.tokenize.word_tokenize
で事前トークナイズを行い、@tann9949 が notebooks/cv-preprocess.ipynb
で記述したクリーニングルールを使って前処理を行っています。その後、ekapolc/Thai_commonvoice_split の方法で重複排除と分割を行い、データリークを防ぎ、訓練セットに多くのデータを残しています。データセットの読み込みスクリプトは scripts/th_common_voice_70.py
です。
トレーニング
以下の設定を使用して、単一のV100 GPUでファインチューニングを行い、検証損失が最も低いチェックポイントを選択しています。ファインチューニングスクリプトは scripts/wav2vec2_finetune.py
です。
model = Wav2Vec2ForCTC.from_pretrained(
"facebook/wav2vec2-large-xlsr-53",
attention_dropout=0.1,
hidden_dropout=0.1,
feat_proj_dropout=0.0,
mask_time_prob=0.05,
layerdrop=0.1,
gradient_checkpointing=True,
ctc_loss_reduction="mean",
pad_token_id=processor.tokenizer.pad_token_id,
vocab_size=len(processor.tokenizer)
)
model.freeze_feature_extractor()
training_args = TrainingArguments(
output_dir="../data/wav2vec2-large-xlsr-53-thai",
group_by_length=True,
per_device_train_batch_size=32,
gradient_accumulation_steps=1,
per_device_eval_batch_size=16,
metric_for_best_model='wer',
evaluation_strategy="steps",
eval_steps=1000,
logging_strategy="steps",
logging_steps=1000,
save_strategy="steps",
save_steps=1000,
num_train_epochs=100,
fp16=True,
learning_rate=1e-4,
warmup_steps=1000,
save_total_limit=3,
report_to="tensorboard"
)
評価
テストセットで、PyThaiNLP 2.3.1 と deepcut でトークナイズされた単語を使ったWERとCERでベンチマークを行っています。また、TNC のn-gramを使ったスペル訂正を適用した場合のパフォーマンスも測定しています。評価コードは notebooks/wav2vec2_finetuning_tutorial.ipynb
で見つけることができます。ベンチマークは test-unique
分割で行われています。
モデル |
WER PyThaiNLP 2.3.1 |
WER deepcut |
CER |
Kaldi from scratch |
23.04 |
|
7.57 |
スペル訂正なし |
13.634024 |
8.152052 |
2.813019 |
スペル訂正あり |
17.996397 |
14.167975 |
5.225761 |
Google Web Speech API※ |
13.711234 |
10.860058 |
7.357340 |
Microsoft Bing Speech API※ |
12.578819 |
9.620991 |
5.016620 |
Amazon Transcribe※ |
21.86334 |
14.487553 |
7.077562 |
NECTEC AI for Thai Partii API※ |
20.105887 |
15.515631 |
9.551027 |
※ APIs are not finetuned with Common Voice 7.0 data
📄 ライセンス
このプロジェクトは cc-by-sa 4.0 ライセンスの下で公開されています。
謝辞