模型简介
模型特点
模型能力
使用案例
🚀 Distil-Whisper: distil-large-v3
Distil-Whisper是一个用于医学语音识别的微调工作空间。该模型会经常更新,若你觉得当前版本对你有用,可直接复制该工作空间。Distil-Whisper在论文 Robust Knowledge Distillation via Large-Scale Pseudo Labelling 中被提出,本版本distil-large-v3是Distil-Whisper英语系列的第三个也是最后一个版本,它是OpenAI的 Whisper large-v3 的知识蒸馏版本,是目前最新且性能最佳的Whisper模型。
🚀 快速开始
本项目是Distil-Whisper的distil-large-v3版本,用于自动语音识别任务。以下是使用该模型的基本步骤:
- 安装必要的库,如
transformers
、datasets
等。 - 加载模型和处理器。
- 准备音频数据。
- 进行语音识别推理。
以下是一个简单的示例代码:
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from datasets import load_dataset
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = "distil-whisper/distil-large-v3"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
torch_dtype=torch_dtype,
device=device,
)
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
sample = dataset[0]["audio"]
result = pipe(sample)
print(result["text"])
✨ 主要特性
- 长文本转录准确性高:与之前的Distil-Whisper模型相比,distil-large-v3的蒸馏过程经过调整,使用OpenAI的顺序长文本算法时,具有卓越的长文本转录准确性。
- 速度更快:比之前的Distil-Whisper模型更快,比large-v3快6.3倍,比distil-large-v2快1.1倍。
- 兼容性强:与最流行的Whisper库(Whisper cpp、Faster-Whisper、OpenAI Whisper)兼容,使用这些库时,从之前的Distil-Whisper检查点切换到distil-large-v3可获得显著的性能提升。
- 支持多种算法:支持短文本转录、顺序长文本转录、分块长文本转录和推测解码等多种算法。
模型性能对比
模型 | 参数数量(百万) | 相对延迟 | 短文本 | 顺序长文本 | 分块长文本 |
---|---|---|---|---|---|
large-v3 | 1550 | 1.0 | 8.4 | 10.0 | 11.0 |
distil-large-v3 | 756 | 6.3 | 9.7 | 10.8 | 10.9 |
distil-large-v2 | 756 | 5.8 | 10.1 | 15.6 | 11.6 |
📦 安装指南
安装依赖库
pip install --upgrade pip
pip install --upgrade transformers accelerate datasets[audio]
安装特定库以支持不同使用场景
Whisper.cpp
git clone https://github.com/ggerganov/whisper.cpp.git
cd whisper.cpp
pip install --upgrade huggingface_hub
下载GGML权重:
from huggingface_hub import hf_hub_download
hf_hub_download(repo_id='distil-whisper/distil-large-v3-ggml', filename='ggml-distil-large-v3.bin', local_dir='./models')
或使用 wget
下载:
wget https://huggingface.co/distil-whisper/distil-large-v3-ggml/resolve/main/ggml-distil-large-v3.bin -P ./models
Faster-Whisper
pip install --upgrade pip
pip install --upgrade git+https://github.com/SYSTRAN/faster-whisper datasets[audio]
OpenAI Whisper
pip install --upgrade pip
pip install --upgrade openai-whisper datasets[audio]
Transformers.js
npm i @xenova/transformers
Candle
- 安装
candle-core
:按照 这里 的说明安装。 - 克隆
candle
仓库:
git clone https://github.com/huggingface/candle.git
- 进入Whisper示例目录:
cd candle/candle-examples/examples/whisper
💻 使用示例
基础用法
短文本转录
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from datasets import load_dataset
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = "distil-whisper/distil-large-v3"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
torch_dtype=torch_dtype,
device=device,
)
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
sample = dataset[0]["audio"]
result = pipe(sample)
print(result["text"])
转录本地音频文件:
- result = pipe(sample)
+ result = pipe("audio.mp3")
获取分段级时间戳:
result = pipe(sample, return_timestamps=True)
print(result["chunks"])
顺序长文本转录
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from datasets import load_dataset
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = "distil-whisper/distil-large-v3"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
torch_dtype=torch_dtype,
device=device,
)
dataset = load_dataset("distil-whisper/librispeech_long", "clean", split="validation")
sample = dataset[0]["audio"]
result = pipe(sample)
print(result["text"])
分块长文本转录
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from datasets import load_dataset
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = "distil-whisper/distil-large-v3"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
chunk_length_s=25,
batch_size=16,
torch_dtype=torch_dtype,
device=device,
)
dataset = load_dataset("distil-whisper/librispeech_long", "clean", split="validation")
sample = dataset[0]["audio"]
result = pipe(sample)
print(result["text"])
高级用法
推测解码
from transformers import pipeline, AutoModelForCausalLM, AutoModelForSpeechSeq2Seq, AutoProcessor
import torch
from datasets import load_dataset
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
assistant_model_id = "distil-whisper/distil-large-v3"
assistant_model = AutoModelForCausalLM.from_pretrained(
assistant_model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
assistant_model.to(device)
model_id = "openai/whisper-large-v3"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
generate_kwargs={"assistant_model": assistant_model},
torch_dtype=torch_dtype,
device=device,
)
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
sample = dataset[0]["audio"]
result = pipe(sample)
print(result["text"])
额外的速度和内存优化
Flash Attention 2
- model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True)
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True, attn_implementation="flash_attention_2")
Torch Scale-Product-Attention (SDPA)
- model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True)
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True, attn_implementation="sdpa")
📚 详细文档
模型细节
Distil-Whisper继承了Whisper的编码器 - 解码器架构。编码器将语音向量输入序列映射到隐藏状态向量序列,解码器根据所有先前的令牌和编码器隐藏状态自回归地预测文本令牌。因此,编码器只向前运行一次,而解码器运行的次数与生成的令牌数量相同。在实践中,这意味着解码器占总推理时间的90%以上。为了优化延迟,重点是最小化解码器的推理时间。
为了蒸馏Whisper模型,我们在保持编码器固定的同时减少了解码器层的数量。编码器(绿色部分)完全从教师模型复制到学生模型,并在训练期间冻结。学生模型的解码器由教师解码器层的一个子集组成,这些层从最大间隔的层初始化。然后,模型在KL散度和伪标签损失项的加权和上进行训练。
与distil-large-v2的差异
与之前的Distil-Whisper版本相比,distil-large-v3专门针对OpenAI顺序长文本转录算法进行了设计。与distil-large-v2相比,除了模型层是从最新的large-v3模型而不是较旧的large-v2模型初始化之外,没有架构上的差异。差异在于模型的训练方式。
之前的Distil-Whisper模型在平均输入长度为7秒的情况下进行训练,而原始的Whisper模型在30秒的输入上进行预训练。在蒸馏过程中,我们将模型权重的分布转移到训练数据的分布上。如果我们的训练数据包含较短的话语(例如,平均7秒的音频而不是30秒),那么预测分布将转移到这个较短的上下文长度。在推理时,distil-large-v2的最佳上下文窗口是这两个值的插值:15秒。超过这个时间,distil-large-v2模型的预测在很大程度上是不准确的,特别是对于时间戳预测。然而,顺序长文本算法使用30秒的滑动窗口进行推理,窗口根据最后预测的时间戳进行移动。由于最后一个时间戳通常发生在15秒标记之后,其预测的准确性较低,导致长文本转录经常失败。
为了保留Whisper转录30秒滑动窗口的能力,就像顺序解码那样,我们需要确保distil-large-v3的上下文长度也是30秒。这主要通过以下四种策略实现:
- 将训练数据集中的音频样本打包到30秒:由于模型在打包到30秒的音频数据上进行预训练和蒸馏,distil-large-v3现在在与Whisper相同的理想上下文窗口上运行,能够准确预测长达30秒的时间戳。
- 冻结解码器输入嵌入:我们使用与原始模型相同的输入嵌入表示,该表示旨在处理比之前的Distil-Whisper迭代更长的上下文长度。
- 在训练期间使用更长的最大上下文长度:我们在最大目标长度为256的情况下进行训练,而不是128。这有助于distil-large-v3转录可能超过128个令牌的30秒片段。
- 将提示条件附加到50%的训练样本上:使模型能够与
condition_on_prev_tokens
参数一起使用,以及处理长达448个令牌的上下文窗口。
评估
以下代码展示了如何在LibriSpeech验证清洁数据集上使用 流式模式 评估Distil-Whisper模型,即无需将音频数据下载到本地设备。
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
from datasets import load_dataset
from evaluate import load
import torch
from tqdm import tqdm
# define our torch configuration
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = "distil-whisper/distil-large-v3"
# load the model + processor
model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id, torch_dtype=torch_dtype, use_safetensors=True, low_cpu_mem_usage=True)
model = model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
# load the dataset with streaming mode
dataset = load_dataset("librispeech_asr", "clean", split="validation", streaming=True)
# define the evaluation metric
wer_metric = load("wer")
def inference(batch):
# 1. Pre-process the audio data to log-mel spectrogram inputs
audio = [sample["array"] for sample in batch["audio"]]
input_features = processor(audio, sampling_rate=batch["audio"][0]["sampling_rate"], return_tensors="pt").input_features
input_features = input_features.to(device, dtype=torch_dtype)
# 2. Auto-regressively generate the predicted token ids
pred_ids = model.generate(input_features, max_new_tokens=128)
# 3. Decode the token ids to the final transcription
batch["transcription"] = processor.batch_decode(pred_ids, skip_special_tokens=True)
batch["reference"] = batch["text"]
return batch
# batch size 16 inference
dataset = dataset.map(function=inference, batched=True, batch_size=16)
all_transcriptions = []
all_references = []
# iterate over the dataset and run inference
for result in tqdm(dataset, desc="Evaluating..."):
all_transcriptions.append(result["transcription"])
all_references.append(result["reference"])
# normalize predictions and references
all_transcriptions = [processor.normalize(transcription) for transcription in all_transcriptions]
all_references = [processor.normalize(reference) for reference in all_references]
# compute the WER metric
wer = 100 * wer_metric.compute(predictions=all_transcriptions, references=all_references)
print(wer)
预期用途
Distil-Whisper旨在作为Whisper large-v3在英语语音识别中的直接替代品。特别是,它在分布外(OOD)测试数据上实现了可比的WER结果,同时在短文本和长文本音频上快6倍。
数据
Distil-Whisper在来自Hugging Face Hub上九个开源、许可宽松的语音数据集的22,000小时音频数据上进行训练:
数据集 | 大小(小时) | 说话者数量 | 领域 | 许可证 |
---|---|---|---|---|
People's Speech | 12,000 | 未知 | 互联网存档 | CC-BY-SA-4.0 |
Common Voice 13 | 3,000 | 未知 | 旁白维基百科 | CC0-1.0 |
GigaSpeech | 2,500 | 未知 | 有声读物、播客、YouTube | apache-2.0 |
Fisher | 1,960 | 11,900 | 电话对话 | LDC |
LibriSpeech | 960 | 2,480 | 有声读物 | CC-BY-4.0 |
VoxPopuli | 540 | 1,310 | 欧洲议会 | CC0 |
TED-LIUM | 450 | 2,030 | TED演讲 | CC-BY-NC-ND 3.0 |
SwitchBoard | 260 | 540 | 电话对话 | LDC |
AMI | 100 | 未知 | 会议 | CC-BY-4.0 |
总计 | 21,770 | 18,260+ |
组合数据集涵盖10个不同领域和超过50k个说话者。这种数据集的多样性对于确保蒸馏模型对音频分布和噪声具有鲁棒性至关重要。
音频数据然后使用Whisper large-v3模型进行伪标签:我们使用Whisper为训练集中的所有音频生成预测,并在训练期间将这些预测用作目标标签。使用伪标签确保转录在数据集之间具有一致的格式,并在训练期间提供序列级蒸馏信号。
WER过滤
Whisper伪标签预测可能会出现误转录和幻觉。为了确保我们只在准确的伪标签上进行训练,我们在训练期间采用了一种简单的WER启发式方法。首先,我们对Whisper伪标签和每个数据集提供的真实标签进行归一化。然后,我们计算这些标签之间的WER。如果WER超过指定的阈值,我们丢弃该训练示例。否则,我们将其保留用于训练。
Distil-Whisper论文 的第9.2节展示了这种过滤方法对于提高蒸馏模型下游性能的有效性。我们还将Distil-Whisper对幻觉的鲁棒性部分归因于这种过滤方法。
训练
模型在批量大小为256的情况下进行了80,000次优化步骤(或11个epoch)的训练。Tensorboard训练日志可在 这里 找到。
结果
蒸馏模型在分布外(OOD)短文本音频上的WER与Whisper large-v3相差在1.5%以内,在顺序长文本解码上相差在1%以内,在分块长文本上比large-v3高0.1%。这种性能提升归因于较低的幻觉率。
有关评估结果的详细数据集细分,请参考 Distil-Whisper论文 的表16和表17。
Distil-Whisper还在 ESB基准 数据集上进行了评估,作为 OpenASR排行榜 的一部分,其WER与Whisper相差在0.2%以内。
复现Distil-Whisper
复现Distil-Whisper的训练和评估代码可在 Distil-Whisper仓库 中找到。该代码将很快更新,以包括 与distil-large-v2的差异 部分中描述的训练更新。
🔧 技术细节
模型架构
Distil-Whisper继承了Whisper的编码器 - 解码器架构。编码器将语音向量输入序列映射到隐藏状态向量序列,解码器根据所有先前的令牌和编码器隐藏状态自回归地预测文本令牌。
蒸馏方法
在蒸馏过程中,保持编码器固定,减少解码器层的数量。编码器完全从教师模型复制到学生模型,并在训练期间冻结。学生模型的解码器由教师解码器层的一个子集组成,这些层从最大间隔的层初始化。
训练数据处理
使用Whisper large-v3模型为训练数据生成伪标签,确保转录在数据集之间具有一致的格式,并在训练期间提供序列级蒸馏信号。同时,采用WER过滤方法,丢弃WER超过指定阈值的训练示例,以确保只在准确的伪标签上进行训练。
📄 许可证
Distil-Whisper继承了OpenAI的Whisper模型的 MIT许可证。
引用
如果您使用此模型,请考虑引用 Distil-Whisper论文:
@misc{gandhi2023distilwhisper,
title={Distil-Whisper: Robust Knowledge Distillation via Large-Scale Pseudo Labelling},
author={Sanchit Gandhi and Patrick von Platen and Alexander M. Rush},
year={2023},
eprint={2311.00430},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
致谢
- OpenAI提供了Whisper 模型,特别是Jong Wook Kim提供了 原始代码库 并参与了训练讨论。
- Hugging Face 🤗 Transformers 实现了模型集成。
- Georgi Gerganov 实现了Whisper cpp集成。
- Systran团队 实现了Faster-Whisper集成。
- Joshua Lochner 实现了Transformers.js集成。
- Laurent Mazare 实现了Candle集成。
- Vaibhav Srivastav 负责Distil-Whisper的分发。
- Google的 TPU Research Cloud (TRC) 计划提供了Cloud TPU v4计算资源。
- Raghav Sonavane 提供了Distil-Whisper在LibriSpeech数据集上的早期版本。



