🚀 SEW-D-mid
SEW-D-mid 是基於 16kHz 採樣語音音頻預訓練的基礎模型。該模型由 ASAPP Research 開發,相關信息可參考 SEW-D by ASAPP Research。使用此模型時,需確保輸入的語音也採樣為 16kHz。請注意,該模型需要在下游任務(如自動語音識別、說話人識別、意圖分類、情感識別等)上進行微調。
論文:Performance-Efficiency Trade-offs in Unsupervised Pre-training for Speech Recognition
作者:Felix Wu、Kwangyoun Kim、Jing Pan、Kyu Han、Kilian Q. Weinberger、Yoav Artzi
摘要
本文研究了自動語音識別(ASR)預訓練模型中的性能 - 效率權衡問題。我們聚焦於 wav2vec 2.0,並對影響模型性能和效率的幾種架構設計進行了形式化。綜合所有觀察結果,我們提出了 SEW(Squeezed and Efficient Wav2vec),這是一種在各種訓練設置下,在性能和效率方面都有顯著改進的預訓練模型架構。例如,在 LibriSpeech 的 100h - 960h 半監督設置下,與 wav2vec 2.0 相比,SEW 的推理速度提高了 1.9 倍,詞錯誤率相對降低了 13.5%。在推理時間相近的情況下,SEW 在不同模型大小下將詞錯誤率降低了 25 - 50%。
原始模型可在 https://github.com/asappresearch/sew#model-checkpoints 找到。
🚀 快速開始
模型信息
屬性 |
詳情 |
模型類型 |
語音處理模型 |
訓練數據 |
LibriSpeech 數據集 |
標籤 |
音頻、語音、自動語音識別、HF 自動語音識別排行榜 |
許可證 |
Apache-2.0 |
示例音頻
評估結果
數據集 |
測試詞錯誤率 (WER) |
LibriSpeech (clean) |
4.94 |
LibriSpeech (other) |
11.51 |
✨ 主要特性
- 基於 16kHz 採樣語音音頻進行預訓練。
- 可應用於多種下游任務,如自動語音識別、說話人識別等。
- 在性能和效率方面有顯著改進。
📦 安裝指南
文檔未提及安裝步驟,暫不提供。
💻 使用示例
基礎用法
以下代碼展示瞭如何將該模型作為獨立的聲學模型來轉錄音頻文件:
from transformers import Wav2Vec2Processor, SEWDForCTC
from datasets import load_dataset
import soundfile as sf
import torch
processor = Wav2Vec2Processor.from_pretrained("asapp/sew-d-mid-400k-ft-ls100h")
model = SEWDForCTC.from_pretrained("asapp/sew-d-mid-400k-ft-ls100h")
ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
input_values = processor(ds[0]["audio"]["array"], return_tensors="pt").input_values
logits = model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)
高級用法
以下代碼展示瞭如何在 LibriSpeech 的 “clean” 和 “other” 測試數據上評估 asapp/sew-d-mid-400k-ft-ls100hh 模型:
from datasets import load_dataset
from transformers import SEWDForCTC, Wav2Vec2Processor
import torch
from jiwer import wer
librispeech_eval = load_dataset("librispeech_asr", "clean", split="test")
model = SEWDForCTC.from_pretrained("asapp/sew-d-mid-400k-ft-ls100h").to("cuda")
processor = Wav2Vec2Processor.from_pretrained("asapp/sew-d-mid-400k-ft-ls100h")
def map_to_pred(batch):
input_values = processor(batch["audio"][0]["array"], sampling_rate=16000,
return_tensors="pt", padding="longest").input_values
with torch.no_grad():
logits = model(input_values.to("cuda")).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)
batch["transcription"] = transcription
return batch
result = librispeech_eval.map(map_to_pred, batched=True, batch_size=1, remove_columns=["audio"])
print("WER:", wer(result["text"], result["transcription"]))
📚 詳細文檔
注意事項
⚠️ 重要提示
使用該模型時,確保輸入的語音採樣率為 16kHz。同時,該模型需要在下游任務上進行微調。
論文引用
📄 許可證
本項目採用 Apache-2.0 許可證。