🚀 XLS-R-300-m 自動語音識別模型
本模型是一個專注於日語自動語音識別的模型,它基於預訓練的 facebook/wav2vec2-xls-r-300m 模型,在 MOZILLA - FOUNDATION/COMMON_VOICE_8_0 - JA 數據集上進行微調。在訓練和評估過程中,使用 pykakasi 庫將漢字轉換為平假名,模型能夠輸出平假名和片假名字符。由於日語中沒有空格,使用字符錯誤率(CER)來評估性能比單詞錯誤率(WER)更合適。
🚀 快速開始
本模型是 facebook/wav2vec2-xls-r-300m 在 MOZILLA - FOUNDATION/COMMON_VOICE_8_0 - JA 數據集上的微調版本。在訓練和評估期間,使用 pykakasi 庫將漢字轉換為平假名。該模型可以輸出平假名和片假名字符。由於日語中沒有空格,WER 不是評估性能的合適指標,CER 更為合適。
在 mozilla - foundation/common_voice_8_0 數據集上,它實現了:
在 speech - recognition - community - v2/dev_data 數據集上,它實現了:
在評估集上,它取得了以下結果:
- 損失: 0.5212
- 單詞錯誤率(WER): 1.3068
✨ 主要特性
- 語言轉換:在訓練和評估時使用 pykakasi 庫將漢字轉換為平假名。
- 輸出字符類型:能夠輸出平假名和片假名字符。
- 評估指標:由於日語無空格,採用 CER 作為更合適的評估指標。
🔧 技術細節
訓練超參數
訓練過程中使用了以下超參數:
- 學習率(learning_rate): 7.5e - 05
- 訓練批次大小(train_batch_size): 48
- 評估批次大小(eval_batch_size): 8
- 隨機種子(seed): 42
- 優化器(optimizer): Adam,β1 = 0.9,β2 = 0.999,ε = 1e - 08
- 學習率調度器類型(lr_scheduler_type): 線性
- 學習率調度器熱身步數(lr_scheduler_warmup_steps): 2000
- 訓練輪數(num_epochs): 50.0
- 混合精度訓練(mixed_precision_training): 原生自動混合精度(Native AMP)
訓練結果
訓練損失 |
輪數 |
步數 |
驗證損失 |
單詞錯誤率(WER) |
4.0974 |
4.72 |
1000 |
4.0178 |
1.9535 |
2.1276 |
9.43 |
2000 |
0.9301 |
1.2128 |
1.7622 |
14.15 |
3000 |
0.7103 |
1.5527 |
1.6397 |
18.87 |
4000 |
0.6729 |
1.4269 |
1.5468 |
23.58 |
5000 |
0.6087 |
1.2497 |
1.4885 |
28.3 |
6000 |
0.5786 |
1.3222 |
1.451 |
33.02 |
7000 |
0.5726 |
1.3768 |
1.3912 |
37.74 |
8000 |
0.5518 |
1.2497 |
1.3617 |
42.45 |
9000 |
0.5352 |
1.2694 |
1.3113 |
47.17 |
10000 |
0.5228 |
1.2781 |
框架版本
- Transformers 4.17.0.dev0
- Pytorch 1.10.2 + cu102
- Datasets 1.18.2.dev0
- Tokenizers 0.11.0
評估命令
- 在
mozilla - foundation/common_voice_8_0
數據集的 test
分割上進行評估
python ./eval.py --model_id AndrewMcDowell/wav2vec2-xls-r-300m-japanese --dataset mozilla-foundation/common_voice_8_0 --config ja --split test --log_outputs
- 在
mozilla - foundation/common_voice_8_0
數據集的 test
分割上進行評估
python ./eval.py --model_id AndrewMcDowell/wav2vec2-xls-r-300m-japanese --dataset speech-recognition-community-v2/dev_data --config de --split validation --chunk_length_s 5.0 --stride_length_s 1.0
📄 許可證
本模型採用 Apache - 2.0 許可證。