模型概述
模型特點
模型能力
使用案例
🚀 wav2vec2-large-xlsr-53-th
本項目是基於泰語 Common Voice 7.0 對 wav2vec2-large-xlsr-53
進行微調的成果。該項目在自動語音識別領域具有重要價值,可有效提升泰語語音識別的準確性。
🚀 快速開始
本項目基於 Fine-tuning Wav2Vec2 for English ASR,使用 Common Voice Corpus 7.0 中的泰語示例對 wav2vec2-large-xlsr-53 進行微調。相關的筆記本和腳本可在 vistec-ai/wav2vec2-large-xlsr-53-th 中找到,預訓練模型和處理器可在 airesearch/wav2vec2-large-xlsr-53-th 中獲取。
✨ 主要特性
- 在
robust-speech-event
的eval.py
中添加了syllable_tokenize
、word_tokenize
(PyThaiNLP)和 deepcut 分詞器。 - 對泰語語音識別進行了微調,提升了識別性能。
- 提供了詳細的訓練和評估腳本及配置。
💻 使用示例
基礎用法
#load pretrained processor and model
processor = Wav2Vec2Processor.from_pretrained("airesearch/wav2vec2-large-xlsr-53-th")
model = Wav2Vec2ForCTC.from_pretrained("airesearch/wav2vec2-large-xlsr-53-th")
#function to resample to 16_000
def speech_file_to_array_fn(batch,
text_col="sentence",
fname_col="path",
resampling_to=16000):
speech_array, sampling_rate = torchaudio.load(batch[fname_col])
resampler=torchaudio.transforms.Resample(sampling_rate, resampling_to)
batch["speech"] = resampler(speech_array)[0].numpy()
batch["sampling_rate"] = resampling_to
batch["target_text"] = batch[text_col]
return batch
#get 2 examples as sample input
test_dataset = test_dataset.map(speech_file_to_array_fn)
inputs = processor(test_dataset["speech"][:2], sampling_rate=16_000, return_tensors="pt", padding=True)
#infer
with torch.no_grad():
logits = model(inputs.input_values,).logits
predicted_ids = torch.argmax(logits, dim=-1)
print("Prediction:", processor.batch_decode(predicted_ids))
print("Reference:", test_dataset["sentence"][:2])
>> Prediction: ['และ เขา ก็ สัมผัส ดีบุก', 'คุณ สามารถ รับทราบ เมื่อ ข้อความ นี้ ถูก อ่าน แล้ว']
>> Reference: ['และเขาก็สัมผัสดีบุก', 'คุณสามารถรับทราบเมื่อข้อความนี้ถูกอ่านแล้ว']
📦 安裝指南
文檔中未提及具體安裝步驟,暫無法提供。
📚 詳細文檔
數據集
Common Voice Corpus 7.0 包含 133 小時經過驗證的泰語語音數據(總計 255 小時),大小為 5GB。使用 pythainlp.tokenize.word_tokenize
進行預分詞,並按照 @tann9949 在 notebooks/cv-preprocess.ipynb
中描述的清理規則對數據集進行預處理。然後,按照 ekapolc/Thai_commonvoice_split 中的方法進行去重和劃分,以避免 Common Voice Corpus 7.0 中清理後隨機劃分導致的數據洩露,並將大部分數據保留給訓練集。數據集加載腳本為 scripts/th_common_voice_70.py
,你可以結合 train_cleand.tsv
、validation_cleaned.tsv
和 test_cleaned.tsv
使用該腳本,以獲得與本項目相同的劃分。劃分後的數據集如下:
DatasetDict({
train: Dataset({
features: ['path', 'sentence'],
num_rows: 86586
})
test: Dataset({
features: ['path', 'sentence'],
num_rows: 2502
})
validation: Dataset({
features: ['path', 'sentence'],
num_rows: 3027
})
})
訓練
在單個 V100 GPU 上使用以下配置進行微調,並選擇驗證損失最低的檢查點。微調腳本為 scripts/wav2vec2_finetune.py
:
# create model
model = Wav2Vec2ForCTC.from_pretrained(
"facebook/wav2vec2-large-xlsr-53",
attention_dropout=0.1,
hidden_dropout=0.1,
feat_proj_dropout=0.0,
mask_time_prob=0.05,
layerdrop=0.1,
gradient_checkpointing=True,
ctc_loss_reduction="mean",
pad_token_id=processor.tokenizer.pad_token_id,
vocab_size=len(processor.tokenizer)
)
model.freeze_feature_extractor()
training_args = TrainingArguments(
output_dir="../data/wav2vec2-large-xlsr-53-thai",
group_by_length=True,
per_device_train_batch_size=32,
gradient_accumulation_steps=1,
per_device_eval_batch_size=16,
metric_for_best_model='wer',
evaluation_strategy="steps",
eval_steps=1000,
logging_strategy="steps",
logging_steps=1000,
save_strategy="steps",
save_steps=1000,
num_train_epochs=100,
fp16=True,
learning_rate=1e-4,
warmup_steps=1000,
save_total_limit=3,
report_to="tensorboard"
)
評估
在測試集上使用 PyThaiNLP 2.3.1 和 deepcut 進行分詞後的單詞錯誤率(WER)以及字符錯誤率(CER)進行基準測試。還測量了應用 TNC 三元組拼寫校正時的性能。評估代碼可在 notebooks/wav2vec2_finetuning_tutorial.ipynb
中找到,基準測試在 test-unique
劃分上進行。
模型 | WER PyThaiNLP 2.3.1 | WER deepcut | CER |
---|---|---|---|
Kaldi from scratch | 23.04 | 7.57 | |
本項目無拼寫校正 | 13.634024 | 8.152052 | 2.813019 |
本項目有拼寫校正 | 17.996397 | 14.167975 | 5.225761 |
Google Web Speech API※ | 13.711234 | 10.860058 | 7.357340 |
Microsoft Bing Speech API※ | 12.578819 | 9.620991 | 5.016620 |
Amazon Transcribe※ | 21.86334 | 14.487553 | 7.077562 |
NECTEC AI for Thai Partii API※ | 20.105887 | 15.515631 | 9.551027 |
※ APIs 未使用 Common Voice 7.0 數據進行微調
robust-speech-event
評估
> python eval.py --model_id ./ --dataset mozilla-foundation/common_voice_7_0 --config th --split test --log_outputs --thai_tokenizer newmm/syllable/deepcut/cer
Common Voice 7 "test" 評估結果
評估方式 | WER PyThaiNLP 2.3.1 | WER deepcut | SER | CER |
---|---|---|---|---|
僅分詞 | 0.9524% | 2.5316% | 1.2346% | 0.1623% |
清理規則和分詞 | TBD | TBD | TBD | TBD |
📄 許可證
👏 致謝
- 模型訓練和驗證筆記本/腳本:@cstorm125
- 數據集清理腳本:@tann9949
- 數據集劃分:@ekapolc 和 @14mss
- 運行訓練:@mrpeerat
- 拼寫校正:@wannaphong



