🚀 whisper-small-sp
このモデルは、commonvoice dataset v11
データセット上で openai/whisper-small をファインチューニングしたバージョンです。
評価セットでは以下の結果を達成しています。
- 損失: 0.4485
- 単語誤り率 (Wer): 20.6842
📦 インストール
必要なライブラリをインストールするには、以下のコマンドを使用します。
pip install transformers datasets evaluate torch
💻 使用例
基本的な使用法
from datasets import load_dataset, Audio
import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
processor = WhisperProcessor.from_pretrained("clu-ling/whisper-small-spanish")
model = WhisperForConditionalGeneration.from_pretrained("clu-ling/whisper-small-spanish").to(device)
forced_decoder_ids = processor.get_decoder_prompt_ids(language="es", task="transcribe")
commonvoice_eval = load_dataset("mozilla-foundation/common_voice_11_0", "es", split="validation", streaming=True)
commonvoice_eval = commonvoice_eval.cast_column("audio", Audio(sampling_rate=16000))
sample = next(iter(commonvoice_eval))["audio"]
input_features = processor(sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="pt").input_features
predicted_ids = model.generate(input_features.to(device), forced_decoder_ids=forced_decoder_ids)
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
print(transcription)
高度な使用法
from transformers.models.whisper.english_normalizer import BasicTextNormalizer
from datasets import load_dataset, Audio
import evaluate
import torch
import re
from transformers import WhisperProcessor, WhisperForConditionalGeneration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
wer_metric = evaluate.load("wer")
processor = WhisperProcessor.from_pretrained("clu-ling/whisper-small-spanish")
model = WhisperForConditionalGeneration.from_pretrained("clu-ling/whisper-small-spanish")
dataset = load_dataset("mozilla-foundation/common_voice_11_0", "es", split="test", )
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
def normalize(batch):
batch["gold_text"] = whisper_norm(batch['sentence'])
return batch
def map_wer(batch):
model.to(device)
forced_decoder_ids = processor.get_decoder_prompt_ids(language = "es", task = "transcribe")
inputs = processor(batch["audio"]["array"], sampling_rate=batch["audio"]["sampling_rate"], return_tensors="pt").input_features
with torch.no_grad():
generated_ids = model.generate(inputs=inputs.to(device), forced_decoder_ids=forced_decoder_ids)
transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
batch["predicted_text"] = whisper_norm(transcription)
return batch
processed_dataset = dataset.map(normalize)
predicted = processed_dataset.map(map_wer)
wer = wer_metric.compute(references=predicted['gold_text'], predictions=predicted['predicted_text'])
wer = round(100 * wer, 2)
print("WER:", wer)
🔧 技術詳細
トレーニングハイパーパラメータ
トレーニング中に以下のハイパーパラメータが使用されました。
- 学習率 (learning_rate): 0.0005
- トレーニングバッチサイズ (train_batch_size): 16
- 評価バッチサイズ (eval_batch_size): 8
- 乱数シード (seed): 42
- オプティマイザ (optimizer): Adam (betas=(0.9,0.999), epsilon=1e-08)
- 学習率スケジューラの種類 (lr_scheduler_type): linear
- 学習率スケジューラのウォームアップステップ数 (lr_scheduler_warmup_steps): 500
- トレーニングステップ数 (training_steps): 25000
- 混合精度トレーニング (mixed_precision_training): Native AMP
トレーニング結果
トレーニング損失 |
エポック |
ステップ |
検証損失 |
単語誤り率 (Wer) |
2.2671 |
0.13 |
1000 |
2.2108 |
76.2667 |
1.4465 |
0.26 |
2000 |
1.6057 |
67.8753 |
1.0997 |
0.39 |
3000 |
1.1928 |
54.2433 |
0.9389 |
0.52 |
4000 |
1.0020 |
47.8307 |
0.7881 |
0.65 |
5000 |
0.8933 |
46.0046 |
0.7596 |
0.78 |
6000 |
0.7721 |
38.5595 |
0.5678 |
0.91 |
7000 |
0.6903 |
36.2897 |
0.4412 |
1.04 |
8000 |
0.6476 |
32.7473 |
0.4239 |
1.17 |
9000 |
0.5973 |
30.8142 |
0.3935 |
1.3 |
10000 |
0.5444 |
29.0208 |
0.3307 |
1.43 |
11000 |
0.5024 |
27.0434 |
0.2937 |
1.56 |
12000 |
0.4608 |
24.7318 |
0.2471 |
1.69 |
13000 |
0.4259 |
22.8940 |
0.2357 |
1.82 |
14000 |
0.3936 |
21.6018 |
0.2292 |
1.95 |
15000 |
0.3776 |
20.8004 |
0.1493 |
2.08 |
16000 |
0.4599 |
24.0491 |
0.1708 |
2.21 |
17000 |
0.4370 |
23.3443 |
0.1385 |
2.34 |
18000 |
0.4277 |
22.3171 |
0.1288 |
2.47 |
19000 |
0.4050 |
21.0118 |
0.1627 |
2.6 |
20000 |
0.4507 |
23.4004 |
0.1675 |
2.73 |
21000 |
0.4346 |
22.8261 |
0.159 |
2.86 |
22000 |
0.4179 |
22.2949 |
0.1458 |
2.99 |
23000 |
0.3978 |
21.0810 |
0.0487 |
3.12 |
24000 |
0.4456 |
20.8617 |
0.0401 |
3.25 |
25000 |
0.4485 |
20.6842 |
📄 ライセンス
このモデルは Apache-2.0 ライセンスの下で提供されています。