🚀 Sharif-wav2vec2
Sharif-wav2vec2 是針對波斯語微調後的版本。基礎模型經過了微調,使用了 Commonvoice 中時長為 108 小時、採樣率為 16kHz 的波斯語樣本。之後,我們使用 kenlm 工具包訓練了一個 5-gram 語言模型,並將其用於處理器中,這提高了我們在線自動語音識別(ASR)的準確率。
🚀 快速開始
在使用該模型時,請確保語音輸入的採樣率為 16kHz。在使用之前,你可能需要安裝以下依賴項:
pip install pyctcdecode
pip install pypi-kenlm
💻 使用示例
基礎用法
你可以使用 Hugging Face 上的託管推理 API 進行測試(提供了來自 Common Voice 的示例)。轉錄給定語音可能需要一些時間;或者你可以使用以下代碼在本地運行:
import tensorflow
import torchaudio
import torch
import numpy as np
from transformers import AutoProcessor, AutoModelForCTC
processor = AutoProcessor.from_pretrained("SLPL/Sharif-wav2vec2")
model = AutoModelForCTC.from_pretrained("SLPL/Sharif-wav2vec2")
speech_array, sampling_rate = torchaudio.load("path/to/your.wav")
speech_array = speech_array.squeeze().numpy()
features = processor(
speech_array,
sampling_rate=processor.feature_extractor.sampling_rate,
return_tensors="pt",
padding=True)
with torch.no_grad():
logits = model(
features.input_values,
attention_mask=features.attention_mask).logits
prediction = processor.batch_decode(logits.numpy()).text
print(prediction[0])
評估
你可以使用以下代碼進行評估。請確保你的數據集採用以下形式,以避免衝突:
path |
reference |
path/to/audio_file.wav |
"TRANSCRIPTION" |
同時,請確保在運行之前安裝了 pip install jiwer
。
import tensorflow
import torchaudio
import torch
import librosa
from datasets import load_dataset,load_metric
import numpy as np
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
from transformers import Wav2Vec2ProcessorWithLM
model = Wav2Vec2ForCTC.from_pretrained("SLPL/Sharif-wav2vec2")
processor = Wav2Vec2ProcessorWithLM.from_pretrained("SLPL/Sharif-wav2vec2")
def speech_file_to_array_fn(batch):
speech_array, sampling_rate = torchaudio.load(batch["path"])
speech_array = speech_array.squeeze().numpy()
speech_array = librosa.resample(
np.asarray(speech_array),
sampling_rate,
processor.feature_extractor.sampling_rate)
batch["speech"] = speech_array
return batch
def predict(batch):
features = processor(
batch["speech"],
sampling_rate=processor.feature_extractor.sampling_rate,
return_tensors="pt",
padding=True
)
with torch.no_grad():
logits = model(
features.input_values,
attention_mask=features.attention_mask).logits
batch["prediction"] = processor.batch_decode(logits.numpy()).text
return batch
dataset = load_dataset(
"csv",
data_files={"test":"dataset.eval.csv"},
delimiter=",")["test"]
dataset = dataset.map(speech_file_to_array_fn)
result = dataset.map(predict, batched=True, batch_size=4)
wer = load_metric("wer")
print("WER: {:.2f}".format(wer.compute(
predictions=result["prediction"],
references=result["reference"])))
在 Common Voice 6.1 上的結果(WER):
📄 許可證
本項目採用 MIT 許可證。
📚 引用
如果你想引用此模型,可以使用以下內容:
?
貢獻
感謝 @sarasadeghii 和 @sadrasabouri 添加此模型。