🚀 xls-r-300m-km 語音識別模型
本模型是基於Transformer架構的語音識別模型,在openslr數據集上微調了facebook/wav2vec2-xls-r-300m模型,可用於高棉語的自動語音識別任務。
🚀 快速開始
本模型是 facebook/wav2vec2-xls-r-300m 在 openslr 數據集上的微調版本。它在評估集上取得了以下結果:
- 損失:0.3281
- 字錯率(Wer):0.3462
OpenSLR “測試” 集(自劃分 10%)評估結果(運行 ./eval.py)
- 字錯率(WER):0.3216977389924633
- 字符錯誤率(CER):0.08653361193169537
使用語言模型在 OpenSLR “測試” 集(自劃分 10%)的評估結果(運行 ./eval.py)
- 字錯率(WER):0.257040856802856
- 字符錯誤率(CER):0.07025001801282513
✨ 主要特性
- 小數據集表現良好:儘管僅使用約4小時的錄音數據進行訓練,但模型性能不錯。
- 支持語言模型:可結合語言模型進一步提升識別效果。
📦 安裝指南
為支持語言模型,需在 HuggingFace Transformers 基礎上安裝以下庫:
pip install pyctcdecode
pip install https://github.com/kpu/kenlm/archive/master.zip
💻 使用示例
基礎用法
使用 HuggingFace 的 pipeline,可實現從原始音頻輸入到文本輸出的端到端處理:
from transformers import pipeline
pipe = pipeline(model="vitouphy/wav2vec2-xls-r-300m-khmer")
output = pipe("sound_file.wav", chunk_length_s=10, stride_length_s=(4, 2))
高級用法
更自定義的方式來預測音素:
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
import librosa
import torch
processor = Wav2Vec2Processor.from_pretrained("vitouphy/wav2vec2-xls-r-300m-khmer")
model = Wav2Vec2ForCTC.from_pretrained("vitouphy/wav2vec2-xls-r-300m-khmer")
speech_array, sampling_rate = librosa.load("sound_file.wav", sr=16_000)
inputs = processor(speech_array, sampling_rate=16_000, return_tensors="pt", padding=True)
with torch.no_grad():
logits = model(inputs.input_values, attention_mask=inputs.attention_mask).logits
predicted_ids = torch.argmax(logits, axis=-1)
predicted_sentences = processor.batch_decode(predicted_ids)
print(predicted_sentences)
📚 詳細文檔
預期用途和限制
本模型使用的數據僅約4小時的錄音:
- 數據按 80/10/10 劃分,訓練時長僅 3.2 小時,數據量非常小。
- 儘管如此,模型性能不錯,小數據集能有此表現很有趣,你可嘗試使用。
- 模型的侷限性:
- 對稀有字符(如 ឬស្សី ឪឡឹក)識別效果不佳。
- 要求語音清晰、發音準確。
- 增加數據以覆蓋更多詞彙和字符,可能有助於提升系統性能。
訓練過程
訓練超參數
訓練過程中使用了以下超參數:
- 學習率(learning_rate):5e-05
- 訓練批次大小(train_batch_size):8
- 評估批次大小(eval_batch_size):8
- 隨機種子(seed):42
- 梯度累積步數(gradient_accumulation_steps):4
- 總訓練批次大小(total_train_batch_size):32
- 優化器(optimizer):Adam,β=(0.9, 0.999),ε=1e-08
- 學習率調度器類型(lr_scheduler_type):線性
- 學習率調度器熱身步數(lr_scheduler_warmup_steps):1000
- 訓練輪數(num_epochs):100
- 混合精度訓練(mixed_precision_training):Native AMP
訓練結果
訓練損失 |
輪數 |
步數 |
驗證損失 |
字錯率(Wer) |
5.0795 |
5.47 |
400 |
4.4121 |
1.0 |
3.5658 |
10.95 |
800 |
3.5203 |
1.0 |
3.3689 |
16.43 |
1200 |
2.8984 |
0.9996 |
2.01 |
21.91 |
1600 |
1.0041 |
0.7288 |
1.6783 |
27.39 |
2000 |
0.6941 |
0.5989 |
1.527 |
32.87 |
2400 |
0.5599 |
0.5282 |
1.4278 |
38.35 |
2800 |
0.4827 |
0.4806 |
1.3458 |
43.83 |
3200 |
0.4429 |
0.4532 |
1.2893 |
49.31 |
3600 |
0.4156 |
0.4330 |
1.2441 |
54.79 |
4000 |
0.4020 |
0.4040 |
1.188 |
60.27 |
4400 |
0.3777 |
0.3866 |
1.1628 |
65.75 |
4800 |
0.3607 |
0.3858 |
1.1324 |
71.23 |
5200 |
0.3534 |
0.3604 |
1.0969 |
76.71 |
5600 |
0.3428 |
0.3624 |
1.0897 |
82.19 |
6000 |
0.3387 |
0.3567 |
1.0625 |
87.66 |
6400 |
0.3339 |
0.3499 |
1.0601 |
93.15 |
6800 |
0.3288 |
0.3446 |
1.0474 |
98.62 |
7200 |
0.3281 |
0.3462 |
框架版本
- Transformers 4.17.0.dev0
- Pytorch 1.10.2+cu102
- Datasets 1.18.2.dev0
- Tokenizers 0.11.0
📄 許可證
本項目採用 Apache-2.0 許可證。