🚀 SEW-D-base+
SEW-D-base+ 是由 ASAPP Research 基於 16kHz 採樣的語音音頻進行預訓練的基礎模型。使用該模型時,請確保輸入的語音也採樣為 16kHz。請注意,此模型需要在下游任務(如自動語音識別、說話人識別、意圖分類、情感識別等)上進行微調。
論文信息
- 標題:Performance-Efficiency Trade-offs in Unsupervised Pre-training for Speech Recognition
- 作者:Felix Wu, Kwangyoun Kim, Jing Pan, Kyu Han, Kilian Q. Weinberger, Yoav Artzi
- 摘要:本文研究了自動語音識別(ASR)預訓練模型中的性能 - 效率權衡問題。我們聚焦於 wav2vec 2.0,並對影響模型性能和效率的幾種架構設計進行了形式化。綜合所有觀察結果,我們推出了 SEW(Squeezed and Efficient Wav2vec),這是一種預訓練模型架構,在各種訓練設置下,其性能和效率都有顯著提升。例如,在 LibriSpeech 的 100h - 960h 半監督設置下,與 wav2vec 2.0 相比,SEW 的推理速度提高了 1.9 倍,單詞錯誤率相對降低了 13.5%。在推理時間相近的情況下,SEW 在不同模型規模下將單詞錯誤率降低了 25 - 50%。
原始模型可在 https://github.com/asappresearch/sew#model-checkpoints 找到。
🚀 快速開始
模型信息
屬性 |
詳情 |
模型類型 |
語音識別模型 |
訓練數據 |
librispeech_asr |
許可證 |
apache-2.0 |
示例音頻
評估結果
數據集 |
測試集 |
單詞錯誤率(WER) |
LibriSpeech (clean) |
test |
4.34 |
LibriSpeech (other) |
test |
9.45 |
💻 使用示例
基礎用法
from transformers import Wav2Vec2Processor, SEWDForCTC
from datasets import load_dataset
import soundfile as sf
import torch
processor = Wav2Vec2Processor.from_pretrained("asapp/sew-d-base-plus-400k-ft-ls100h")
model = SEWDForCTC.from_pretrained("asapp/sew-d-base-plus-400k-ft-ls100h")
ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
input_values = processor(ds[0]["audio"]["array"], return_tensors="pt").input_values
logits = model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)
評估用法
from datasets import load_dataset
from transformers import SEWDForCTC, Wav2Vec2Processor
import torch
from jiwer import wer
librispeech_eval = load_dataset("librispeech_asr", "clean", split="test")
model = SEWDForCTC.from_pretrained("asapp/sew-d-base-plus-400k-ft-ls100h").to("cuda")
processor = Wav2Vec2Processor.from_pretrained("asapp/sew-d-base-plus-400k-ft-ls100h")
def map_to_pred(batch):
input_values = processor(batch["audio"][0]["array"], sampling_rate=16000,
return_tensors="pt", padding="longest").input_values
with torch.no_grad():
logits = model(input_values.to("cuda")).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)
batch["transcription"] = transcription
return batch
result = librispeech_eval.map(map_to_pred, batched=True, batch_size=1, remove_columns=["audio"])
print("WER:", wer(result["text"], result["transcription"]))
📄 許可證
本項目採用 apache-2.0 許可證。