🚀 XLS-R-300-m 自动语音识别模型
本模型是一个专注于日语自动语音识别的模型,它基于预训练的 facebook/wav2vec2-xls-r-300m 模型,在 MOZILLA - FOUNDATION/COMMON_VOICE_8_0 - JA 数据集上进行微调。在训练和评估过程中,使用 pykakasi 库将汉字转换为平假名,模型能够输出平假名和片假名字符。由于日语中没有空格,使用字符错误率(CER)来评估性能比单词错误率(WER)更合适。
🚀 快速开始
本模型是 facebook/wav2vec2-xls-r-300m 在 MOZILLA - FOUNDATION/COMMON_VOICE_8_0 - JA 数据集上的微调版本。在训练和评估期间,使用 pykakasi 库将汉字转换为平假名。该模型可以输出平假名和片假名字符。由于日语中没有空格,WER 不是评估性能的合适指标,CER 更为合适。
在 mozilla - foundation/common_voice_8_0 数据集上,它实现了:
在 speech - recognition - community - v2/dev_data 数据集上,它实现了:
在评估集上,它取得了以下结果:
- 损失: 0.5212
- 单词错误率(WER): 1.3068
✨ 主要特性
- 语言转换:在训练和评估时使用 pykakasi 库将汉字转换为平假名。
- 输出字符类型:能够输出平假名和片假名字符。
- 评估指标:由于日语无空格,采用 CER 作为更合适的评估指标。
🔧 技术细节
训练超参数
训练过程中使用了以下超参数:
- 学习率(learning_rate): 7.5e - 05
- 训练批次大小(train_batch_size): 48
- 评估批次大小(eval_batch_size): 8
- 随机种子(seed): 42
- 优化器(optimizer): Adam,β1 = 0.9,β2 = 0.999,ε = 1e - 08
- 学习率调度器类型(lr_scheduler_type): 线性
- 学习率调度器热身步数(lr_scheduler_warmup_steps): 2000
- 训练轮数(num_epochs): 50.0
- 混合精度训练(mixed_precision_training): 原生自动混合精度(Native AMP)
训练结果
训练损失 |
轮数 |
步数 |
验证损失 |
单词错误率(WER) |
4.0974 |
4.72 |
1000 |
4.0178 |
1.9535 |
2.1276 |
9.43 |
2000 |
0.9301 |
1.2128 |
1.7622 |
14.15 |
3000 |
0.7103 |
1.5527 |
1.6397 |
18.87 |
4000 |
0.6729 |
1.4269 |
1.5468 |
23.58 |
5000 |
0.6087 |
1.2497 |
1.4885 |
28.3 |
6000 |
0.5786 |
1.3222 |
1.451 |
33.02 |
7000 |
0.5726 |
1.3768 |
1.3912 |
37.74 |
8000 |
0.5518 |
1.2497 |
1.3617 |
42.45 |
9000 |
0.5352 |
1.2694 |
1.3113 |
47.17 |
10000 |
0.5228 |
1.2781 |
框架版本
- Transformers 4.17.0.dev0
- Pytorch 1.10.2 + cu102
- Datasets 1.18.2.dev0
- Tokenizers 0.11.0
评估命令
- 在
mozilla - foundation/common_voice_8_0
数据集的 test
分割上进行评估
python ./eval.py --model_id AndrewMcDowell/wav2vec2-xls-r-300m-japanese --dataset mozilla-foundation/common_voice_8_0 --config ja --split test --log_outputs
- 在
mozilla - foundation/common_voice_8_0
数据集的 test
分割上进行评估
python ./eval.py --model_id AndrewMcDowell/wav2vec2-xls-r-300m-japanese --dataset speech-recognition-community-v2/dev_data --config de --split validation --chunk_length_s 5.0 --stride_length_s 1.0
📄 许可证
本模型采用 Apache - 2.0 许可证。