🚀 西班牙语版Whisper-large-v2模型
本模型是基于Transformer架构的语音识别模型,在特定数据集上微调了openai/whisper-large-v2
模型,在评估集上取得了较好的损失和字错率(WER)指标。
🚀 快速开始
此模型是 openai/whisper-large-v2 在特定数据集上的微调版本。它在评估集上取得了以下结果:
- 损失值:0.1466
- 字错率(WER):0.0855
✨ 主要特性
- 基于
openai/whisper-large-v2
模型微调,在特定任务上表现更优。
- 提供了训练超参数和训练结果,方便了解模型训练过程。
- 给出了转录和评估的代码示例,便于使用和验证模型。
📦 安装指南
文档未提供具体安装步骤,可参考openai/whisper-large-v2
的安装说明,并确保安装以下框架版本:
- Transformers 4.26.0.dev0
- Pytorch 1.13.1+cu117
- Datasets 2.8.1.dev0
- Tokenizers 0.13.2
💻 使用示例
基础用法
from datasets import load_dataset, Audio
import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
processor = WhisperProcessor.from_pretrained("clu-ling/whisper-large-v2-spanish")
model = WhisperForConditionalGeneration.from_pretrained("clu-ling/whisper-large-v2-spanish").to(device)
forced_decoder_ids = processor.get_decoder_prompt_ids(language="es", task="transcribe")
commonvoice_eval = load_dataset("mozilla-foundation/common_voice_11_0", "es", split="validation", streaming=True)
commonvoice_eval = commonvoice_eval.cast_column("audio", Audio(sampling_rate=16000))
sample = next(iter(commonvoice_eval))["audio"]
input_features = processor(sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="pt").input_features
predicted_ids = model.generate(input_features.to(device), forced_decoder_ids=forced_decoder_ids)
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
print(transcription)
高级用法
from transformers.models.whisper.english_normalizer import BasicTextNormalizer
from datasets import load_dataset, Audio
import evaluate
import torch
import re
from transformers import WhisperProcessor, WhisperForConditionalGeneration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
wer_metric = evaluate.load("wer")
processor = WhisperProcessor.from_pretrained("clu-ling/whisper-large-v2-spanish")
model = WhisperForConditionalGeneration.from_pretrained("clu-ling/whisper-large-v2-spanish")
dataset = load_dataset("mozilla-foundation/common_voice_11_0", "es", split="test", )
dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))
def normalize(batch):
batch["gold_text"] = whisper_norm(batch['sentence'])
return batch
def map_wer(batch):
model.to(device)
forced_decoder_ids = processor.get_decoder_prompt_ids(language = "es", task = "transcribe")
inputs = processor(batch["audio"]["array"], sampling_rate=batch["audio"]["sampling_rate"], return_tensors="pt").input_features
with torch.no_grad():
generated_ids = model.generate(inputs=inputs.to(device), forced_decoder_ids=forced_decoder_ids)
transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
batch["predicted_text"] = whisper_norm(transcription)
return batch
processed_dataset = dataset.map(normalize)
predicted = processed_dataset.map(map_wer)
wer = wer_metric.compute(references=predicted['gold_text'], predictions=predicted['predicted_text'])
wer = round(100 * wer, 2)
print("WER:", wer)
🔧 技术细节
训练超参数
训练过程中使用了以下超参数:
- 学习率:1e-05
- 训练批次大小:16
- 评估批次大小:16
- 随机种子:42
- 优化器:Adam(β1=0.9,β2=0.999,ε=1e-08)
- 学习率调度器类型:线性
- 学习率调度器热身步数:500
- 训练步数:25000
- 混合精度训练:Native AMP
训练结果
训练损失 |
轮数 |
步数 |
验证损失 |
字错率(WER) |
0.1908 |
0.03 |
1000 |
0.2235 |
0.1154 |
0.1888 |
0.07 |
2000 |
0.2132 |
0.1131 |
0.167 |
0.1 |
3000 |
0.2115 |
0.1133 |
0.1752 |
0.14 |
4000 |
0.2081 |
0.1146 |
0.1656 |
0.17 |
5000 |
0.2002 |
0.1073 |
0.1535 |
0.21 |
6000 |
0.1971 |
0.1086 |
0.1854 |
0.24 |
7000 |
0.1927 |
0.1048 |
0.1722 |
0.28 |
8000 |
0.1889 |
0.1043 |
0.166 |
0.31 |
9000 |
0.1850 |
0.1022 |
0.1277 |
0.35 |
10000 |
0.1820 |
0.1032 |
0.1457 |
0.38 |
11000 |
0.1777 |
0.0998 |
0.169 |
0.42 |
12000 |
0.1771 |
0.0982 |
0.1612 |
0.45 |
13000 |
0.1724 |
0.0976 |
0.1616 |
0.49 |
14000 |
0.1693 |
0.0956 |
0.1556 |
0.52 |
15000 |
0.1671 |
0.0942 |
0.1448 |
0.56 |
16000 |
0.1646 |
0.0930 |
0.117 |
0.59 |
17000 |
0.1613 |
0.0914 |
0.1441 |
0.62 |
18000 |
0.1596 |
0.0899 |
0.148 |
0.66 |
19000 |
0.1571 |
0.0895 |
0.1255 |
0.69 |
20000 |
0.1547 |
0.0874 |
0.1479 |
0.73 |
21000 |
0.1525 |
0.0885 |
0.1304 |
0.76 |
22000 |
0.1503 |
0.0861 |
0.1111 |
0.8 |
23000 |
0.1486 |
0.0867 |
0.1337 |
0.83 |
24000 |
0.1472 |
0.0854 |
0.1289 |
0.87 |
25000 |
0.1466 |
0.0855 |
📄 许可证
本模型采用Apache-2.0许可证。