模型概述
模型特點
模型能力
使用案例
🚀 Distil-Whisper: distil-small.en
Distil-Whisper是一種經過蒸餾的語音識別模型,它基於論文 Robust Knowledge Distillation via Large-Scale Pseudo Labelling 提出。相比原始的Whisper模型,Distil-Whisper 推理速度快6倍,模型大小縮小49%,並且在分佈外評估集上的字錯率(WER)誤差 控制在1%以內。
本倉庫是distil-small.en的存儲庫,它是 Whisper small.en 的蒸餾變體。distil-small.en是 最小的Distil-Whisper檢查點,僅包含1.66億個參數,非常適合內存受限的應用場景(如設備端應用)。
對於大多數其他應用場景,建議使用 distil-medium.en 或 distil-large-v2 檢查點,因為它們不僅推理速度更快,而且字錯率(WER)表現更好:
模型 | 參數數量 / M | 相對延遲 ↑ | 短文本WER ↓ | 長文本WER ↓ |
---|---|---|---|---|
large-v3 | 1550 | 1.0 | 8.4 | 11.0 |
large-v2 | 1550 | 1.0 | 9.1 | 11.7 |
distil-large-v3 | 756 | 6.3 | 9.7 | 10.8 |
distil-large-v2 | 756 | 5.8 | 10.1 | 11.6 |
distil-medium.en | 394 | 6.8 | 11.1 | 12.4 |
distil-small.en | 166 | 5.6 | 12.1 | 12.8 |
⚠️ 重要提示
Distil-Whisper目前僅支持英語語音識別。我們正在與社區合作,對其他語言的Whisper模型進行蒸餾。如果您有興趣參與特定語言的蒸餾工作,請查看 訓練代碼。待多語言檢查點準備好後,我們會在 Distil-Whisper倉庫 中更新。
為什麼distil-small.en比distil-large-v2慢?
distil-medium.en 和 distil-large-v2 均使用兩層解碼器,而distil-small.en使用四層解碼器。增加解碼器層數可以提高模型的字錯率(WER)表現,但會降低推理速度。我們發現,對於 distil-small.en
,四層解碼器是獲得合理WER性能的最低要求,它在推理速度比Whisper large-v2 快5.6倍的同時,WER誤差控制在3%以內。當我們嘗試使用兩層解碼器進行蒸餾時,模型的WER比large-v2差5%以上,儘管推理速度快7.8倍。我們將蒸餾兩層的small.en模型作為未來的工作方向。
🚀 快速開始
Distil-Whisper從Hugging Face 🤗 Transformers 4.35版本開始得到支持。要運行該模型,首先需要安裝最新版本的Transformers庫。在本示例中,我們還將安裝 🤗 Datasets 庫,以便從Hugging Face Hub加載示例音頻數據集:
pip install --upgrade pip
pip install --upgrade transformers accelerate datasets[audio]
✨ 主要特性
- 速度快:相比原始的Whisper模型,Distil-Whisper推理速度快6倍。
- 模型小:模型大小縮小49%,適合內存受限的應用場景。
- 準確率高:在分佈外評估集上的字錯率(WER)誤差控制在1%以內。
- 支持多種解碼方式:支持短文本轉錄、長文本轉錄和推測解碼等多種解碼方式。
- 支持多種加速方法:支持Flash Attention、Torch Scale-Product-Attention (SDPA) 等加速方法。
- 支持多種框架:支持Hugging Face 🤗 Transformers、openai-whisper、Transformers.js等多種框架。
📦 安裝指南
安裝依賴庫
pip install --upgrade pip
pip install --upgrade transformers accelerate datasets[audio]
安裝其他可選依賴庫
- Flash Attention:
pip install flash-attn --no-build-isolation
- optimum:
pip install --upgrade optimum
- openai-whisper:
pip install --upgrade openai-whisper
- Transformers.js:
npm i @xenova/transformers
💻 使用示例
基礎用法
短文本轉錄
模型可以使用 pipeline
類對短文本音頻文件(< 30秒)進行轉錄,示例代碼如下:
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from datasets import load_dataset
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = "distil-whisper/distil-small.en"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
torch_dtype=torch_dtype,
device=device,
)
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
sample = dataset[0]["audio"]
result = pipe(sample)
print(result["text"])
如果要轉錄本地音頻文件,只需在調用pipeline時傳入音頻文件的路徑:
- result = pipe(sample)
+ result = pipe("audio.mp3")
長文本轉錄
Distil-Whisper使用分塊算法對長文本音頻文件(> 30秒)進行轉錄。實際上,這種分塊長文本算法比OpenAI在Whisper論文中提出的順序算法快9倍(詳見 Distil-Whisper論文 的表7)。
要啟用分塊功能,只需在調用pipeline時傳入 chunk_length_s
參數。對於Distil-Whisper,15秒的分塊長度是最優的。要啟用批量處理,只需傳入 batch_size
參數:
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from datasets import load_dataset
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = "distil-whisper/distil-small.en"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
chunk_length_s=15,
batch_size=16,
torch_dtype=torch_dtype,
device=device,
)
dataset = load_dataset("distil-whisper/librispeech_long", "default", split="validation")
sample = dataset[0]["audio"]
result = pipe(sample)
print(result["text"])
高級用法
推測解碼
Distil-Whisper可以作為Whisper的輔助模型,用於 推測解碼。推測解碼在數學上保證了與Whisper相同的輸出結果,同時推理速度快2倍。這使得它可以完美替代現有的Whisper推理流程,因為可以保證輸出結果一致。
在以下代碼示例中,我們將Distil-Whisper輔助模型獨立加載到主Whisper推理流程中,並將其指定為生成過程中的“輔助模型”:
from transformers import pipeline, AutoModelForSpeechSeq2Seq, AutoProcessor
import torch
from datasets import load_dataset
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
assistant_model_id = "distil-whisper/distil-small.en"
assistant_model = AutoModelForSpeechSeq2Seq.from_pretrained(
assistant_model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
assistant_model.to(device)
model_id = "openai/whisper-medium.en"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
generate_kwargs={"assistant_model": assistant_model},
torch_dtype=torch_dtype,
device=device,
)
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
sample = dataset[0]["audio"]
result = pipe(sample)
print(result["text"])
📚 詳細文檔
額外的速度和內存優化方法
Flash Attention
如果您的GPU支持,我們建議使用 Flash-Attention 2。要使用Flash Attention,首先需要安裝 Flash Attention:
pip install flash-attn --no-build-isolation
然後,在調用 from_pretrained
方法時傳入 use_flash_attention_2=True
參數:
- model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True)
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True, use_flash_attention_2=True)
Torch Scale-Product-Attention (SDPA)
如果您的GPU不支持Flash Attention,我們建議使用 BetterTransformers。要使用BetterTransformers,首先需要安裝optimum:
pip install --upgrade optimum
然後,在使用模型之前,將其轉換為“BetterTransformer”模型:
model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True)
+ model = model.to_bettertransformer()
在 openai-whisper
中運行Distil-Whisper
要在原始的Whisper格式中使用該模型,首先需要確保您已經安裝了 openai-whisper
包:
pip install --upgrade openai-whisper
以下代碼示例展示瞭如何使用 🤗 Datasets 庫加載LibriSpeech數據集中的示例文件並進行轉錄:
import torch
from datasets import load_dataset
from huggingface_hub import hf_hub_download
from whisper import load_model, transcribe
distil_small_en = hf_hub_download(repo_id="distil-whisper/distil-small.en", filename="original-model.bin")
model = load_model(distil_small_en)
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
sample = dataset[0]["audio"]["array"]
sample = torch.from_numpy(sample).float()
pred_out = transcribe(model, audio=sample)
print(pred_out["text"])
請注意,首次運行該示例時,模型權重將被下載並保存到緩存中。後續再次運行相同示例時,權重將直接從緩存中加載,無需再次下載。
如果要轉錄本地音頻文件,只需在調用 transcribe
方法時傳入音頻文件的路徑作為 audio
參數:
pred_out = transcribe(model, audio="audio.mp3")
在Transformers.js中運行Distil-Whisper
Distil-Whisper甚至可以使用 Transformers.js 在瀏覽器中完全運行:
- 從 NPM 安裝Transformers.js:
npm i @xenova/transformers
- 導入庫並使用pipeline API進行推理:
import { pipeline } from '@xenova/transformers';
const transcriber = await pipeline('automatic-speech-recognition', 'distil-whisper/distil-small.en');
const url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/jfk.wav';
const output = await transcriber(url);
// { text: " And so my fellow Americans, ask not what your country can do for you. Ask what you can do for your country." }
您可以訪問在線 Distil-Whisper Web演示 親自體驗。您會發現,它可以在本地瀏覽器中運行,無需服務器支持!
更多信息請參考 文檔。
8bit和4bit量化
該功能即將推出!
🔧 技術細節
模型架構
Distil-Whisper繼承了Whisper的編碼器 - 解碼器架構。編碼器將語音向量輸入序列映射到隱藏狀態向量序列,解碼器根據所有先前的標記和編碼器的隱藏狀態自迴歸地預測文本標記。因此,編碼器只需要進行一次前向傳播,而解碼器的運行次數與生成的標記數量相同。實際上,這意味著解碼器在總推理時間中佔比超過90%。因此,為了優化推理延遲,我們的重點是最小化解碼器的推理時間。
為了對Whisper模型進行蒸餾,我們在保持編碼器不變的情況下減少了解碼器的層數。編碼器(綠色部分)完全從教師模型複製到學生模型,並在訓練過程中凍結。學生模型的解碼器由教師模型解碼器層的子集組成,這些層從最大間隔的層初始化。然後,模型在KL散度和偽標籤損失項的加權和上進行訓練。
評估方法
以下代碼示例展示瞭如何使用 流式模式 在LibriSpeech驗證集的clean子集上評估Distil-Whisper模型,這意味著無需將音頻數據下載到本地設備:
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
from transformers.models.whisper.english_normalizer import EnglishTextNormalizer
from datasets import load_dataset
from evaluate import load
import torch
from tqdm import tqdm
# define our torch configuration
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = "distil-whisper/distil-small.en"
# load the model + processor
model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id, torch_dtype=torch_dtype, use_safetensors=True, low_cpu_mem_usage=True)
model = model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
# load the dataset with streaming mode
dataset = load_dataset("librispeech_asr", "clean", split="validation", streaming=True)
# define the evaluation metric
wer_metric = load("wer")
normalizer = EnglishTextNormalizer(processor.tokenizer.english_spelling_normalizer)
def inference(batch):
# 1. Pre-process the audio data to log-mel spectrogram inputs
audio = [sample["array"] for sample in batch["audio"]]
input_features = processor(audio, sampling_rate=batch["audio"][0]["sampling_rate"], return_tensors="pt").input_features
input_features = input_features.to(device, dtype=torch_dtype)
# 2. Auto-regressively generate the predicted token ids
pred_ids = model.generate(input_features, max_new_tokens=128)
# 3. Decode the token ids to the final transcription
batch["transcription"] = processor.batch_decode(pred_ids, skip_special_tokens=True)
batch["reference"] = batch["text"]
return batch
dataset = dataset.map(function=inference, batched=True, batch_size=16)
all_transcriptions = []
all_references = []
# iterate over the dataset and run inference
for i, result in tqdm(enumerate(dataset), desc="Evaluating..."):
all_transcriptions.append(result["transcription"])
all_references.append(result["reference"])
# normalize predictions and references
all_transcriptions = [normalizer(transcription) for transcription in all_transcriptions]
all_references = [normalizer(reference) for reference in all_references]
# compute the WER metric
wer = 100 * wer_metric.compute(predictions=all_transcriptions, references=all_references)
print(wer)
輸出結果:
3.4326070294536297
預期用途
Distil-Whisper旨在作為Whisper在英語語音識別任務中的直接替代品。特別是,它在分佈外測試數據上的字錯率(WER)表現與Whisper相當,同時在短文本和長文本音頻上的推理速度快6倍。
訓練數據
Distil-Whisper在Hugging Face Hub上的9個開源、許可寬鬆的語音數據集的22000小時音頻數據上進行訓練:
數據集 | 時長 / h | 說話人數量 | 領域 | 許可證 |
---|---|---|---|---|
People's Speech | 12,000 | 未知 | Internet Archive | CC-BY-SA-4.0 |
Common Voice 13 | 3,000 | 未知 | Narrated Wikipedia | CC0-1.0 |
GigaSpeech | 2,500 | 未知 | Audiobook, podcast, YouTube | apache-2.0 |
Fisher | 1,960 | 11,900 | Telephone conversations | LDC |
LibriSpeech | 960 | 2,480 | Audiobooks | CC-BY-4.0 |
VoxPopuli | 540 | 1,310 | European Parliament | CC0 |
TED-LIUM | 450 | 2,030 | TED talks | CC-BY-NC-ND 3.0 |
SwitchBoard | 260 | 540 | Telephone conversations | LDC |
AMI | 100 | 未知 | Meetings | CC-BY-4.0 |
總計 | 21,770 | 18,260+ |
這些數據集涵蓋了10個不同的領域和超過50000名說話人。數據集的多樣性對於確保蒸餾後的模型對音頻分佈和噪聲具有魯棒性至關重要。
然後,我們使用Whisper large-v2模型對音頻數據進行偽標籤標註:我們使用Whisper為訓練集中的所有音頻生成預測結果,並在訓練過程中使用這些結果作為目標標籤。使用偽標籤可以確保轉錄結果在不同數據集之間保持一致的格式,並在訓練過程中提供序列級別的蒸餾信號。
WER過濾
Whisper的偽標籤預測可能會出現轉錄錯誤和幻覺問題。為了確保我們只在準確的偽標籤上進行訓練,我們在訓練過程中採用了一種簡單的WER啟發式方法。首先,我們對Whisper的偽標籤和每個數據集提供的真實標籤進行歸一化處理。然後,我們計算這些標籤之間的WER。如果WER超過指定的閾值,我們將丟棄該訓練示例;否則,我們將其保留用於訓練。
Distil-Whisper論文 的第9.2節展示了這種過濾方法對於提高蒸餾模型下游性能的有效性。我們還將Distil-Whisper對幻覺問題的魯棒性部分歸因於這種過濾方法。
訓練過程
模型在批量大小為2056的情況下進行了50000次優化步驟(或12個epoch)的訓練。Tensorboard訓練日誌可以在以下鏈接中找到:https://huggingface.co/distil-whisper/distil-small.en/tensorboard?params=scalars#frame
評估結果
蒸餾後的模型在分佈外(OOD)短文本音頻上的字錯率(WER)與Whisper相差在1%以內,在OOD長文本音頻上的表現比Whisper好0.1%。這種性能提升歸因於較低的幻覺率。
有關每個數據集評估結果的詳細細分,請參考 Distil-Whisper論文 的表16和表17。
Distil-Whisper還在 ESB基準 數據集上進行了評估,作為 OpenASR排行榜 的一部分,其WER與Whisper相差在0.2%以內。
復現Distil-Whisper
復現Distil-Whisper的訓練和評估代碼可以在Distil-Whisper倉庫中找到:https://github.com/huggingface/distil-whisper/tree/main/training
📄 許可證
Distil-Whisper繼承了OpenAI的Whisper模型的 MIT許可證。
引用
如果您使用了該模型,請考慮引用 Distil-Whisper論文:
@misc{gandhi2023distilwhisper,
title={Distil-Whisper: Robust Knowledge Distillation via Large-Scale Pseudo Labelling},
author={Sanchit Gandhi and Patrick von Platen and Alexander M. Rush},
year={2023},
eprint={2311.00430},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
致謝
- OpenAI提供了Whisper 模型 和 原始代碼庫。
- Hugging Face 🤗 Transformers 提供了模型集成支持。
- Google的 TPU Research Cloud (TRC) 項目提供了Cloud TPU v4資源。
@rsonavane
在LibriSpeech數據集上發佈了早期版本的Distil-Whisper。



