🚀 Sharif-wav2vec2
Sharif-wav2vec2 是针对波斯语微调后的版本。基础模型经过了微调,使用了 Commonvoice 中时长为 108 小时、采样率为 16kHz 的波斯语样本。之后,我们使用 kenlm 工具包训练了一个 5-gram 语言模型,并将其用于处理器中,这提高了我们在线自动语音识别(ASR)的准确率。
🚀 快速开始
在使用该模型时,请确保语音输入的采样率为 16kHz。在使用之前,你可能需要安装以下依赖项:
pip install pyctcdecode
pip install pypi-kenlm
💻 使用示例
基础用法
你可以使用 Hugging Face 上的托管推理 API 进行测试(提供了来自 Common Voice 的示例)。转录给定语音可能需要一些时间;或者你可以使用以下代码在本地运行:
import tensorflow
import torchaudio
import torch
import numpy as np
from transformers import AutoProcessor, AutoModelForCTC
processor = AutoProcessor.from_pretrained("SLPL/Sharif-wav2vec2")
model = AutoModelForCTC.from_pretrained("SLPL/Sharif-wav2vec2")
speech_array, sampling_rate = torchaudio.load("path/to/your.wav")
speech_array = speech_array.squeeze().numpy()
features = processor(
speech_array,
sampling_rate=processor.feature_extractor.sampling_rate,
return_tensors="pt",
padding=True)
with torch.no_grad():
logits = model(
features.input_values,
attention_mask=features.attention_mask).logits
prediction = processor.batch_decode(logits.numpy()).text
print(prediction[0])
评估
你可以使用以下代码进行评估。请确保你的数据集采用以下形式,以避免冲突:
path |
reference |
path/to/audio_file.wav |
"TRANSCRIPTION" |
同时,请确保在运行之前安装了 pip install jiwer
。
import tensorflow
import torchaudio
import torch
import librosa
from datasets import load_dataset,load_metric
import numpy as np
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
from transformers import Wav2Vec2ProcessorWithLM
model = Wav2Vec2ForCTC.from_pretrained("SLPL/Sharif-wav2vec2")
processor = Wav2Vec2ProcessorWithLM.from_pretrained("SLPL/Sharif-wav2vec2")
def speech_file_to_array_fn(batch):
speech_array, sampling_rate = torchaudio.load(batch["path"])
speech_array = speech_array.squeeze().numpy()
speech_array = librosa.resample(
np.asarray(speech_array),
sampling_rate,
processor.feature_extractor.sampling_rate)
batch["speech"] = speech_array
return batch
def predict(batch):
features = processor(
batch["speech"],
sampling_rate=processor.feature_extractor.sampling_rate,
return_tensors="pt",
padding=True
)
with torch.no_grad():
logits = model(
features.input_values,
attention_mask=features.attention_mask).logits
batch["prediction"] = processor.batch_decode(logits.numpy()).text
return batch
dataset = load_dataset(
"csv",
data_files={"test":"dataset.eval.csv"},
delimiter=",")["test"]
dataset = dataset.map(speech_file_to_array_fn)
result = dataset.map(predict, batched=True, batch_size=4)
wer = load_metric("wer")
print("WER: {:.2f}".format(wer.compute(
predictions=result["prediction"],
references=result["reference"])))
在 Common Voice 6.1 上的结果(WER):
📄 许可证
本项目采用 MIT 许可证。
📚 引用
如果你想引用此模型,可以使用以下内容:
?
贡献
感谢 @sarasadeghii 和 @sadrasabouri 添加此模型。