模型简介
模型特点
模型能力
使用案例
🚀 Distil-Whisper: distil-large-v3
Distil-Whisper是在论文 Robust Knowledge Distillation via Large-Scale Pseudo Labelling 中提出的。这是Distil-Whisper英语系列的第三个也是最后一个版本,它是OpenAI的 Whisper large-v3 的知识蒸馏版本,Whisper large-v3是迄今为止最新且性能最佳的Whisper模型。
与之前的Distil-Whisper模型相比,distil-large-v3的蒸馏过程经过调整,结合OpenAI的 顺序长格式算法 可实现 卓越的长格式转录准确性。最终得到的蒸馏模型在长格式音频上使用顺序和分块算法时,其字错率(WER)与large-v3相差不到1%,并且在使用顺序算法时比distil-large-v2的WER低4.8%。该模型也比之前的Distil-Whisper模型更快:比large-v3快6.3倍,比distil-large-v2快1.1倍。
模型 | 参数数量(百万) | 相对延迟 | 短格式 | 顺序长格式 | 分块长格式 |
---|---|---|---|---|---|
large-v3 | 1550 | 1.0 | 8.4 | 10.0 | 11.0 |
distil-large-v3 | 756 | 6.3 | 9.7 | 10.8 | 10.9 |
distil-large-v2 | 756 | 5.8 | 10.1 | 15.6 | 11.6 |
由于顺序算法是最流行的Whisper库(Whisper cpp、Faster-Whisper、OpenAI Whisper)中“事实上”的转录算法,因此这个蒸馏模型设计为与这些库兼容。当使用这些库时,从之前的Distil-Whisper检查点切换到distil-large-v3,你可以期待显著的性能提升。为了方便使用,最流行库的权重已经转换好,下面是使用说明。
🚀 快速开始
本项目是Distil-Whisper系列的distil-large-v3模型,是OpenAI的Whisper large-v3的知识蒸馏版本,在长格式转录准确性和速度上有显著提升,且与多种流行库兼容。以下是使用该模型的快速指南。
✨ 主要特性
- 卓越的长格式转录准确性:结合OpenAI的顺序长格式算法,在长格式音频转录上表现出色。
- 更快的推理速度:比Whisper large-v3快6.3倍,比distil-large-v2快1.1倍。
- 广泛的兼容性:与多种流行的Whisper库(如Whisper cpp、Faster-Whisper、OpenAI Whisper等)兼容。
- 支持多种算法:支持顺序长格式、分块长格式和推测解码等算法。
📦 安装指南
安装Transformers库
distil-large-v3从Hugging Face 🤗 Transformers库的4.39版本开始支持。要运行该模型,首先需要安装最新版本的Transformers库。以下是安装命令:
pip install --upgrade pip
pip install --upgrade transformers accelerate datasets[audio]
其他库的安装
根据不同的使用场景,可能还需要安装其他库,如Whisper.cpp、Faster-Whisper、OpenAI Whisper等,具体安装步骤在后续使用示例中会详细说明。
💻 使用示例
基础用法
短格式转录
模型可以使用 pipeline
类对短格式音频文件(< 30秒)进行转录,示例代码如下:
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from datasets import load_dataset
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = "distil-whisper/distil-large-v3"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
torch_dtype=torch_dtype,
device=device,
)
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
sample = dataset[0]["audio"]
result = pipe(sample)
print(result["text"])
要转录本地音频文件,只需在调用pipeline时传入音频文件的路径:
- result = pipe(sample)
+ result = pipe("audio.mp3")
要获取分段级别的时间戳,传入参数 return_timestamps=True
并返回 "chunks"
输出:
result = pipe(sample, return_timestamps=True)
print(result["chunks"])
顺序长格式转录
distil-large-v3专门设计为与OpenAI的顺序长格式转录算法兼容。以下是使用 pipeline
类对长音频文件进行顺序转录的示例代码:
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from datasets import load_dataset
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = "distil-whisper/distil-large-v3"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
torch_dtype=torch_dtype,
device=device,
)
dataset = load_dataset("distil-whisper/librispeech_long", "clean", split="validation")
sample = dataset[0]["audio"]
result = pipe(sample)
print(result["text"])
分块长格式转录
distil-large-v3仍然与Transformers分块长格式算法兼容。当需要转录单个大音频文件并追求最快推理速度时,应使用此算法。以下是启用分块转录的示例代码:
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from datasets import load_dataset
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = "distil-whisper/distil-large-v3"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
chunk_length_s=25,
batch_size=16,
torch_dtype=torch_dtype,
device=device,
)
dataset = load_dataset("distil-whisper/librispeech_long", "clean", split="validation")
sample = dataset[0]["audio"]
result = pipe(sample)
print(result["text"])
推测解码
distil-large-v3是第一个可以作为Whisper large-v3的辅助模型用于 推测解码 的Distil-Whisper模型。推测解码在数学上保证了与Whisper相同的输出,同时速度快2倍。以下是使用推测解码的示例代码:
from transformers import pipeline, AutoModelForCausalLM, AutoModelForSpeechSeq2Seq, AutoProcessor
import torch
from datasets import load_dataset
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
assistant_model_id = "distil-whisper/distil-large-v3"
assistant_model = AutoModelForCausalLM.from_pretrained(
assistant_model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
assistant_model.to(device)
model_id = "openai/whisper-large-v3"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
generate_kwargs={"assistant_model": assistant_model},
torch_dtype=torch_dtype,
device=device,
)
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
sample = dataset[0]["audio"]
result = pipe(sample)
print(result["text"])
高级用法
更多生成参数控制
若需要对生成参数进行更多控制,可以直接使用模型 + 处理器API。以下是示例代码:
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
from datasets import Audio, load_dataset
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = "distil-whisper/distil-large-v3"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
dataset = dataset.cast_column("audio", Audio(processor.feature_extractor.sampling_rate))
sample = dataset[0]["audio"]
input_features = processor(
sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="pt"
).input_features
input_features = input_features.to(device, dtype=torch_dtype)
gen_kwargs = {
"max_new_tokens": 128,
"num_beams": 1,
"return_timestamps": False,
}
pred_ids = model.generate(input_features, **gen_kwargs)
pred_text = processor.batch_decode(pred_ids, skip_special_tokens=True, decode_with_timestamps=gen_kwargs["return_timestamps"])
print(pred_text)
额外的速度和内存优化
可以对Distil-Whisper应用额外的速度和内存优化,以进一步降低推理速度和显存要求。主要有以下几种优化方法:
Flash Attention 2
如果你的GPU支持,建议使用 Flash-Attention 2。首先需要安装 Flash Attention:
pip install flash-attn --no-build-isolation
然后在 from_pretrained
中传入 attn_implementation="flash_attention_2"
:
- model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True)
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True, attn_implementation="flash_attention_2")
Torch Scale-Product-Attention (SDPA)
如果你的GPU不支持Flash Attention,建议使用PyTorch scaled dot-product attention (SDPA)。对于PyTorch 2.1.1或更高版本,此注意力实现默认启用。可以使用以下代码检查是否有兼容的PyTorch版本:
from transformers.utils import is_torch_sdpa_available
print(is_torch_sdpa_available())
如果返回 True
,则已安装有效的PyTorch版本,SDPA默认启用;如果返回 False
,则需要根据 官方说明 升级PyTorch版本。安装有效版本后,SDPA默认启用,也可以通过指定 attn_implementation="sdpa"
显式设置:
- model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True)
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True, attn_implementation="sdpa")
📚 详细文档
与其他库的集成
Whisper.cpp
Distil-Whisper可以使用 Whisper.cpp 包结合原始顺序长格式转录算法运行。在Mac M1上的临时基准测试中,distil-large-v3比Whisper large-v3快5倍以上,在长格式音频上的WER相差不到0.8%。以下是使用步骤:
- 克隆Whisper.cpp仓库:
git clone https://github.com/ggerganov/whisper.cpp.git
cd whisper.cpp
- 安装Hugging Face Hub Python包:
pip install --upgrade huggingface_hub
使用以下Python代码片段下载distil-large-v3的GGML权重:
from huggingface_hub import hf_hub_download
hf_hub_download(repo_id='distil-whisper/distil-large-v3-ggml', filename='ggml-distil-large-v3.bin', local_dir='./models')
如果你没有设置Python环境,也可以使用 wget
直接下载权重:
wget https://huggingface.co/distil-whisper/distil-large-v3-ggml/resolve/main/ggml-distil-large-v3.bin -P ./models
- 使用提供的示例音频运行推理:
make -j && ./main -m models/ggml-distil-large-v3.bin -f samples/jfk.wav
Faster-Whisper
Faster-Whisper是使用 CTranslate2 重新实现的Whisper,是一个快速的Transformer模型推理引擎。首先,根据 官方说明 安装Faster-Whisper包。以下是使用示例代码:
import torch
from faster_whisper import WhisperModel
from datasets import load_dataset
# define our torch configuration
device = "cuda:0" if torch.cuda.is_available() else "cpu"
compute_type = "float16" if torch.cuda.is_available() else "float32"
# load model on GPU if available, else cpu
model = WhisperModel("distil-large-v3", device=device, compute_type=compute_type)
# load toy dataset for example
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
sample = dataset[1]["audio"]["path"]
segments, info = model.transcribe(sample, beam_size=1)
for segment in segments:
print("[%.2fs -> %.2fs] %s" % (segment.start, segment.end, segment.text))
要转录本地音频文件,只需在 transcribe
中传入音频文件的路径:
segments, info = model.transcribe("audio.mp3", beam_size=1)
OpenAI Whisper
要使用原始Whisper格式的模型,首先确保安装了 openai-whisper
包。以下是使用示例代码:
from huggingface_hub import hf_hub_download
from datasets import load_dataset
from whisper import load_model, transcribe
model_path = hf_hub_download(repo_id="distil-whisper/distil-large-v3-openai", filename="model.bin")
model = load_model(model_path)
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
sample = dataset[0]["audio"]["path"]
pred_out = transcribe(model, audio=sample, language="en")
print(pred_out["text"])
注意,第一次运行示例时,模型权重将被下载并保存到缓存中。随后可以重复使用相同的示例,权重将直接从缓存中加载,无需再次下载。要转录本地音频文件,只需在 transcribe
中传入音频文件的路径:
pred_out = transcribe(model, audio=sample, language="en")
Distil-Whisper模型也可以与OpenAI Whisper CLI一起使用,具体说明请参考 这里。
Transformers.js
Distil-Whisper可以使用 Transformers.js 完全在你的Web浏览器中运行。以下是使用步骤:
- 从 NPM 安装Transformers.js:
npm i @xenova/transformers
- 导入库并使用pipeline API进行推理:
import { pipeline } from '@xenova/transformers';
const transcriber = await pipeline('automatic-speech-recognition', 'distil-whisper/distil-large-v3');
const url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/jfk.wav';
const output = await transcriber(url);
// { text: " And so, my fellow Americans, ask not what your country can do for you. Ask what you can do for your country." }
可以尝试在线 Distil-Whisper Web演示,它在你的浏览器中本地运行,无需服务器!更多信息请参考Transformers.js 文档。
Candle
通过与Hugging Face Candle 🕯️ 的集成,Distil-Whisper可用于Rust库 🦀。以下是使用步骤:
- 按照 这里 的说明安装
candle-core
。 - 本地克隆
candle
仓库:
git clone https://github.com/huggingface/candle.git
- 进入 Whisper 示例目录:
cd candle/candle-examples/examples/whisper
- 运行示例:
cargo run --example whisper --release --features symphonia -- --model distil-large-v3
- 要指定自己的音频文件,添加
--input
标志:
cargo run --example whisper --release --features symphonia -- --model distil-large-v3 --input audio.wav
提示:如果使用Apple Metal编译,在运行示例时指定 metal
特性:
cargo run --example whisper --release --features="symphonia,metal" -- --model distil-large-v3
如果遇到以下错误:
error: target `whisper` in package `candle-examples` requires the features: `symphonia`
Consider enabling them by passing, e.g., `--features="symphonia"`
你应该清理 cargo
安装:
cargo clean
然后重新编译:
cargo run --example whisper --release --features symphonia -- --model distil-large-v3
模型评估
以下是在LibriSpeech验证-clean数据集上使用 流式模式 评估Distil-Whisper模型的示例代码,这意味着无需将音频数据下载到本地设备:
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
from datasets import load_dataset
from evaluate import load
import torch
from tqdm import tqdm
# define our torch configuration
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = "distil-whisper/distil-large-v3"
# load the model + processor
model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id, torch_dtype=torch_dtype, use_safetensors=True, low_cpu_mem_usage=True)
model = model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
# load the dataset with streaming mode
dataset = load_dataset("librispeech_asr", "clean", split="validation", streaming=True)
# define the evaluation metric
wer_metric = load("wer")
def inference(batch):
# 1. Pre-process the audio data to log-mel spectrogram inputs
audio = [sample["array"] for sample in batch["audio"]]
input_features = processor(audio, sampling_rate=batch["audio"][0]["sampling_rate"], return_tensors="pt").input_features
input_features = input_features.to(device, dtype=torch_dtype)
# 2. Auto-regressively generate the predicted token ids
pred_ids = model.generate(input_features, max_new_tokens=128)
# 3. Decode the token ids to the final transcription
batch["transcription"] = processor.batch_decode(pred_ids, skip_special_tokens=True)
batch["reference"] = batch["text"]
return batch
# batch size 16 inference
dataset = dataset.map(function=inference, batched=True, batch_size=16)
all_transcriptions = []
all_references = []
# iterate over the dataset and run inference
for result in tqdm(dataset, desc="Evaluating..."):
all_transcriptions.append(result["transcription"])
all_references.append(result["reference"])
# normalize predictions and references
all_transcriptions = [processor.normalize(transcription) for transcription in all_transcriptions]
all_references = [processor.normalize(reference) for reference in all_references]
# compute the WER metric
wer = 100 * wer_metric.compute(predictions=all_transcriptions, references=all_references)
print(wer)
打印输出:
2.428920763531516
预期用途
Distil-Whisper旨在作为Whisper large-v3在英语语音识别中的直接替代品。特别是,它在分布外(OOD)测试数据上实现了相当的WER结果,同时在短格式和长格式音频上快6倍。
数据
Distil-Whisper在Hugging Face Hub上的九个开源、许可宽松的语音数据集的22,000小时音频数据上进行训练:
数据集 | 大小(小时) | 说话者数量 | 领域 | 许可证 |
---|---|---|---|---|
People's Speech | 12,000 | 未知 | Internet Archive | CC-BY-SA-4.0 |
Common Voice 13 | 3,000 | 未知 | Narrated Wikipedia | CC0-1.0 |
GigaSpeech | 2,500 | 未知 | Audiobook, podcast, YouTube | apache-2.0 |
Fisher | 1,960 | 11,900 | Telephone conversations | LDC |
LibriSpeech | 960 | 2,480 | Audiobooks | CC-BY-4.0 |
VoxPopuli | 540 | 1,310 | European Parliament | CC0 |
TED-LIUM | 450 | 2,030 | TED talks | CC-BY-NC-ND 3.0 |
SwitchBoard | 260 | 540 | Telephone conversations | LDC |
AMI | 100 | 未知 | Meetings | CC-BY-4.0 |
总计 | 21,770 | 18,260+ |
组合数据集涵盖10个不同领域和超过50k个说话者。这个数据集的多样性对于确保蒸馏模型对音频分布和噪声具有鲁棒性至关重要。然后,使用Whisper large-v3模型对音频数据进行伪标签:我们使用Whisper为训练集中的所有音频生成预测,并在训练期间将这些预测用作目标标签。使用伪标签确保转录在数据集之间格式一致,并在训练期间提供序列级蒸馏信号。
WER过滤
Whisper伪标签预测可能存在误转录和幻觉问题。为了确保只在准确的伪标签上进行训练,我们在训练期间采用了简单的WER启发式方法。首先,对Whisper伪标签和每个数据集提供的真实标签进行归一化。然后计算这些标签之间的WER。如果WER超过指定阈值,则丢弃该训练示例;否则,保留用于训练。Distil-Whisper论文 的第9.2节展示了此过滤器对于提高蒸馏模型下游性能的有效性。我们还部分将Distil-Whisper对幻觉的鲁棒性归因于这个过滤器。
训练
模型在批量大小为256的情况下进行了80,000次优化步骤(或11个周期)的训练。Tensorboard训练日志可以在以下链接找到:https://huggingface.co/distil-whisper/distil-large-v3/tensorboard?params=scalars#frame
结果
蒸馏模型在分布外(OOD)短格式音频上的WER与Whisper large-v3相差不到1.5%,在顺序长格式解码上相差不到1%,在分块长格式上比large-v3高0.1%。这种性能提升归因于更低的幻觉率。有关每个数据集评估结果的详细细分,请参考 Distil-Whisper论文 的表16和表17。Distil-Whisper还在 ESB基准 数据集上进行了评估,作为 OpenASR排行榜 的一部分,其WER与Whisper相差不到0.2%。
复现Distil-Whisper
复现Distil-Whisper的训练和评估代码可在Distil-Whisper仓库中找到:https://github.com/huggingface/distil-whisper/tree/main/training。此代码将很快更新,以包含 与distil-large-v2的差异 部分中描述的训练更新。
🔧 技术细节
模型架构
Distil-Whisper继承了Whisper的编码器 - 解码器架构。编码器将语音向量输入序列映射到隐藏状态向量序列。解码器根据所有先前的标记和编码器隐藏状态自回归地预测文本标记。因此,编码器只向前运行一次,而解码器运行的次数与生成的标记数量相同。实际上,这意味着解码器占总推理时间的90%以上。因此,为了优化延迟,重点是最小化解码器的推理时间。 为了蒸馏Whisper模型,我们在保持编码器固定的同时减少解码器层数。编码器(以绿色显示)从教师模型完全复制到学生模型,并在训练期间冻结。学生的解码器由教师解码器层的一个子集组成,这些层从最大间隔的层初始化。然后,模型在KL散度和伪标签损失项的加权和上进行训练。
在顺序解码算法下,还采用了其他技巧来提高distil-large-v3的性能,这些技巧将在即将发布的博客文章中详细解释。
📄 许可证
Distil-Whisper继承了OpenAI的Whisper模型的 MIT许可证。
引用
如果你使用此模型,请考虑引用 Distil-Whisper论文:
@misc{gandhi2023distilwhisper,
title={Distil-Whisper: Robust Knowledge Distillation via Large-Scale Pseudo Labelling},
author={Sanchit Gandhi and Patrick von Platen and Alexander M. Rush},
year={2023},
eprint={2311.00430},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
致谢
- OpenAI提供了Whisper 模型,特别是Jong Wook Kim提供了 原始代码库 并参与了训练讨论。
- Hugging Face 🤗 Transformers 实现了模型集成。
- Georgi Gerganov 实现了Whisper cpp集成。
- Systran团队 实现了Faster-Whisper集成。
- Joshua Lochner 实现了Transformers.js集成。
- Laurent Mazare 实现了Candle集成。
- Vaibhav Srivastav 负责Distil-Whisper的分发。
- Google的 TPU Research Cloud (TRC) 计划提供了Cloud TPU v4计算资源。
- Raghav Sonavane 提供了Distil-Whisper在LibriSpeech数据集上的早期版本。



