模型概述
模型特點
模型能力
使用案例
🚀 Distil-Whisper: distil-medium.en
Distil-Whisper是一個用於英語自動語音識別的模型。它是Whisper模型的蒸餾版本,速度提升了6倍,模型大小減小了49%,並且在分佈外評估集上的字錯率(WER)與原模型相差在1%以內。
🚀 快速開始
Distil-Whisper從Hugging Face 🤗 Transformers的4.35版本開始得到支持。要運行該模型,首先需要安裝最新版本的Transformers庫。在這個示例中,我們還將安裝🤗 Datasets,以便從Hugging Face Hub加載玩具音頻數據集:
pip install --upgrade pip
pip install --upgrade transformers accelerate datasets[audio]
✨ 主要特性
- 速度快:相比原Whisper模型,速度提升了6倍。
- 模型小:模型大小減小了49%。
- 性能接近:在分佈外評估集上的字錯率(WER)與原模型相差在1%以內。
- 支持多種使用方式:支持短音頻轉錄、長音頻轉錄、推測解碼等。
- 可進行額外優化:可以通過Flash Attention、Torch Scale-Product-Attention等方式進一步提升速度和內存使用效率。
💻 使用示例
基礎用法
短音頻轉錄
模型可以使用pipeline
類對短音頻文件(< 30秒)進行轉錄,示例代碼如下:
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from datasets import load_dataset
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = "distil-whisper/distil-medium.en"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
torch_dtype=torch_dtype,
device=device,
)
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
sample = dataset[0]["audio"]
result = pipe(sample)
print(result["text"])
要轉錄本地音頻文件,只需在調用pipeline時傳入音頻文件的路徑:
- result = pipe(sample)
+ result = pipe("audio.mp3")
長音頻轉錄
Distil-Whisper使用分塊算法對長音頻文件(> 30秒)進行轉錄。實際上,這種分塊長音頻算法比OpenAI在Whisper論文中提出的順序算法快9倍(見Distil-Whisper論文的表7)。
要啟用分塊功能,需要在pipeline
中傳入chunk_length_s
參數。對於Distil-Whisper,15秒的分塊長度是最優的。要啟用批處理,需要傳入batch_size
參數:
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from datasets import load_dataset
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = "distil-whisper/distil-medium.en"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
chunk_length_s=15,
batch_size=16,
torch_dtype=torch_dtype,
device=device,
)
dataset = load_dataset("distil-whisper/librispeech_long", "default", split="validation")
sample = dataset[0]["audio"]
result = pipe(sample)
print(result["text"])
推測解碼
Distil-Whisper可以作為Whisper的輔助模型用於推測解碼。推測解碼在數學上保證了與Whisper相同的輸出,同時速度提高了2倍。這使得它成為現有Whisper管道的完美替代品,因為可以保證相同的輸出。 在以下代碼片段中,我們將輔助Distil-Whisper模型獨立加載到主Whisper管道中。然後,我們將其指定為生成的“輔助模型”:
from transformers import pipeline, AutoModelForCausalLM, AutoModelForSpeechSeq2Seq, AutoProcessor
import torch
from datasets import load_dataset
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
assistant_model_id = "distil-whisper/distil-medium.en"
assistant_model = AutoModelForCausalLM.from_pretrained(
assistant_model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
assistant_model.to(device)
model_id = "openai/whisper-medium.en"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
generate_kwargs={"assistant_model": assistant_model},
torch_dtype=torch_dtype,
device=device,
)
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
sample = dataset[0]["audio"]
result = pipe(sample)
print(result["text"])
高級用法
額外的速度和內存優化
你可以對Distil-Whisper應用額外的速度和內存優化,具體如下:
Flash Attention
如果你的GPU支持,我們建議使用Flash-Attention 2。為此,你首先需要安裝Flash Attention:
pip install flash-attn --no-build-isolation
然後,你只需要在from_pretrained
中傳入use_flash_attention_2=True
:
- model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True)
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True, use_flash_attention_2=True)
Torch Scale-Product-Attention (SDPA)
如果你的GPU不支持Flash Attention,我們建議使用BetterTransformers。為此,你首先需要安裝optimum:
pip install --upgrade optimum
然後,在使用模型之前將其轉換為“BetterTransformer”模型:
model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True)
+ model = model.to_bettertransformer()
在openai-whisper
中運行Distil-Whisper
要以原始Whisper格式使用該模型,首先要確保你已經安裝了openai-whisper
包:
pip install --upgrade openai-whisper
以下代碼片段展示瞭如何轉錄使用🤗 Datasets加載的LibriSpeech數據集中的樣本文件:
import torch
from datasets import load_dataset
from huggingface_hub import hf_hub_download
from whisper import load_model, transcribe
medium_en = hf_hub_download(repo_id="distil-whisper/distil-medium.en", filename="original-model.bin")
model = load_model(medium_en)
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
sample = dataset[0]["audio"]["array"]
sample = torch.from_numpy(sample).float()
pred_out = transcribe(model, audio=sample)
print(pred_out["text"])
要轉錄本地音頻文件,只需將音頻文件的路徑作為audio
參數傳遞給transcribe
:
pred_out = transcribe(model, audio="audio.mp3")
Whisper.cpp
Distil-Whisper可以使用Whisper.cpp倉庫中的原始順序長音頻轉錄算法運行。在Mac M1上的臨時基準測試中,distil-medium.en
比large-v2
快4倍,同時在長音頻上的字錯率(WER)相差在1%以內。
開始使用的步驟如下:
- 克隆Whisper.cpp倉庫:
git clone https://github.com/ggerganov/whisper.cpp.git
cd whisper.cpp
- 從Hugging Face Hub下載
distil-medium.en
的ggml權重:
python -c "from huggingface_hub import hf_hub_download; hf_hub_download(repo_id='distil-whisper/distil-medium.en', filename='ggml-medium-32-2.en.bin', local_dir='./models')"
如果你沒有安裝huggingface_hub
包,也可以使用wget
下載權重:
wget https://huggingface.co/distil-whisper/distil-medium.en/resolve/main/ggml-medium-32-2.en.bin -P ./models
- 使用提供的示例音頻運行推理:
make -j && ./main -m models/ggml-medium-32-2.en.bin -f samples/jfk.wav
Transformers.js
import { pipeline } from '@xenova/transformers';
let transcriber = await pipeline('automatic-speech-recognition', 'distil-whisper/distil-medium.en');
let url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/jfk.wav';
let output = await transcriber(url);
// { text: " And so my fellow Americans, ask not what your country can do for you. Ask what you can do for your country." }
更多信息請參閱文檔。
Candle
通過與Hugging Face Candle 🕯️集成,Distil-Whisper現在可以在Rust庫🦀中使用。 優點如下:
- 優化的CPU後端,支持x86的MKL和Mac的Accelerate。
- CUDA後端,可在GPU上高效運行,支持通過NCCL進行多GPU分發。
- WASM支持:可以在瀏覽器中運行Distil-Whisper。 開始使用的步驟如下:
- 按照這裡的說明安裝
candle-core
。 - 本地克隆
candle
倉庫:
git clone https://github.com/huggingface/candle.git
- 進入Whisper的示例目錄:
cd candle/candle-examples/examples/whisper
- 運行示例:
cargo run --example whisper --release -- --model distil-medium.en
- 要指定自己的音頻文件,添加
--input
標誌:
cargo run --example whisper --release -- --model distil-medium.en --input audio.wav
🔧 技術細節
Distil-Whisper繼承了Whisper的編碼器 - 解碼器架構。編碼器將語音向量輸入序列映射到隱藏狀態向量序列。解碼器根據所有先前的標記和編碼器隱藏狀態自迴歸地預測文本標記。因此,編碼器只向前運行一次,而解碼器運行的次數與生成的標記數量相同。實際上,這意味著解碼器佔總推理時間的90%以上。因此,為了優化延遲,應該重點關注最小化解碼器的推理時間。 為了蒸餾Whisper模型,我們在保持編碼器固定的同時減少了解碼器層的數量。編碼器(綠色部分)完全從教師模型複製到學生模型,並在訓練期間凍結。學生模型的解碼器僅由兩個解碼器層組成,這些層從教師模型的第一個和最後一個解碼器層初始化(紅色部分)。教師模型的所有其他解碼器層都被丟棄。然後,模型在KL散度和偽標籤損失項的加權和上進行訓練。
📚 詳細文檔
評估
以下代碼片段展示瞭如何在LibriSpeech驗證集的clean
子集上使用流式模式評估Distil-Whisper模型,這意味著不需要將音頻數據下載到本地設備。
首先,我們需要安裝所需的包,包括🤗 Datasets用於流式加載音頻數據,以及🤗 Evaluate用於進行WER計算:
pip install --upgrade pip
pip install --upgrade transformers datasets[audio] evaluate jiwer
然後可以使用以下示例端到端地運行評估:
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
from transformers.models.whisper.english_normalizer import EnglishTextNormalizer
from datasets import load_dataset
from evaluate import load
import torch
from tqdm import tqdm
# define our torch configuration
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = "distil-whisper/distil-medium.en"
# load the model + processor
model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id, torch_dtype=torch_dtype, use_safetensors=True, low_cpu_mem_usage=True)
model = model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
# load the dataset with streaming mode
dataset = load_dataset("librispeech_asr", "clean", split="validation", streaming=True)
# define the evaluation metric
wer_metric = load("wer")
normalizer = EnglishTextNormalizer(processor.tokenizer.english_spelling_normalizer)
def inference(batch):
# 1. Pre-process the audio data to log-mel spectrogram inputs
audio = [sample["array"] for sample in batch["audio"]]
input_features = processor(audio, sampling_rate=batch["audio"][0]["sampling_rate"], return_tensors="pt").input_features
input_features = input_features.to(device, dtype=torch_dtype)
# 2. Auto-regressively generate the predicted token ids
pred_ids = model.generate(input_features, max_new_tokens=128)
# 3. Decode the token ids to the final transcription
batch["transcription"] = processor.batch_decode(pred_ids, skip_special_tokens=True)
batch["reference"] = batch["text"]
return batch
dataset = dataset.map(function=inference, batched=True, batch_size=16)
all_transcriptions = []
all_references = []
# iterate over the dataset and run inference
for i, result in tqdm(enumerate(dataset), desc="Evaluating..."):
all_transcriptions.append(result["transcription"])
all_references.append(result["reference"])
# normalize predictions and references
all_transcriptions = [normalizer(transcription) for transcription in all_transcriptions]
all_references = [normalizer(reference) for reference in all_references]
# compute the WER metric
wer = 100 * wer_metric.compute(predictions=all_transcriptions, references=all_references)
print(wer)
打印輸出:
3.593196832001168
預期用途
Distil-Whisper旨在作為Whisper在英語語音識別中的直接替代品。特別是,它在分佈外測試數據上實現了相當的WER結果,同時在短音頻和長音頻上都快6倍。
數據
Distil-Whisper在Hugging Face Hub上的9個開源、許可寬鬆的語音數據集的22,000小時音頻數據上進行訓練:
屬性 | 詳情 |
---|---|
模型類型 | Distil-Whisper是Whisper模型的蒸餾版本,繼承了編碼器 - 解碼器架構。 |
訓練數據 | 訓練數據來自9個開源、許可寬鬆的語音數據集,包括People's Speech、Common Voice 13、GigaSpeech等,總時長約21,770小時,涵蓋10個不同領域和超過50k個說話者。 |
WER過濾
Whisper的偽標籤預測可能存在誤轉錄和幻覺問題。為了確保我們只在準確的偽標籤上進行訓練,我們在訓練期間採用了簡單的WER啟發式方法。首先,我們對Whisper的偽標籤和每個數據集提供的真實標籤進行歸一化。然後,我們計算這些標籤之間的WER。如果WER超過指定的閾值,我們將丟棄該訓練示例。否則,我們將其保留用於訓練。 Distil-Whisper論文的第9.2節展示了該過濾器對提高蒸餾模型下游性能的有效性。我們還部分將Distil-Whisper對幻覺的魯棒性歸因於該過濾器。
訓練
模型進行了80,000次優化步驟(即8個epoch)的訓練。Tensorboard訓練日誌可以在以下鏈接找到:https://huggingface.co/distil-whisper/distil-medium.en/tensorboard?params=scalars#frame
結果
蒸餾後的模型在分佈外(OOD)短音頻上的WER與Whisper相差在1%以內,在OOD長音頻上比Whisper高0.1%。這種性能提升歸因於較低的幻覺。 有關每個數據集評估結果的詳細細分,請參閱Distil-Whisper論文的表16和表17。 Distil-Whisper還在ESB基準數據集上進行了評估,作為OpenASR排行榜的一部分,其WER與Whisper相差在0.2%以內。
復現Distil-Whisper
復現Distil-Whisper的訓練和評估代碼可以在Distil-Whisper倉庫中找到:https://github.com/huggingface/distil-whisper/tree/main/training
📄 許可證
Distil-Whisper繼承了OpenAI的Whisper模型的MIT許可證。
引用
如果你使用了這個模型,請考慮引用Distil-Whisper論文:
@misc{gandhi2023distilwhisper,
title={Distil-Whisper: Robust Knowledge Distillation via Large-Scale Pseudo Labelling},
author={Sanchit Gandhi and Patrick von Platen and Alexander M. Rush},
year={2023},
eprint={2311.00430},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
致謝
- OpenAI提供了Whisper 模型和原始代碼庫。
- Hugging Face 🤗 Transformers實現了模型集成。
- Google的TPU研究雲(TRC)計劃提供了Cloud TPU v4s。
@rsonavane
在LibriSpeech數據集上發佈了Distil-Whisper的早期版本。



