模型简介
模型特点
模型能力
使用案例
🚀 distil-whisper-large-v3-es
这是一个经过蒸馏的Whisper v3大模型的仓库,该模型在Mozilla Common Voice数据集v16.1上进行了训练。此模型的诞生得益于SandboxAI和内格罗河国立大学的合作。
🚀 快速开始
Distil-Whisper从Hugging Face 🤗 Transformers的4.35版本开始得到支持。要运行该模型,首先需要安装最新版本的Transformers库。在这个示例中,我们还将安装🤗 Datasets,以便从Hugging Face Hub加载玩具音频数据集:
pip install --upgrade pip
pip install --upgrade transformers accelerate datasets[audio]
✨ 主要特性
- 支持多种转录方式:可进行短音频(< 30秒)和长音频(> 30秒)的转录,还支持推测解码。
- 高效的长音频转录算法:Distil-Whisper使用分块算法进行长音频转录,比OpenAI在Whisper论文中提出的顺序算法快9倍。
- 可作为推测解码的辅助模型:能作为Whisper的辅助模型进行推测解码,在保证输出结果与Whisper相同的情况下,速度提升2倍。
💻 使用示例
基础用法
短音频转录
该模型可以使用pipeline
类来转录短音频文件(< 30秒),如下所示:
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from datasets import load_dataset
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = "marianbasti/distil-whisper-large-v3-es"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
torch_dtype=torch_dtype,
device=device,
)
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
sample = dataset[0]["audio"]
result = pipe(sample)
print(result["text"])
要转录本地音频文件,只需在调用pipeline时传入音频文件的路径:
- result = pipe(sample)
+ result = pipe("audio.mp3")
高级用法
长音频转录
Distil-Whisper使用分块算法来转录长音频文件(> 30秒)。实际上,这种分块长音频算法比OpenAI在Whisper论文中提出的顺序算法快9倍(参见Distil-Whisper论文的表7)。
要启用分块功能,需要将chunk_length_s
参数传递给pipeline
。对于Distil-Whisper,15秒的分块长度是最优的。要启用批处理,需要传递batch_size
参数:
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from datasets import load_dataset
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = "marianbasti/distil-whisper-large-v3-es"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
chunk_length_s=15,
batch_size=16,
torch_dtype=torch_dtype,
device=device,
)
dataset = load_dataset("distil-whisper/librispeech_long", "clean", split="validation")
sample = dataset[0]["audio"]
result = pipe(sample)
print(result["text"])
推测解码
Distil-Whisper可以作为Whisper的辅助模型用于推测解码。推测解码在数学上保证了与Whisper相同的输出,同时速度快2倍。这使得它成为现有Whisper管道的完美替代品,因为可以保证相同的输出。 在以下代码片段中,我们将辅助Distil-Whisper模型独立加载到主Whisper管道中。然后,我们将其指定为生成的“辅助模型”:
from transformers import pipeline, AutoModelForCausalLM, AutoModelForSpeechSeq2Seq, AutoProcessor
import torch
from datasets import load_dataset
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
assistant_model_id = "marianbasti/distil-whisper-large-v3-es"
assistant_model = AutoModelForCausalLM.from_pretrained(
assistant_model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
assistant_model.to(device)
model_id = "openai/whisper-large-v3"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
generate_kwargs={"assistant_model": assistant_model},
torch_dtype=torch_dtype,
device=device,
)
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
sample = dataset[0]["audio"]
result = pipe(sample)
print(result["text"])
🔧 技术细节
该模型在单个RTX3090上进行了60,000次优化步骤(约1.47个epoch)的训练,训练时长约为60小时,使用的训练参数如下:
--teacher_model_name_or_path "openai/whisper-large-v3"
--train_dataset_name "mozilla-foundation/common_voice_16_1"
--train_dataset_config_name "es"
--train_split_name "train"
--text_column_name "sentence"
--eval_dataset_name "mozilla-foundation/common_voice_16_1"
--eval_dataset_config_name "es"
--eval_split_name "validation"
--eval_text_column_name "sentence"
--eval_steps 10000
--save_steps 10000
--warmup_steps 500
--learning_rate 1e-4
--lr_scheduler_type "linear"
--logging_steps 25
--save_total_limit 1
--max_steps 60000
--wer_threshold 10
--per_device_train_batch_size 8
--per_device_eval_batch_size 8
--dataloader_num_workers 12
--preprocessing_num_workers 12
--output_dir "./"
--do_train
--do_eval
--gradient_checkpointing
--predict_with_generate
--overwrite_output_dir
--use_pseudo_labels "false"
--freeze_encoder
--streaming False
📚 详细文档
训练结果
蒸馏后的模型的字错误率(WER)为5.11%(正交WER为10.15%)。
📄 许可证
Distil-Whisper继承了OpenAI的Whisper模型的MIT许可证。
引用
如果您使用此模型,请考虑引用Distil-Whisper论文:
@misc{gandhi2023distilwhisper,
title={Distil-Whisper: Robust Knowledge Distillation via Large-Scale Pseudo Labelling},
author={Sanchit Gandhi and Patrick von Platen and Alexander M. Rush},
year={2023},
eprint={2311.00430},
archivePrefix={arXiv},
primaryClass={cs.CL}
}



