🚀 用于语音识别的微调日语Whisper模型
本项目是基于 openai/whisper-base 模型,使用 Common Voice、JVS 和 JSUT 数据集对日语进行微调后的语音识别模型。使用该模型时,请确保输入的语音采样率为 16kHz。
🚀 快速开始
本模型可直接按以下方式使用:
from transformers import WhisperForConditionalGeneration, WhisperProcessor
from datasets import load_dataset
import librosa
import torch
LANG_ID = "ja"
MODEL_ID = "Ivydata/whisper-base-japanese"
SAMPLES = 10
test_dataset = load_dataset("common_voice", LANG_ID, split=f"test[:{SAMPLES}]")
processor = WhisperProcessor.from_pretrained("openai/whisper-base")
model = WhisperForConditionalGeneration.from_pretrained(MODEL_ID)
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(
language="ja", task="transcribe"
)
model.config.suppress_tokens = []
def speech_file_to_array_fn(batch):
speech_array, sampling_rate = librosa.load(batch["path"], sr=16_000)
batch["speech"] = speech_array
batch["sentence"] = batch["sentence"].upper()
batch["sampling_rate"] = sampling_rate
return batch
test_dataset = test_dataset.map(speech_file_to_array_fn)
sample = test_dataset[0]
input_features = processor(sample["speech"], sampling_rate=sample["sampling_rate"], return_tensors="pt").input_features
predicted_ids = model.generate(input_features)
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=False)
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
💻 使用示例
基础用法
from transformers import WhisperForConditionalGeneration, WhisperProcessor
from datasets import load_dataset
import librosa
import torch
LANG_ID = "ja"
MODEL_ID = "Ivydata/whisper-base-japanese"
SAMPLES = 10
test_dataset = load_dataset("common_voice", LANG_ID, split=f"test[:{SAMPLES}]")
processor = WhisperProcessor.from_pretrained("openai/whisper-base")
model = WhisperForConditionalGeneration.from_pretrained(MODEL_ID)
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(
language="ja", task="transcribe"
)
model.config.suppress_tokens = []
def speech_file_to_array_fn(batch):
speech_array, sampling_rate = librosa.load(batch["path"], sr=16_000)
batch["speech"] = speech_array
batch["sentence"] = batch["sentence"].upper()
batch["sampling_rate"] = sampling_rate
return batch
test_dataset = test_dataset.map(speech_file_to_array_fn)
sample = test_dataset[0]
input_features = processor(sample["speech"], sampling_rate=sample["sampling_rate"], return_tensors="pt").input_features
predicted_ids = model.generate(input_features)
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=False)
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
📚 详细文档
测试结果
以下表格展示了模型在 TEDxJP-10K 数据集上测试的字符错误率(CER):
模型 |
字符错误率(CER) |
Ivydata/whisper-small-japanese |
27.25% |
Ivydata/wav2vec2-large-xlsr-53-japanese |
27.87% |
jonatasgrosman/wav2vec2-large-xlsr-53-japanese |
34.18% |
📄 许可证
本项目采用 Apache-2.0 许可证。