🚀 微调日语Whisper语音识别模型
本项目是基于openai/whisper-small模型,使用Common Voice、JVS和JSUT数据集对日语进行微调的语音识别模型。使用该模型时,请确保语音输入的采样率为16kHz。
🚀 快速开始
本模型可直接按如下方式使用:
💻 使用示例
基础用法
from transformers import WhisperForConditionalGeneration, WhisperProcessor
from datasets import load_dataset
import librosa
import torch
LANG_ID = "ja"
MODEL_ID = "Ivydata/whisper-small-japanese"
SAMPLES = 10
test_dataset = load_dataset("common_voice", LANG_ID, split=f"test[:{SAMPLES}]")
processor = WhisperProcessor.from_pretrained(MODEL_ID)
model = WhisperForConditionalGeneration.from_pretrained(MODEL_ID)
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(
language="ja", task="transcribe"
)
model.config.suppress_tokens = []
def speech_file_to_array_fn(batch):
speech_array, sampling_rate = librosa.load(batch["path"], sr=16_000)
batch["speech"] = speech_array
batch["sentence"] = batch["sentence"].upper()
batch["sampling_rate"] = sampling_rate
return batch
test_dataset = test_dataset.map(speech_file_to_array_fn)
sample = test_dataset[0]
input_features = processor(sample["speech"], sampling_rate=sample["sampling_rate"], return_tensors="pt").input_features
predicted_ids = model.generate(input_features)
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=False)
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
📚 详细文档
测试结果
以下表格展示了该模型在TEDxJP - 10K数据集上测试的字符错误率(CER):
模型 |
字符错误率(CER) |
Ivydata/whisper-small-japanese |
23.10% |
Ivydata/wav2vec2-large-xlsr-53-japanese |
27.87% |
jonatasgrosman/wav2vec2-large-xlsr-53-japanese |
34.18% |
📄 许可证
本项目采用Apache - 2.0许可证。