模型概述
模型特點
模型能力
使用案例
🚀 distil-whisper-large-v3-es
這是一個經過蒸餾的Whisper v3大模型的倉庫,該模型在Mozilla Common Voice數據集v16.1上進行了訓練。此模型的誕生得益於SandboxAI和內格羅河國立大學的合作。
🚀 快速開始
Distil-Whisper從Hugging Face 🤗 Transformers的4.35版本開始得到支持。要運行該模型,首先需要安裝最新版本的Transformers庫。在這個示例中,我們還將安裝🤗 Datasets,以便從Hugging Face Hub加載玩具音頻數據集:
pip install --upgrade pip
pip install --upgrade transformers accelerate datasets[audio]
✨ 主要特性
- 支持多種轉錄方式:可進行短音頻(< 30秒)和長音頻(> 30秒)的轉錄,還支持推測解碼。
- 高效的長音頻轉錄算法:Distil-Whisper使用分塊算法進行長音頻轉錄,比OpenAI在Whisper論文中提出的順序算法快9倍。
- 可作為推測解碼的輔助模型:能作為Whisper的輔助模型進行推測解碼,在保證輸出結果與Whisper相同的情況下,速度提升2倍。
💻 使用示例
基礎用法
短音頻轉錄
該模型可以使用pipeline
類來轉錄短音頻文件(< 30秒),如下所示:
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from datasets import load_dataset
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = "marianbasti/distil-whisper-large-v3-es"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
torch_dtype=torch_dtype,
device=device,
)
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
sample = dataset[0]["audio"]
result = pipe(sample)
print(result["text"])
要轉錄本地音頻文件,只需在調用pipeline時傳入音頻文件的路徑:
- result = pipe(sample)
+ result = pipe("audio.mp3")
高級用法
長音頻轉錄
Distil-Whisper使用分塊算法來轉錄長音頻文件(> 30秒)。實際上,這種分塊長音頻算法比OpenAI在Whisper論文中提出的順序算法快9倍(參見Distil-Whisper論文的表7)。
要啟用分塊功能,需要將chunk_length_s
參數傳遞給pipeline
。對於Distil-Whisper,15秒的分塊長度是最優的。要啟用批處理,需要傳遞batch_size
參數:
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from datasets import load_dataset
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = "marianbasti/distil-whisper-large-v3-es"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
chunk_length_s=15,
batch_size=16,
torch_dtype=torch_dtype,
device=device,
)
dataset = load_dataset("distil-whisper/librispeech_long", "clean", split="validation")
sample = dataset[0]["audio"]
result = pipe(sample)
print(result["text"])
推測解碼
Distil-Whisper可以作為Whisper的輔助模型用於推測解碼。推測解碼在數學上保證了與Whisper相同的輸出,同時速度快2倍。這使得它成為現有Whisper管道的完美替代品,因為可以保證相同的輸出。 在以下代碼片段中,我們將輔助Distil-Whisper模型獨立加載到主Whisper管道中。然後,我們將其指定為生成的“輔助模型”:
from transformers import pipeline, AutoModelForCausalLM, AutoModelForSpeechSeq2Seq, AutoProcessor
import torch
from datasets import load_dataset
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
assistant_model_id = "marianbasti/distil-whisper-large-v3-es"
assistant_model = AutoModelForCausalLM.from_pretrained(
assistant_model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
assistant_model.to(device)
model_id = "openai/whisper-large-v3"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
generate_kwargs={"assistant_model": assistant_model},
torch_dtype=torch_dtype,
device=device,
)
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
sample = dataset[0]["audio"]
result = pipe(sample)
print(result["text"])
🔧 技術細節
該模型在單個RTX3090上進行了60,000次優化步驟(約1.47個epoch)的訓練,訓練時長約為60小時,使用的訓練參數如下:
--teacher_model_name_or_path "openai/whisper-large-v3"
--train_dataset_name "mozilla-foundation/common_voice_16_1"
--train_dataset_config_name "es"
--train_split_name "train"
--text_column_name "sentence"
--eval_dataset_name "mozilla-foundation/common_voice_16_1"
--eval_dataset_config_name "es"
--eval_split_name "validation"
--eval_text_column_name "sentence"
--eval_steps 10000
--save_steps 10000
--warmup_steps 500
--learning_rate 1e-4
--lr_scheduler_type "linear"
--logging_steps 25
--save_total_limit 1
--max_steps 60000
--wer_threshold 10
--per_device_train_batch_size 8
--per_device_eval_batch_size 8
--dataloader_num_workers 12
--preprocessing_num_workers 12
--output_dir "./"
--do_train
--do_eval
--gradient_checkpointing
--predict_with_generate
--overwrite_output_dir
--use_pseudo_labels "false"
--freeze_encoder
--streaming False
📚 詳細文檔
訓練結果
蒸餾後的模型的字錯誤率(WER)為5.11%(正交WER為10.15%)。
📄 許可證
Distil-Whisper繼承了OpenAI的Whisper模型的MIT許可證。
引用
如果您使用此模型,請考慮引用Distil-Whisper論文:
@misc{gandhi2023distilwhisper,
title={Distil-Whisper: Robust Knowledge Distillation via Large-Scale Pseudo Labelling},
author={Sanchit Gandhi and Patrick von Platen and Alexander M. Rush},
year={2023},
eprint={2311.00430},
archivePrefix={arXiv},
primaryClass={cs.CL}
}



