🚀 SEW-tiny
SEW-tiny是基於16kHz採樣語音音頻預訓練的基礎模型。該模型可用於自動語音識別、說話人識別、意圖分類、情感識別等下游任務。使用模型時,請確保輸入的語音也是16kHz採樣的。
🚀 快速開始
本模型基於 SEW by ASAPP Research 開發。相關論文為 Performance-Efficiency Trade-offs in Unsupervised Pre-training for Speech Recognition,作者包括 Felix Wu、Kwangyoun Kim、Jing Pan、Kyu Han、Kilian Q. Weinberger、Yoav Artzi。
摘要:本文研究了自動語音識別(ASR)預訓練模型中的性能 - 效率權衡問題。聚焦於wav2vec 2.0,我們對影響模型性能和效率的幾種架構設計進行了形式化。綜合所有觀察結果,我們推出了SEW(Squeezed and Efficient Wav2vec),這是一種預訓練模型架構,在各種訓練設置下,在性能和效率方面都有顯著提升。例如,在LibriSpeech的100h - 960h半監督設置下,與wav2vec 2.0相比,SEW的推理速度提高了1.9倍,單詞錯誤率相對降低了13.5%。在推理時間相近的情況下,SEW在不同模型規模下將單詞錯誤率降低了25 - 50%。
原始模型可在 https://github.com/asappresearch/sew#model-checkpoints 找到。
✨ 主要特性
- 適用音頻格式:適用於16kHz採樣的語音音頻。
- 下游任務廣泛:可用於自動語音識別、說話人識別、意圖分類、情感識別等下游任務。
- 性能與效率提升:在性能和效率方面相較於wav2vec 2.0有顯著提升。
📦 安裝指南
文檔中未提及具體安裝步驟,可參考相關依賴庫的安裝方式,如 transformers
、datasets
、soundfile
、torch
、jiwer
等。
💻 使用示例
基礎用法
以下代碼展示瞭如何將該模型作為獨立的聲學模型來轉錄音頻文件:
from transformers import Wav2Vec2Processor, SEWForCTC
from datasets import load_dataset
import soundfile as sf
import torch
processor = Wav2Vec2Processor.from_pretrained("asapp/sew-tiny-100k-ft-ls100h")
model = SEWForCTC.from_pretrained("asapp/sew-tiny-100k-ft-ls100h")
ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
input_values = processor(ds[0]["audio"]["array"], return_tensors="pt").input_values
logits = model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)
高級用法
以下代碼展示瞭如何在LibriSpeech的“clean”和“other”測試數據上評估 asapp/sew-tiny-100k-ft-ls100h 模型:
from datasets import load_dataset
from transformers import SEWForCTC, Wav2Vec2Processor
import torch
from jiwer import wer
librispeech_eval = load_dataset("librispeech_asr", "clean", split="test")
model = SEWForCTC.from_pretrained("asapp/sew-tiny-100k-ft-ls100h").to("cuda")
processor = Wav2Vec2Processor.from_pretrained("asapp/sew-tiny-100k-ft-ls100h")
def map_to_pred(batch):
input_values = processor(batch["audio"][0]["array"], sampling_rate=16000,
return_tensors="pt", padding="longest").input_values
with torch.no_grad():
logits = model(input_values.to("cuda")).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)
batch["transcription"] = transcription
return batch
result = librispeech_eval.map(map_to_pred, batched=True, batch_size=1, remove_columns=["audio"])
print("WER:", wer(result["text"], result["transcription"]))
📚 詳細文檔
- 數據集:使用了
librispeech_asr
數據集。
- 評估指標:使用單詞錯誤率(WER)進行評估。
🔧 技術細節
本文聚焦於wav2vec 2.0,對影響模型性能和效率的幾種架構設計進行了形式化。推出的SEW模型在各種訓練設置下,在性能和效率方面都有顯著提升。例如,在LibriSpeech的100h - 960h半監督設置下,與wav2vec 2.0相比,SEW的推理速度提高了1.9倍,單詞錯誤率相對降低了13.5%。在推理時間相近的情況下,SEW在不同模型規模下將單詞錯誤率降低了25 - 50%。
📄 許可證
本項目採用 apache-2.0
許可證。
相關信息表格
屬性 |
詳情 |
模型類型 |
SEW-tiny |
訓練數據 |
librispeech_asr |
標籤 |
音頻、語音、自動語音識別、hf-asr-leaderboard |
許可證 |
apache-2.0 |
模型評估結果
數據集 |
測試WER |
LibriSpeech (clean) |
10.61 |
LibriSpeech (other) |
23.74 |