模型概述
模型特點
模型能力
使用案例
🚀 Distil-Whisper: Distil-Large-v3.5
Distil-Whisper是OpenAI的Whisper-Large-v3的知識蒸餾版本,相關內容在論文Robust Knowledge Distillation via Large-Scale Pseudo Labelling中有所描述。作為Distil-Whisper英語系列的最新成員,Distil-Large-v3.5在保持前代高效性的同時,展現出了更優的性能。
與早期模型相比,Distil-Large-v3.5在超過4倍(98000小時)更多樣化的公共數據上進行了訓練,並且在蒸餾過程中使用了具有擴展訓練計劃和激進數據增強(SpecAugment)的"耐心"教師。這使得它與之前的Distil-Whisper模型相比,魯棒性和準確性都得到了增強,適合作為直接替代品。
模型 | 參數數量(百萬) | 相對即時因子 | 短文本離群詞錯誤率 | 長文本離群詞錯誤率 |
---|---|---|---|---|
large-v3-turbo | 809 | 1.0 | 7.30 | 10.25 |
distil-large-v3 | 756 | 1.44 | 7.53 | 11.6 |
distil-large-v3.5 | 756 | 1.46 | 7.08 | 11.39 |
既然已經有了Whisper-Large-v3-Turbo,為什麼還要考慮Distil-Large-v3.5呢?
- 它在準確性和效率之間提供了不同的平衡,比Whisper-Large-v3-Turbo 快約1.5倍,同時在短文本轉錄上表現略好,在長文本轉錄上落後約1%。
- 它可以完美地作為與Whisper-Large-v3進行推測解碼的草稿模型。通過在訓練期間凍結編碼器,我們只需要加載兩個額外的解碼器層,並僅對編碼器進行一次前向傳播。這使得推理速度比Whisper-Large-v3快約2倍,同時保持輸出結果相同。
該模型是Bofeng Huang、Eustache Le Bihan、Steven Zheng和Vaibhav Srivastav在🤗上的合作成果。
🚀 快速開始
Distil-Large-v3.5從Hugging Face 🤗 Transformers庫的4.39版本開始得到支持。要運行該模型,首先需要安裝最新版本的Transformers。在這個示例中,我們還將安裝🤗 Datasets,以便從Hugging Face Hub加載一個玩具音頻數據集:
pip install --upgrade pip
pip install --upgrade transformers accelerate datasets[audio]
✨ 主要特性
- 知識蒸餾版本,在保持高效性的同時提升性能。
- 訓練數據更多樣化,魯棒性和準確性增強。
- 在準確性和效率之間提供不同平衡。
- 可作為推測解碼的草稿模型,提升推理速度。
- 支持多種庫集成,方便不同場景使用。
📦 安裝指南
安裝Transformers及相關庫
pip install --upgrade pip
pip install --upgrade transformers accelerate datasets[audio]
不同庫集成的安裝步驟
Whisper.cpp
git clone https://github.com/ggerganov/whisper.cpp.git
cd whisper.cpp
make
pip install --upgrade huggingface_hub
Faster-Whisper
pip install --upgrade pip
pip install --upgrade git+https://github.com/SYSTRAN/faster-whisper datasets[audio]
OpenAI Whisper
pip install --upgrade pip
pip install --upgrade openai-whisper datasets[audio]
Candle
git clone https://github.com/huggingface/candle.git
cd candle/candle-examples/examples/whisper
💻 使用示例
基礎用法
短文本轉錄
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from datasets import load_dataset
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = "distil-whisper/distil-large-v3.5"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
torch_dtype=torch_dtype,
device=device,
)
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
sample = dataset[0]["audio"]
result = pipe(sample)
print(result["text"])
轉錄本地音頻文件
- result = pipe(sample)
+ result = pipe("audio.mp3")
獲取片段級時間戳
result = pipe(sample, return_timestamps=True)
print(result["chunks"])
高級用法
順序長文本轉錄
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from datasets import load_dataset
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = "distil-whisper/distil-large-v3.5"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
torch_dtype=torch_dtype,
device=device,
)
dataset = load_dataset("distil-whisper/librispeech_long", "clean", split="validation")
sample = dataset[0]["audio"]
result = pipe(sample)
print(result["text"])
分塊長文本轉錄
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from datasets import load_dataset
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = "distil-whisper/distil-large-v3.5"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
chunk_length_s=25,
batch_size=16,
torch_dtype=torch_dtype,
device=device,
)
dataset = load_dataset("distil-whisper/librispeech_long", "clean", split="validation")
sample = dataset[0]["audio"]
result = pipe(sample)
print(result["text"])
推測解碼
from transformers import pipeline, AutoModelForCausalLM, AutoModelForSpeechSeq2Seq, AutoProcessor
import torch
from datasets import load_dataset
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
assistant_model_id = "distil-whisper/distil-large-v3.5"
assistant_model = AutoModelForCausalLM.from_pretrained(
assistant_model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
assistant_model.to(device)
model_id = "openai/whisper-large-v3"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
generate_kwargs={"assistant_model": assistant_model},
torch_dtype=torch_dtype,
device=device,
)
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
sample = dataset[0]["audio"]
result = pipe(sample)
print(result["text"])
📚 詳細文檔
性能評估
短文本評估
模型在5個分佈內(ID)測試集和2個分佈外(OOD)測試集上進行了短文本轉錄評估,結果如下:
數據集 | 大小(小時) | large-v3 | large-v3-turbo | distil-v3 | distil-v3.5 |
---|---|---|---|---|---|
AMI | 8.68 | 15.95 | 16.13 | 15.16 | 14.63 |
Gigaspeech | 35.36 | 10.02 | 10.14 | 10.08 | 9.84 |
LS Clean | 5.40 | 2.01 | 2.10 | 2.54 | 2.37 |
LS Other | 5.34 | 3.91 | 4.24 | 5.19 | 5.04 |
Tedlium | 2.61 | 3.86 | 3.57 | 3.86 | 3.64 |
----------- | ----- | ----- | ----- | ----- | ----- |
Earnings22 | 5.43 | 11.29 | 11.63 | 11.79 | 11.29 |
SPGISpeech | 100.00 | 2.94 | 2.97 | 3.27 | 2.87 |
----------- | ----- | ----- | ----- | ----- | ----- |
ID平均 | 7.15 | 7.24 | 7.37 | 7.10 | |
OOD平均 | 7.12 | 7.30 | 7.53 | 7.08 | |
總平均 | 7.14 | 7.25 | 7.41 | 7.10 |
長文本評估
模型在1個分佈內(ID)測試集和4個分佈外(OOD)測試集上進行了長文本轉錄評估,使用順序解碼算法(condition_on_prev_tokens=False, return_timestamps=True),結果如下:
數據集 | 大小(小時) | large-v3-turbo | distil-v2 | distil-v3 | distil-v3.5 |
---|---|---|---|---|---|
tedlium-long-form | 2.47 | 3.07 | 9.66 | 3.9 | 4.63 |
----------------- | ----- | ----- | ----- | ----- | ----- |
meanwhile | 1.01 | 5.03 | 16.75 | 7.04 | 6.79 |
earnings21 | 39.26 | 9.84 | 15.09 | 10.54 | 10.6 |
earnings22 | 119.89 | 13.32 | 19.11 | 15.06 | 14.19 |
rev16 | 16.16 | 12.82 | 21.15 | 13.76 | 13.98 |
----------------- | ----- | ----- | ----- | ----- | ----- |
ID平均 | 3.07 | 9.66 | 3.9 | 4.63 | |
OOD平均 | 10.25 | 18.03 | 11.6 | 11.39 | |
總平均 | 8.82 | 16.35 | 10.06 | 10.04 |
不同算法使用場景
- 順序長文本算法:適用於轉錄準確性是最重要因素,且對延遲考慮較少的場景;或者轉錄批量長音頻文件的場景,此時順序算法的延遲與分塊算法相當,但詞錯誤率可提高達0.5%。
- 分塊長文本算法:適用於轉錄單個大音頻文件且需要最快推理速度的場景,比分塊算法快達9倍。
不同庫集成使用說明
Whisper.cpp
git clone https://github.com/ggerganov/whisper.cpp.git
cd whisper.cpp
make
pip install --upgrade huggingface_hub
Faster-Whisper
pip install --upgrade pip
pip install --upgrade git+https://github.com/SYSTRAN/faster-whisper datasets[audio]
OpenAI Whisper
pip install --upgrade pip
pip install --upgrade openai-whisper datasets[audio]
Candle
git clone https://github.com/huggingface/candle.git
cd candle/candle-examples/examples/whisper
🔧 技術細節
訓練細節
Distil-Whisper繼承了Whisper的編碼器 - 解碼器架構。編碼器將語音向量輸入序列映射到隱藏狀態向量序列,解碼器根據所有先前的令牌和編碼器隱藏狀態自迴歸地預測文本令牌。因此,編碼器只進行一次前向傳播,而解碼器的運行次數與生成的令牌數量相同。在實踐中,這意味著解碼器佔總推理時間的90%以上。因此,為了優化延遲,重點在於最小化解碼器的推理時間。
Distil-Large-v3.5基於先前Distil-Whisper模型的技術,特別是Distil-Large-v2中引入的教師模型蒸餾訓練過程,以及Distil-Large-v3中用於改進順序解碼長文本轉錄的樣本打包方法。除了這些已有的方法,我們還進行了以下顯著改進:
- 大幅擴展訓練數據:將訓練數據集從之前版本的22000小時增加到超過98000小時的高質量公共數據,這主要得益於Yodas數據集,它提供了來自YouTube的多樣化內容。
- 實現了具有激進數據增強(SpecAugment)的"耐心"教師:這使得訓練計劃大幅擴展到80個週期,而之前版本僅為11個週期。即使在這個延長的時間段內,模型的評估損失仍在略微下降,表明模型仍在持續改進。
- 降低訓練數據中附加先前提示的概率:最初設置為50%,但發現模型在轉錄包含先前上下文的文本時存在困難,可能是由於解碼器大小的限制。隨後將其降低到20%,這在提高整體性能的同時,仍然起到了數據增強的作用。
- 增加批量大小和學習率:實現了4096個打包片段的更大批量大小,遠大於之前版本使用的256個。還測試了包括餘弦、wsd和scheduler-free optimizers在內的替代學習率調度器,但發現線性方法仍然表現最佳。
我們還修改了片段打包順序,以創建更具邏輯結構的打包片段,並採用了BPE dropout作為正則化方法。我們發現這對短文本轉錄性能略有下降,但對長文本內容的結果有所改善。
該模型在Jean Zay集群上使用64個H100 GPU進行訓練,整個訓練過程耗時三天。Tensorboard訓練日誌可在這裡找到。
訓練數據
我們最初從Common Voice、LibriSpeech、VoxPopuli、TED-LIUM、People's Speech、GigaSpeech、AMI等來源收集了超過196000小時的公共數據,特別是Yodas。這個多樣化的數據集對於確保我們的蒸餾模型在各種音頻分佈和噪聲條件下保持魯棒性至關重要。
我們將收集到的示例打包成大約30秒的片段,每個片段只包含一個說話者。為了保持數據集之間的高質量和一致格式,我們使用Whisper-Large-v3對這些訓練片段進行了偽標記。然後,我們對Whisper偽標記和每個數據集提供的真實標記進行了歸一化處理,並計算了它們之間的詞錯誤率(WER)。我們丟棄了任何WER超過10%的示例,最終得到了大約98000小時的高質量訓練數據。
過濾後的數據可以在這個多語言數據集中找到,以便進行重現和進一步過濾。
重現Distil-Whisper
用於重現Distil-Whisper的訓練和評估代碼可在Distil-Whisper倉庫中找到。
📄 許可證
Distil-Whisper繼承了OpenAI的Whisper模型的MIT許可證。
🔖 引用
如果您使用了這個模型,請考慮引用Distil-Whisper論文:
@misc{gandhi2023distilwhisper,
title={Distil-Whisper: Robust Knowledge Distillation via Large-Scale Pseudo Labelling},
author={Sanchit Gandhi and Patrick von Platen and Alexander M. Rush},
year={2023},
eprint={2311.00430},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
🙏 致謝
- OpenAI提供了Whisper模型和原始代碼庫。
- Hugging Face 🤗 Transformers實現了模型集成。
- Sanchit Gandhi建立了Distil-Whisper訓練管道,並在這篇論文中進行了詳細研究。
- GENCI在分配號2024 - GC011015467下提供了IDRIS HPC資源的訪問權限。
- Georgi Gerganov實現了Whisper.cpp集成。
- Systran團隊實現了Faster-Whisper集成。
- Joshua Lochner實現了Transformers.js集成。
- Laurent Mazare實現了Candle集成。
- Vaibhav Srivastav負責Distil-Whisper的分發。
- Raghav Sonavane在LibriSpeech數據集上進行了Distil-Whisper的早期迭代。



