🚀 wav2vec2-large-xls-r-300m-as
該模型是在common_voice
數據集上對facebook/wav2vec2-xls-r-300m進行微調後的版本。它在評估集上取得了以下結果:
- 損失值:1.9068
- 字錯率(Wer):0.6679
✨ 主要特性
- 基於預訓練模型
facebook/wav2vec2-xls-r-300m
在common_voice
數據集上微調,適用於自動語音識別任務。
- 在評估集上有明確的損失值和字錯率指標。
📦 安裝指南
文檔未提供安裝相關內容,故跳過此章節。
💻 使用示例
基礎用法
以下是使用該模型進行推理的示例代碼:
import torch
from datasets import load_dataset
from transformers import AutoModelForCTC, AutoProcessor
import torchaudio.functional as F
model_id = "anuragshas/wav2vec2-large-xls-r-300m-as"
sample_iter = iter(load_dataset("mozilla-foundation/common_voice_7_0", "as", split="test", streaming=True, use_auth_token=True))
sample = next(sample_iter)
resampled_audio = F.resample(torch.tensor(sample["audio"]["array"]), 48_000, 16_000).numpy()
model = AutoModelForCTC.from_pretrained(model_id)
processor = AutoProcessor.from_pretrained(model_id)
input_values = processor(resampled_audio, return_tensors="pt").input_values
with torch.no_grad():
logits = model(input_values).logits
transcription = processor.batch_decode(logits.numpy()).text
評估命令
在mozilla-foundation/common_voice_7_0
數據集的test
分割上進行評估:
python eval.py --model_id anuragshas/wav2vec2-large-xls-r-300m-as --dataset mozilla-foundation/common_voice_7_0 --config as --split test
評估結果
在Common Voice 7 "test" 數據集上的字錯率(WER)評估結果如下:
無語言模型 |
使用語言模型(運行 ./eval.py ) |
67 |
56.995 |
📚 詳細文檔
模型描述
更多信息待補充。
預期用途與限制
更多信息待補充。
訓練和評估數據
更多信息待補充。
🔧 技術細節
訓練超參數
訓練過程中使用了以下超參數:
- 學習率(learning_rate):0.0003
- 訓練批次大小(train_batch_size):16
- 評估批次大小(eval_batch_size):8
- 隨機種子(seed):42
- 梯度累積步數(gradient_accumulation_steps):2
- 總訓練批次大小(total_train_batch_size):32
- 優化器(optimizer):Adam,β=(0.9, 0.999),ε=1e - 08
- 學習率調度器類型(lr_scheduler_type):線性
- 學習率調度器熱身比例(lr_scheduler_warmup_ratio):0.12
- 訓練輪數(num_epochs):240
訓練結果
訓練損失 |
輪數 |
步數 |
驗證損失 |
字錯率(Wer) |
5.7027 |
21.05 |
400 |
3.4157 |
1.0 |
1.1638 |
42.1 |
800 |
1.3498 |
0.7461 |
0.2266 |
63.15 |
1200 |
1.6147 |
0.7273 |
0.1473 |
84.21 |
1600 |
1.6649 |
0.7108 |
0.1043 |
105.26 |
2000 |
1.7691 |
0.7090 |
0.0779 |
126.31 |
2400 |
1.8300 |
0.7009 |
0.0613 |
147.36 |
2800 |
1.8681 |
0.6916 |
0.0471 |
168.41 |
3200 |
1.8567 |
0.6875 |
0.0343 |
189.46 |
3600 |
1.9054 |
0.6840 |
0.0265 |
210.51 |
4000 |
1.9020 |
0.6786 |
0.0219 |
231.56 |
4400 |
1.9068 |
0.6679 |
框架版本
- Transformers:4.16.0
- Pytorch:1.10.0+cu111
- Datasets:1.17.0
- Tokenizers:0.10.3
📄 許可證
該模型使用Apache-2.0許可證。