🚀 粵語版小模型Whisper - Alvin
本模型是基於粵語對 openai/whisper-small 進行微調的版本。在 Common Voice 16.0 數據集上,其字符錯誤率(CER)在無標點時為 7.93%,有標點時為 9.72%。
✨ 主要特性
- 基於預訓練模型
openai/whisper-small
進行粵語微調。
- 在多個粵語數據集上進行訓練和評估,具有較好的粵語語音識別性能。
- 支持多種推理加速方法,如 Flash Attention 和 Speculative Decoding。
📦 安裝指南
文檔未提及安裝步驟,暫不提供。
💻 使用示例
基礎用法
import librosa
import torch
from transformers import WhisperForConditionalGeneration, WhisperProcessor
y, sr = librosa.load('audio.mp3', sr=16000)
MODEL_NAME = "alvanlii/whisper-small-cantonese"
processor = WhisperProcessor.from_pretrained(MODEL_NAME)
model = WhisperForConditionalGeneration.from_pretrained(MODEL_NAME)
processed_in = processor(y, sampling_rate=sr, return_tensors="pt")
gout = model.generate(
input_features=processed_in.input_features,
output_scores=True, return_dict_in_generate=True
)
transcription = processor.batch_decode(gout.sequences, skip_special_tokens=True)[0]
print(transcription)
高級用法
使用 huggingface pipelines 進行推理:
from transformers import pipeline
MODEL_NAME = "alvanlii/whisper-small-cantonese"
lang = "zh"
device = 0
pipe = pipeline(
task="automatic-speech-recognition",
model=MODEL_NAME,
chunk_length_s=30,
device=device,
)
pipe.model.config.forced_decoder_ids = pipe.tokenizer.get_decoder_prompt_ids(language=lang, task="transcribe")
text = pipe('audio.mp3')["text"]
📚 詳細文檔
訓練和評估數據
訓練數據
- CantoMap:Winterstein, Grégoire, Tang, Carmen 和 Lai, Regine (2020) "CantoMap: a Hong Kong Cantonese MapTask Corpus",發表於 The 12th Language Resources and Evaluation Conference 會議論文集,Marseille: European Language Resources Association, p. 2899 - 2906。
- Cantonse - ASR:Yu, Tiezheng, Frieske, Rita, Xu, Peng, Cahyawijaya, Samuel, Yiu, Cheuk Tung, Lovenia, Holy, Dai, Wenliang, Barezi, Elham, Chen, Qifeng, Ma, Xiaojuan, Shi, Bertram, Fung, Pascale (2022) "Automatic Speech Recognition Datasets in Cantonese: A Survey and New Dataset",2022 年。鏈接:https://arxiv.org/pdf/2201.02419.pdf
名稱 |
時長(小時) |
Common Voice 16.0 zh - HK Train |
138 |
Common Voice 16.0 yue Train |
85 |
Common Voice 17.0 yue Train |
178 |
Cantonese - ASR |
72 |
CantoMap |
23 |
Pseudo - Labelled YouTube Data |
438 |
評估數據
使用 Common Voice 16.0 yue 測試集進行評估。
評估結果
- 字符錯誤率(CER,越低越好):
- 無標點:0.0793
- 有標點:0.0972,較之前版本的 0.1073 和 0.1581 有所下降
- GPU 推理(使用 Fast Attention,示例如下):每個樣本 0.055 秒
- 注意:所有 GPU 評估均在 RTX 3090 GPU 上進行
- GPU 推理:每個樣本 0.308 秒
- CPU 推理:每個樣本 2.57 秒
- GPU 顯存佔用:約 1.5 GB
模型加速
只需添加 attn_implementation="sdpa"
即可使用 Flash Attention 進行加速。
from transformers import AutoModelForSpeechSeq2Seq
import torch
torch_dtype = torch.float16
model = AutoModelForSpeechSeq2Seq.from_pretrained(
"alvanlii/whisper-small-cantonese",
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
use_safetensors=True,
attn_implementation="sdpa",
)
使用 Flash Attention 後,每個樣本的推理時間從 0.308 秒減少到 0.055 秒。
推測解碼
可以使用更大的模型,然後使用 alvanlii/whisper-small-cantonese
加速推理,且基本不損失準確性。
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
import torch
torch_dtype = torch.float16
device = "cuda" if torch.cuda.is_available() else "cpu"
model_id = "simonl0909/whisper-large-v2-cantonese"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
use_safetensors=True,
attn_implementation="sdpa",
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
assistant_model_id = "alvanlii/whisper-small-cantonese"
assistant_model = AutoModelForSpeechSeq2Seq.from_pretrained(
assistant_model_id,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
use_safetensors=True,
attn_implementation="sdpa",
)
assistant_model.to(device)
inputs = processor(...)
model.generate(**inputs, use_cache=True, assistant_model=assistant_model)
原始的 simonl0909/whisper-large-v2-cantonese
模型每個樣本推理時間為 0.714 秒,CER 為 7.65%。使用 alvanlii/whisper-small-cantonese
進行推測解碼後,每個樣本推理時間為 0.137 秒,CER 為 7.67%,速度大幅提升。
Whisper.cpp
截至 2024 年 6 月,已上傳用於 Whisper cpp 的 GGML 二進制文件。可以從 這裡 下載二進制文件,並在 這裡 進行測試。
Whisper CT2
若要在 WhisperX 或 FasterWhisper 中使用,需要 CT2 文件。轉換後的模型文件位於 這裡。
訓練超參數
屬性 |
詳情 |
學習率 |
5e - 5 |
訓練批次大小 |
25(在 1 塊 3090 GPU 上) |
評估批次大小 |
8 |
梯度累積步數 |
4 |
總訓練批次大小 |
25 x 4 = 100 |
優化器 |
Adam,beta=(0.9, 0.999),epsilon = 1e - 08 |
學習率調度器類型 |
線性 |
學習率調度器熱身步數 |
500 |
訓練步數 |
15000 |
數據增強 |
無 |
🔧 技術細節
文檔未提供足夠詳細的技術實現細節,暫不提供。
📄 許可證
本模型遵循 Apache - 2.0 許可證。