🚀 SEW-D-mid-k127
SEW-D-mid-k127 是基於 16kHz 採樣語音音頻預訓練的基礎模型。該模型由 ASAPP Research 開發,相關信息可查看 SEW-D by ASAPP Research。使用此模型時,請確保輸入的語音也採樣為 16kHz。需注意,該模型應在下游任務(如自動語音識別、說話人識別、意圖分類、情感識別等)上進行微調。
論文:Performance-Efficiency Trade-offs in Unsupervised Pre-training for Speech Recognition
作者:Felix Wu, Kwangyoun Kim, Jing Pan, Kyu Han, Kilian Q. Weinberger, Yoav Artzi
摘要
本文研究了自動語音識別(ASR)預訓練模型中的性能 - 效率權衡問題。我們聚焦於 wav2vec 2.0,並對影響模型性能和效率的幾種架構設計進行了形式化。綜合所有觀察結果,我們推出了 SEW(Squeezed and Efficient Wav2vec),這是一種預訓練模型架構,在各種訓練設置下,其性能和效率都有顯著提升。例如,在 LibriSpeech 的 100h - 960h 半監督設置下,與 wav2vec 2.0 相比,SEW 的推理速度提高了 1.9 倍,單詞錯誤率相對降低了 13.5%。在推理時間相近的情況下,SEW 在不同模型規模下將單詞錯誤率降低了 25 - 50%。
原始模型可在 https://github.com/asappresearch/sew#model-checkpoints 找到。
🚀 快速開始
模型信息
屬性 |
詳情 |
數據集 |
librispeech_asr |
標籤 |
音頻、語音、自動語音識別、hf-asr-leaderboard |
許可證 |
apache-2.0 |
示例音頻
模型評估結果
任務 |
數據集 |
指標 |
值 |
自動語音識別 |
LibriSpeech (clean) |
測試單詞錯誤率 (Test WER) |
4.99 |
自動語音識別 |
LibriSpeech (other) |
測試單詞錯誤率 (Test WER) |
10.95 |
💻 使用示例
基礎用法
以下代碼展示瞭如何將該模型作為獨立的聲學模型來轉錄音頻文件:
from transformers import Wav2Vec2Processor, SEWDForCTC
from datasets import load_dataset
import soundfile as sf
import torch
processor = Wav2Vec2Processor.from_pretrained("asapp/sew-d-mid-k127-400k-ft-ls100h")
model = SEWDForCTC.from_pretrained("asapp/sew-d-mid-k127-400k-ft-ls100h")
ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
input_values = processor(ds[0]["audio"]["array"], return_tensors="pt").input_values
logits = model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)
高級用法
以下代碼展示瞭如何在 LibriSpeech 的 "clean" 和 "other" 測試數據上評估 asapp/sew-d-mid-k127-400k-ft-ls100hh 模型:
from datasets import load_dataset
from transformers import SEWDForCTC, Wav2Vec2Processor
import torch
from jiwer import wer
librispeech_eval = load_dataset("librispeech_asr", "clean", split="test")
model = SEWDForCTC.from_pretrained("asapp/sew-d-mid-k127-400k-ft-ls100h").to("cuda")
processor = Wav2Vec2Processor.from_pretrained("asapp/sew-d-mid-k127-400k-ft-ls100h")
def map_to_pred(batch):
input_values = processor(batch["audio"][0]["array"], sampling_rate=16000,
return_tensors="pt", padding="longest").input_values
with torch.no_grad():
logits = model(input_values.to("cuda")).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)
batch["transcription"] = transcription
return batch
result = librispeech_eval.map(map_to_pred, batched=True, batch_size=1, remove_columns=["audio"])
print("WER:", wer(result["text"], result["transcription"]))
結果(單詞錯誤率):
"clean" |
"other" |
4.99 |
10.95 |
📄 許可證
本項目採用 apache-2.0 許可證。