模型简介
模型特点
模型能力
使用案例
🚀 Distil-Whisper: Distil-Large-v3.5
Distil-Whisper是OpenAI的Whisper-Large-v3的知识蒸馏版本,相关内容在论文Robust Knowledge Distillation via Large-Scale Pseudo Labelling中有所描述。作为Distil-Whisper英语系列的最新成员,Distil-Large-v3.5在保持前代高效性的同时,展现出了更优的性能。
与早期模型相比,Distil-Large-v3.5在超过4倍(98000小时)更多样化的公共数据上进行了训练,并且在蒸馏过程中使用了具有扩展训练计划和激进数据增强(SpecAugment)的"耐心"教师。这使得它与之前的Distil-Whisper模型相比,鲁棒性和准确性都得到了增强,适合作为直接替代品。
模型 | 参数数量(百万) | 相对实时因子 | 短文本离群词错误率 | 长文本离群词错误率 |
---|---|---|---|---|
large-v3-turbo | 809 | 1.0 | 7.30 | 10.25 |
distil-large-v3 | 756 | 1.44 | 7.53 | 11.6 |
distil-large-v3.5 | 756 | 1.46 | 7.08 | 11.39 |
既然已经有了Whisper-Large-v3-Turbo,为什么还要考虑Distil-Large-v3.5呢?
- 它在准确性和效率之间提供了不同的平衡,比Whisper-Large-v3-Turbo 快约1.5倍,同时在短文本转录上表现略好,在长文本转录上落后约1%。
- 它可以完美地作为与Whisper-Large-v3进行推测解码的草稿模型。通过在训练期间冻结编码器,我们只需要加载两个额外的解码器层,并仅对编码器进行一次前向传播。这使得推理速度比Whisper-Large-v3快约2倍,同时保持输出结果相同。
该模型是Bofeng Huang、Eustache Le Bihan、Steven Zheng和Vaibhav Srivastav在🤗上的合作成果。
🚀 快速开始
Distil-Large-v3.5从Hugging Face 🤗 Transformers库的4.39版本开始得到支持。要运行该模型,首先需要安装最新版本的Transformers。在这个示例中,我们还将安装🤗 Datasets,以便从Hugging Face Hub加载一个玩具音频数据集:
pip install --upgrade pip
pip install --upgrade transformers accelerate datasets[audio]
✨ 主要特性
- 知识蒸馏版本,在保持高效性的同时提升性能。
- 训练数据更多样化,鲁棒性和准确性增强。
- 在准确性和效率之间提供不同平衡。
- 可作为推测解码的草稿模型,提升推理速度。
- 支持多种库集成,方便不同场景使用。
📦 安装指南
安装Transformers及相关库
pip install --upgrade pip
pip install --upgrade transformers accelerate datasets[audio]
不同库集成的安装步骤
Whisper.cpp
git clone https://github.com/ggerganov/whisper.cpp.git
cd whisper.cpp
make
pip install --upgrade huggingface_hub
Faster-Whisper
pip install --upgrade pip
pip install --upgrade git+https://github.com/SYSTRAN/faster-whisper datasets[audio]
OpenAI Whisper
pip install --upgrade pip
pip install --upgrade openai-whisper datasets[audio]
Candle
git clone https://github.com/huggingface/candle.git
cd candle/candle-examples/examples/whisper
💻 使用示例
基础用法
短文本转录
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from datasets import load_dataset
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = "distil-whisper/distil-large-v3.5"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
torch_dtype=torch_dtype,
device=device,
)
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
sample = dataset[0]["audio"]
result = pipe(sample)
print(result["text"])
转录本地音频文件
- result = pipe(sample)
+ result = pipe("audio.mp3")
获取片段级时间戳
result = pipe(sample, return_timestamps=True)
print(result["chunks"])
高级用法
顺序长文本转录
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from datasets import load_dataset
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = "distil-whisper/distil-large-v3.5"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
torch_dtype=torch_dtype,
device=device,
)
dataset = load_dataset("distil-whisper/librispeech_long", "clean", split="validation")
sample = dataset[0]["audio"]
result = pipe(sample)
print(result["text"])
分块长文本转录
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from datasets import load_dataset
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = "distil-whisper/distil-large-v3.5"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
chunk_length_s=25,
batch_size=16,
torch_dtype=torch_dtype,
device=device,
)
dataset = load_dataset("distil-whisper/librispeech_long", "clean", split="validation")
sample = dataset[0]["audio"]
result = pipe(sample)
print(result["text"])
推测解码
from transformers import pipeline, AutoModelForCausalLM, AutoModelForSpeechSeq2Seq, AutoProcessor
import torch
from datasets import load_dataset
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
assistant_model_id = "distil-whisper/distil-large-v3.5"
assistant_model = AutoModelForCausalLM.from_pretrained(
assistant_model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
assistant_model.to(device)
model_id = "openai/whisper-large-v3"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
generate_kwargs={"assistant_model": assistant_model},
torch_dtype=torch_dtype,
device=device,
)
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
sample = dataset[0]["audio"]
result = pipe(sample)
print(result["text"])
📚 详细文档
性能评估
短文本评估
模型在5个分布内(ID)测试集和2个分布外(OOD)测试集上进行了短文本转录评估,结果如下:
数据集 | 大小(小时) | large-v3 | large-v3-turbo | distil-v3 | distil-v3.5 |
---|---|---|---|---|---|
AMI | 8.68 | 15.95 | 16.13 | 15.16 | 14.63 |
Gigaspeech | 35.36 | 10.02 | 10.14 | 10.08 | 9.84 |
LS Clean | 5.40 | 2.01 | 2.10 | 2.54 | 2.37 |
LS Other | 5.34 | 3.91 | 4.24 | 5.19 | 5.04 |
Tedlium | 2.61 | 3.86 | 3.57 | 3.86 | 3.64 |
----------- | ----- | ----- | ----- | ----- | ----- |
Earnings22 | 5.43 | 11.29 | 11.63 | 11.79 | 11.29 |
SPGISpeech | 100.00 | 2.94 | 2.97 | 3.27 | 2.87 |
----------- | ----- | ----- | ----- | ----- | ----- |
ID平均 | 7.15 | 7.24 | 7.37 | 7.10 | |
OOD平均 | 7.12 | 7.30 | 7.53 | 7.08 | |
总平均 | 7.14 | 7.25 | 7.41 | 7.10 |
长文本评估
模型在1个分布内(ID)测试集和4个分布外(OOD)测试集上进行了长文本转录评估,使用顺序解码算法(condition_on_prev_tokens=False, return_timestamps=True),结果如下:
数据集 | 大小(小时) | large-v3-turbo | distil-v2 | distil-v3 | distil-v3.5 |
---|---|---|---|---|---|
tedlium-long-form | 2.47 | 3.07 | 9.66 | 3.9 | 4.63 |
----------------- | ----- | ----- | ----- | ----- | ----- |
meanwhile | 1.01 | 5.03 | 16.75 | 7.04 | 6.79 |
earnings21 | 39.26 | 9.84 | 15.09 | 10.54 | 10.6 |
earnings22 | 119.89 | 13.32 | 19.11 | 15.06 | 14.19 |
rev16 | 16.16 | 12.82 | 21.15 | 13.76 | 13.98 |
----------------- | ----- | ----- | ----- | ----- | ----- |
ID平均 | 3.07 | 9.66 | 3.9 | 4.63 | |
OOD平均 | 10.25 | 18.03 | 11.6 | 11.39 | |
总平均 | 8.82 | 16.35 | 10.06 | 10.04 |
不同算法使用场景
- 顺序长文本算法:适用于转录准确性是最重要因素,且对延迟考虑较少的场景;或者转录批量长音频文件的场景,此时顺序算法的延迟与分块算法相当,但词错误率可提高达0.5%。
- 分块长文本算法:适用于转录单个大音频文件且需要最快推理速度的场景,比分块算法快达9倍。
不同库集成使用说明
Whisper.cpp
git clone https://github.com/ggerganov/whisper.cpp.git
cd whisper.cpp
make
pip install --upgrade huggingface_hub
Faster-Whisper
pip install --upgrade pip
pip install --upgrade git+https://github.com/SYSTRAN/faster-whisper datasets[audio]
OpenAI Whisper
pip install --upgrade pip
pip install --upgrade openai-whisper datasets[audio]
Candle
git clone https://github.com/huggingface/candle.git
cd candle/candle-examples/examples/whisper
🔧 技术细节
训练细节
Distil-Whisper继承了Whisper的编码器 - 解码器架构。编码器将语音向量输入序列映射到隐藏状态向量序列,解码器根据所有先前的令牌和编码器隐藏状态自回归地预测文本令牌。因此,编码器只进行一次前向传播,而解码器的运行次数与生成的令牌数量相同。在实践中,这意味着解码器占总推理时间的90%以上。因此,为了优化延迟,重点在于最小化解码器的推理时间。
Distil-Large-v3.5基于先前Distil-Whisper模型的技术,特别是Distil-Large-v2中引入的教师模型蒸馏训练过程,以及Distil-Large-v3中用于改进顺序解码长文本转录的样本打包方法。除了这些已有的方法,我们还进行了以下显著改进:
- 大幅扩展训练数据:将训练数据集从之前版本的22000小时增加到超过98000小时的高质量公共数据,这主要得益于Yodas数据集,它提供了来自YouTube的多样化内容。
- 实现了具有激进数据增强(SpecAugment)的"耐心"教师:这使得训练计划大幅扩展到80个周期,而之前版本仅为11个周期。即使在这个延长的时间段内,模型的评估损失仍在略微下降,表明模型仍在持续改进。
- 降低训练数据中附加先前提示的概率:最初设置为50%,但发现模型在转录包含先前上下文的文本时存在困难,可能是由于解码器大小的限制。随后将其降低到20%,这在提高整体性能的同时,仍然起到了数据增强的作用。
- 增加批量大小和学习率:实现了4096个打包片段的更大批量大小,远大于之前版本使用的256个。还测试了包括余弦、wsd和scheduler-free optimizers在内的替代学习率调度器,但发现线性方法仍然表现最佳。
我们还修改了片段打包顺序,以创建更具逻辑结构的打包片段,并采用了BPE dropout作为正则化方法。我们发现这对短文本转录性能略有下降,但对长文本内容的结果有所改善。
该模型在Jean Zay集群上使用64个H100 GPU进行训练,整个训练过程耗时三天。Tensorboard训练日志可在这里找到。
训练数据
我们最初从Common Voice、LibriSpeech、VoxPopuli、TED-LIUM、People's Speech、GigaSpeech、AMI等来源收集了超过196000小时的公共数据,特别是Yodas。这个多样化的数据集对于确保我们的蒸馏模型在各种音频分布和噪声条件下保持鲁棒性至关重要。
我们将收集到的示例打包成大约30秒的片段,每个片段只包含一个说话者。为了保持数据集之间的高质量和一致格式,我们使用Whisper-Large-v3对这些训练片段进行了伪标记。然后,我们对Whisper伪标记和每个数据集提供的真实标记进行了归一化处理,并计算了它们之间的词错误率(WER)。我们丢弃了任何WER超过10%的示例,最终得到了大约98000小时的高质量训练数据。
过滤后的数据可以在这个多语言数据集中找到,以便进行重现和进一步过滤。
重现Distil-Whisper
用于重现Distil-Whisper的训练和评估代码可在Distil-Whisper仓库中找到。
📄 许可证
Distil-Whisper继承了OpenAI的Whisper模型的MIT许可证。
🔖 引用
如果您使用了这个模型,请考虑引用Distil-Whisper论文:
@misc{gandhi2023distilwhisper,
title={Distil-Whisper: Robust Knowledge Distillation via Large-Scale Pseudo Labelling},
author={Sanchit Gandhi and Patrick von Platen and Alexander M. Rush},
year={2023},
eprint={2311.00430},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
🙏 致谢
- OpenAI提供了Whisper模型和原始代码库。
- Hugging Face 🤗 Transformers实现了模型集成。
- Sanchit Gandhi建立了Distil-Whisper训练管道,并在这篇论文中进行了详细研究。
- GENCI在分配号2024 - GC011015467下提供了IDRIS HPC资源的访问权限。
- Georgi Gerganov实现了Whisper.cpp集成。
- Systran团队实现了Faster-Whisper集成。
- Joshua Lochner实现了Transformers.js集成。
- Laurent Mazare实现了Candle集成。
- Vaibhav Srivastav负责Distil-Whisper的分发。
- Raghav Sonavane在LibriSpeech数据集上进行了Distil-Whisper的早期迭代。



