模型简介
模型特点
模型能力
使用案例
🚀 wav2vec2-large-xlsr-53-th
本项目是基于泰语 Common Voice 7.0 对 wav2vec2-large-xlsr-53
进行微调的成果。该项目在自动语音识别领域具有重要价值,可有效提升泰语语音识别的准确性。
🚀 快速开始
本项目基于 Fine-tuning Wav2Vec2 for English ASR,使用 Common Voice Corpus 7.0 中的泰语示例对 wav2vec2-large-xlsr-53 进行微调。相关的笔记本和脚本可在 vistec-ai/wav2vec2-large-xlsr-53-th 中找到,预训练模型和处理器可在 airesearch/wav2vec2-large-xlsr-53-th 中获取。
✨ 主要特性
- 在
robust-speech-event
的eval.py
中添加了syllable_tokenize
、word_tokenize
(PyThaiNLP)和 deepcut 分词器。 - 对泰语语音识别进行了微调,提升了识别性能。
- 提供了详细的训练和评估脚本及配置。
💻 使用示例
基础用法
#load pretrained processor and model
processor = Wav2Vec2Processor.from_pretrained("airesearch/wav2vec2-large-xlsr-53-th")
model = Wav2Vec2ForCTC.from_pretrained("airesearch/wav2vec2-large-xlsr-53-th")
#function to resample to 16_000
def speech_file_to_array_fn(batch,
text_col="sentence",
fname_col="path",
resampling_to=16000):
speech_array, sampling_rate = torchaudio.load(batch[fname_col])
resampler=torchaudio.transforms.Resample(sampling_rate, resampling_to)
batch["speech"] = resampler(speech_array)[0].numpy()
batch["sampling_rate"] = resampling_to
batch["target_text"] = batch[text_col]
return batch
#get 2 examples as sample input
test_dataset = test_dataset.map(speech_file_to_array_fn)
inputs = processor(test_dataset["speech"][:2], sampling_rate=16_000, return_tensors="pt", padding=True)
#infer
with torch.no_grad():
logits = model(inputs.input_values,).logits
predicted_ids = torch.argmax(logits, dim=-1)
print("Prediction:", processor.batch_decode(predicted_ids))
print("Reference:", test_dataset["sentence"][:2])
>> Prediction: ['และ เขา ก็ สัมผัส ดีบุก', 'คุณ สามารถ รับทราบ เมื่อ ข้อความ นี้ ถูก อ่าน แล้ว']
>> Reference: ['และเขาก็สัมผัสดีบุก', 'คุณสามารถรับทราบเมื่อข้อความนี้ถูกอ่านแล้ว']
📦 安装指南
文档中未提及具体安装步骤,暂无法提供。
📚 详细文档
数据集
Common Voice Corpus 7.0 包含 133 小时经过验证的泰语语音数据(总计 255 小时),大小为 5GB。使用 pythainlp.tokenize.word_tokenize
进行预分词,并按照 @tann9949 在 notebooks/cv-preprocess.ipynb
中描述的清理规则对数据集进行预处理。然后,按照 ekapolc/Thai_commonvoice_split 中的方法进行去重和划分,以避免 Common Voice Corpus 7.0 中清理后随机划分导致的数据泄露,并将大部分数据保留给训练集。数据集加载脚本为 scripts/th_common_voice_70.py
,你可以结合 train_cleand.tsv
、validation_cleaned.tsv
和 test_cleaned.tsv
使用该脚本,以获得与本项目相同的划分。划分后的数据集如下:
DatasetDict({
train: Dataset({
features: ['path', 'sentence'],
num_rows: 86586
})
test: Dataset({
features: ['path', 'sentence'],
num_rows: 2502
})
validation: Dataset({
features: ['path', 'sentence'],
num_rows: 3027
})
})
训练
在单个 V100 GPU 上使用以下配置进行微调,并选择验证损失最低的检查点。微调脚本为 scripts/wav2vec2_finetune.py
:
# create model
model = Wav2Vec2ForCTC.from_pretrained(
"facebook/wav2vec2-large-xlsr-53",
attention_dropout=0.1,
hidden_dropout=0.1,
feat_proj_dropout=0.0,
mask_time_prob=0.05,
layerdrop=0.1,
gradient_checkpointing=True,
ctc_loss_reduction="mean",
pad_token_id=processor.tokenizer.pad_token_id,
vocab_size=len(processor.tokenizer)
)
model.freeze_feature_extractor()
training_args = TrainingArguments(
output_dir="../data/wav2vec2-large-xlsr-53-thai",
group_by_length=True,
per_device_train_batch_size=32,
gradient_accumulation_steps=1,
per_device_eval_batch_size=16,
metric_for_best_model='wer',
evaluation_strategy="steps",
eval_steps=1000,
logging_strategy="steps",
logging_steps=1000,
save_strategy="steps",
save_steps=1000,
num_train_epochs=100,
fp16=True,
learning_rate=1e-4,
warmup_steps=1000,
save_total_limit=3,
report_to="tensorboard"
)
评估
在测试集上使用 PyThaiNLP 2.3.1 和 deepcut 进行分词后的单词错误率(WER)以及字符错误率(CER)进行基准测试。还测量了应用 TNC 三元组拼写校正时的性能。评估代码可在 notebooks/wav2vec2_finetuning_tutorial.ipynb
中找到,基准测试在 test-unique
划分上进行。
模型 | WER PyThaiNLP 2.3.1 | WER deepcut | CER |
---|---|---|---|
Kaldi from scratch | 23.04 | 7.57 | |
本项目无拼写校正 | 13.634024 | 8.152052 | 2.813019 |
本项目有拼写校正 | 17.996397 | 14.167975 | 5.225761 |
Google Web Speech API※ | 13.711234 | 10.860058 | 7.357340 |
Microsoft Bing Speech API※ | 12.578819 | 9.620991 | 5.016620 |
Amazon Transcribe※ | 21.86334 | 14.487553 | 7.077562 |
NECTEC AI for Thai Partii API※ | 20.105887 | 15.515631 | 9.551027 |
※ APIs 未使用 Common Voice 7.0 数据进行微调
robust-speech-event
评估
> python eval.py --model_id ./ --dataset mozilla-foundation/common_voice_7_0 --config th --split test --log_outputs --thai_tokenizer newmm/syllable/deepcut/cer
Common Voice 7 "test" 评估结果
评估方式 | WER PyThaiNLP 2.3.1 | WER deepcut | SER | CER |
---|---|---|---|---|
仅分词 | 0.9524% | 2.5316% | 1.2346% | 0.1623% |
清理规则和分词 | TBD | TBD | TBD | TBD |
📄 许可证
👏 致谢
- 模型训练和验证笔记本/脚本:@cstorm125
- 数据集清理脚本:@tann9949
- 数据集划分:@ekapolc 和 @14mss
- 运行训练:@mrpeerat
- 拼写校正:@wannaphong



