模型概述
模型特點
模型能力
使用案例
🚀 Distil-Whisper: distil-large-v3
Distil-Whisper是一個用於醫學語音識別的微調工作空間。該模型會經常更新,若你覺得當前版本對你有用,可直接複製該工作空間。Distil-Whisper在論文 Robust Knowledge Distillation via Large-Scale Pseudo Labelling 中被提出,本版本distil-large-v3是Distil-Whisper英語系列的第三個也是最後一個版本,它是OpenAI的 Whisper large-v3 的知識蒸餾版本,是目前最新且性能最佳的Whisper模型。
🚀 快速開始
本項目是Distil-Whisper的distil-large-v3版本,用於自動語音識別任務。以下是使用該模型的基本步驟:
- 安裝必要的庫,如
transformers
、datasets
等。 - 加載模型和處理器。
- 準備音頻數據。
- 進行語音識別推理。
以下是一個簡單的示例代碼:
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from datasets import load_dataset
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = "distil-whisper/distil-large-v3"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
torch_dtype=torch_dtype,
device=device,
)
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
sample = dataset[0]["audio"]
result = pipe(sample)
print(result["text"])
✨ 主要特性
- 長文本轉錄準確性高:與之前的Distil-Whisper模型相比,distil-large-v3的蒸餾過程經過調整,使用OpenAI的順序長文本算法時,具有卓越的長文本轉錄準確性。
- 速度更快:比之前的Distil-Whisper模型更快,比large-v3快6.3倍,比distil-large-v2快1.1倍。
- 兼容性強:與最流行的Whisper庫(Whisper cpp、Faster-Whisper、OpenAI Whisper)兼容,使用這些庫時,從之前的Distil-Whisper檢查點切換到distil-large-v3可獲得顯著的性能提升。
- 支持多種算法:支持短文本轉錄、順序長文本轉錄、分塊長文本轉錄和推測解碼等多種算法。
模型性能對比
模型 | 參數數量(百萬) | 相對延遲 | 短文本 | 順序長文本 | 分塊長文本 |
---|---|---|---|---|---|
large-v3 | 1550 | 1.0 | 8.4 | 10.0 | 11.0 |
distil-large-v3 | 756 | 6.3 | 9.7 | 10.8 | 10.9 |
distil-large-v2 | 756 | 5.8 | 10.1 | 15.6 | 11.6 |
📦 安裝指南
安裝依賴庫
pip install --upgrade pip
pip install --upgrade transformers accelerate datasets[audio]
安裝特定庫以支持不同使用場景
Whisper.cpp
git clone https://github.com/ggerganov/whisper.cpp.git
cd whisper.cpp
pip install --upgrade huggingface_hub
下載GGML權重:
from huggingface_hub import hf_hub_download
hf_hub_download(repo_id='distil-whisper/distil-large-v3-ggml', filename='ggml-distil-large-v3.bin', local_dir='./models')
或使用 wget
下載:
wget https://huggingface.co/distil-whisper/distil-large-v3-ggml/resolve/main/ggml-distil-large-v3.bin -P ./models
Faster-Whisper
pip install --upgrade pip
pip install --upgrade git+https://github.com/SYSTRAN/faster-whisper datasets[audio]
OpenAI Whisper
pip install --upgrade pip
pip install --upgrade openai-whisper datasets[audio]
Transformers.js
npm i @xenova/transformers
Candle
- 安裝
candle-core
:按照 這裡 的說明安裝。 - 克隆
candle
倉庫:
git clone https://github.com/huggingface/candle.git
- 進入Whisper示例目錄:
cd candle/candle-examples/examples/whisper
💻 使用示例
基礎用法
短文本轉錄
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from datasets import load_dataset
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = "distil-whisper/distil-large-v3"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
torch_dtype=torch_dtype,
device=device,
)
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
sample = dataset[0]["audio"]
result = pipe(sample)
print(result["text"])
轉錄本地音頻文件:
- result = pipe(sample)
+ result = pipe("audio.mp3")
獲取分段級時間戳:
result = pipe(sample, return_timestamps=True)
print(result["chunks"])
順序長文本轉錄
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from datasets import load_dataset
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = "distil-whisper/distil-large-v3"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
torch_dtype=torch_dtype,
device=device,
)
dataset = load_dataset("distil-whisper/librispeech_long", "clean", split="validation")
sample = dataset[0]["audio"]
result = pipe(sample)
print(result["text"])
分塊長文本轉錄
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from datasets import load_dataset
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = "distil-whisper/distil-large-v3"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
chunk_length_s=25,
batch_size=16,
torch_dtype=torch_dtype,
device=device,
)
dataset = load_dataset("distil-whisper/librispeech_long", "clean", split="validation")
sample = dataset[0]["audio"]
result = pipe(sample)
print(result["text"])
高級用法
推測解碼
from transformers import pipeline, AutoModelForCausalLM, AutoModelForSpeechSeq2Seq, AutoProcessor
import torch
from datasets import load_dataset
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
assistant_model_id = "distil-whisper/distil-large-v3"
assistant_model = AutoModelForCausalLM.from_pretrained(
assistant_model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
assistant_model.to(device)
model_id = "openai/whisper-large-v3"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
generate_kwargs={"assistant_model": assistant_model},
torch_dtype=torch_dtype,
device=device,
)
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
sample = dataset[0]["audio"]
result = pipe(sample)
print(result["text"])
額外的速度和內存優化
Flash Attention 2
- model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True)
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True, attn_implementation="flash_attention_2")
Torch Scale-Product-Attention (SDPA)
- model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True)
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True, attn_implementation="sdpa")
📚 詳細文檔
模型細節
Distil-Whisper繼承了Whisper的編碼器 - 解碼器架構。編碼器將語音向量輸入序列映射到隱藏狀態向量序列,解碼器根據所有先前的令牌和編碼器隱藏狀態自迴歸地預測文本令牌。因此,編碼器只向前運行一次,而解碼器運行的次數與生成的令牌數量相同。在實踐中,這意味著解碼器佔總推理時間的90%以上。為了優化延遲,重點是最小化解碼器的推理時間。
為了蒸餾Whisper模型,我們在保持編碼器固定的同時減少了解碼器層的數量。編碼器(綠色部分)完全從教師模型複製到學生模型,並在訓練期間凍結。學生模型的解碼器由教師解碼器層的一個子集組成,這些層從最大間隔的層初始化。然後,模型在KL散度和偽標籤損失項的加權和上進行訓練。
與distil-large-v2的差異
與之前的Distil-Whisper版本相比,distil-large-v3專門針對OpenAI順序長文本轉錄算法進行了設計。與distil-large-v2相比,除了模型層是從最新的large-v3模型而不是較舊的large-v2模型初始化之外,沒有架構上的差異。差異在於模型的訓練方式。
之前的Distil-Whisper模型在平均輸入長度為7秒的情況下進行訓練,而原始的Whisper模型在30秒的輸入上進行預訓練。在蒸餾過程中,我們將模型權重的分佈轉移到訓練數據的分佈上。如果我們的訓練數據包含較短的話語(例如,平均7秒的音頻而不是30秒),那麼預測分佈將轉移到這個較短的上下文長度。在推理時,distil-large-v2的最佳上下文窗口是這兩個值的插值:15秒。超過這個時間,distil-large-v2模型的預測在很大程度上是不準確的,特別是對於時間戳預測。然而,順序長文本算法使用30秒的滑動窗口進行推理,窗口根據最後預測的時間戳進行移動。由於最後一個時間戳通常發生在15秒標記之後,其預測的準確性較低,導致長文本轉錄經常失敗。
為了保留Whisper轉錄30秒滑動窗口的能力,就像順序解碼那樣,我們需要確保distil-large-v3的上下文長度也是30秒。這主要通過以下四種策略實現:
- 將訓練數據集中的音頻樣本打包到30秒:由於模型在打包到30秒的音頻數據上進行預訓練和蒸餾,distil-large-v3現在在與Whisper相同的理想上下文窗口上運行,能夠準確預測長達30秒的時間戳。
- 凍結解碼器輸入嵌入:我們使用與原始模型相同的輸入嵌入表示,該表示旨在處理比之前的Distil-Whisper迭代更長的上下文長度。
- 在訓練期間使用更長的最大上下文長度:我們在最大目標長度為256的情況下進行訓練,而不是128。這有助於distil-large-v3轉錄可能超過128個令牌的30秒片段。
- 將提示條件附加到50%的訓練樣本上:使模型能夠與
condition_on_prev_tokens
參數一起使用,以及處理長達448個令牌的上下文窗口。
評估
以下代碼展示瞭如何在LibriSpeech驗證清潔數據集上使用 流式模式 評估Distil-Whisper模型,即無需將音頻數據下載到本地設備。
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
from datasets import load_dataset
from evaluate import load
import torch
from tqdm import tqdm
# define our torch configuration
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = "distil-whisper/distil-large-v3"
# load the model + processor
model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id, torch_dtype=torch_dtype, use_safetensors=True, low_cpu_mem_usage=True)
model = model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
# load the dataset with streaming mode
dataset = load_dataset("librispeech_asr", "clean", split="validation", streaming=True)
# define the evaluation metric
wer_metric = load("wer")
def inference(batch):
# 1. Pre-process the audio data to log-mel spectrogram inputs
audio = [sample["array"] for sample in batch["audio"]]
input_features = processor(audio, sampling_rate=batch["audio"][0]["sampling_rate"], return_tensors="pt").input_features
input_features = input_features.to(device, dtype=torch_dtype)
# 2. Auto-regressively generate the predicted token ids
pred_ids = model.generate(input_features, max_new_tokens=128)
# 3. Decode the token ids to the final transcription
batch["transcription"] = processor.batch_decode(pred_ids, skip_special_tokens=True)
batch["reference"] = batch["text"]
return batch
# batch size 16 inference
dataset = dataset.map(function=inference, batched=True, batch_size=16)
all_transcriptions = []
all_references = []
# iterate over the dataset and run inference
for result in tqdm(dataset, desc="Evaluating..."):
all_transcriptions.append(result["transcription"])
all_references.append(result["reference"])
# normalize predictions and references
all_transcriptions = [processor.normalize(transcription) for transcription in all_transcriptions]
all_references = [processor.normalize(reference) for reference in all_references]
# compute the WER metric
wer = 100 * wer_metric.compute(predictions=all_transcriptions, references=all_references)
print(wer)
預期用途
Distil-Whisper旨在作為Whisper large-v3在英語語音識別中的直接替代品。特別是,它在分佈外(OOD)測試數據上實現了可比的WER結果,同時在短文本和長文本音頻上快6倍。
數據
Distil-Whisper在來自Hugging Face Hub上九個開源、許可寬鬆的語音數據集的22,000小時音頻數據上進行訓練:
數據集 | 大小(小時) | 說話者數量 | 領域 | 許可證 |
---|---|---|---|---|
People's Speech | 12,000 | 未知 | 互聯網存檔 | CC-BY-SA-4.0 |
Common Voice 13 | 3,000 | 未知 | 旁白維基百科 | CC0-1.0 |
GigaSpeech | 2,500 | 未知 | 有聲讀物、播客、YouTube | apache-2.0 |
Fisher | 1,960 | 11,900 | 電話對話 | LDC |
LibriSpeech | 960 | 2,480 | 有聲讀物 | CC-BY-4.0 |
VoxPopuli | 540 | 1,310 | 歐洲議會 | CC0 |
TED-LIUM | 450 | 2,030 | TED演講 | CC-BY-NC-ND 3.0 |
SwitchBoard | 260 | 540 | 電話對話 | LDC |
AMI | 100 | 未知 | 會議 | CC-BY-4.0 |
總計 | 21,770 | 18,260+ |
組合數據集涵蓋10個不同領域和超過50k個說話者。這種數據集的多樣性對於確保蒸餾模型對音頻分佈和噪聲具有魯棒性至關重要。
音頻數據然後使用Whisper large-v3模型進行偽標籤:我們使用Whisper為訓練集中的所有音頻生成預測,並在訓練期間將這些預測用作目標標籤。使用偽標籤確保轉錄在數據集之間具有一致的格式,並在訓練期間提供序列級蒸餾信號。
WER過濾
Whisper偽標籤預測可能會出現誤轉錄和幻覺。為了確保我們只在準確的偽標籤上進行訓練,我們在訓練期間採用了一種簡單的WER啟發式方法。首先,我們對Whisper偽標籤和每個數據集提供的真實標籤進行歸一化。然後,我們計算這些標籤之間的WER。如果WER超過指定的閾值,我們丟棄該訓練示例。否則,我們將其保留用於訓練。
Distil-Whisper論文 的第9.2節展示了這種過濾方法對於提高蒸餾模型下游性能的有效性。我們還將Distil-Whisper對幻覺的魯棒性部分歸因於這種過濾方法。
訓練
模型在批量大小為256的情況下進行了80,000次優化步驟(或11個epoch)的訓練。Tensorboard訓練日誌可在 這裡 找到。
結果
蒸餾模型在分佈外(OOD)短文本音頻上的WER與Whisper large-v3相差在1.5%以內,在順序長文本解碼上相差在1%以內,在分塊長文本上比large-v3高0.1%。這種性能提升歸因於較低的幻覺率。
有關評估結果的詳細數據集細分,請參考 Distil-Whisper論文 的表16和表17。
Distil-Whisper還在 ESB基準 數據集上進行了評估,作為 OpenASR排行榜 的一部分,其WER與Whisper相差在0.2%以內。
復現Distil-Whisper
復現Distil-Whisper的訓練和評估代碼可在 Distil-Whisper倉庫 中找到。該代碼將很快更新,以包括 與distil-large-v2的差異 部分中描述的訓練更新。
🔧 技術細節
模型架構
Distil-Whisper繼承了Whisper的編碼器 - 解碼器架構。編碼器將語音向量輸入序列映射到隱藏狀態向量序列,解碼器根據所有先前的令牌和編碼器隱藏狀態自迴歸地預測文本令牌。
蒸餾方法
在蒸餾過程中,保持編碼器固定,減少解碼器層的數量。編碼器完全從教師模型複製到學生模型,並在訓練期間凍結。學生模型的解碼器由教師解碼器層的一個子集組成,這些層從最大間隔的層初始化。
訓練數據處理
使用Whisper large-v3模型為訓練數據生成偽標籤,確保轉錄在數據集之間具有一致的格式,並在訓練期間提供序列級蒸餾信號。同時,採用WER過濾方法,丟棄WER超過指定閾值的訓練示例,以確保只在準確的偽標籤上進行訓練。
📄 許可證
Distil-Whisper繼承了OpenAI的Whisper模型的 MIT許可證。
引用
如果您使用此模型,請考慮引用 Distil-Whisper論文:
@misc{gandhi2023distilwhisper,
title={Distil-Whisper: Robust Knowledge Distillation via Large-Scale Pseudo Labelling},
author={Sanchit Gandhi and Patrick von Platen and Alexander M. Rush},
year={2023},
eprint={2311.00430},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
致謝
- OpenAI提供了Whisper 模型,特別是Jong Wook Kim提供了 原始代碼庫 並參與了訓練討論。
- Hugging Face 🤗 Transformers 實現了模型集成。
- Georgi Gerganov 實現了Whisper cpp集成。
- Systran團隊 實現了Faster-Whisper集成。
- Joshua Lochner 實現了Transformers.js集成。
- Laurent Mazare 實現了Candle集成。
- Vaibhav Srivastav 負責Distil-Whisper的分發。
- Google的 TPU Research Cloud (TRC) 計劃提供了Cloud TPU v4計算資源。
- Raghav Sonavane 提供了Distil-Whisper在LibriSpeech數據集上的早期版本。



