🚀 wav2vec 2.0 XLS - R 1B + TEVRトークン + 5 - gram LMによるドイツ語音声認識パイプライン
このパイプラインは、新しいwav2vec 2.0 XLS - R 1B TEVRアーキテクチャを用いた音響モデルと5 - gram KenLM言語モデルから構成される、完全に学習されたドイツ語音声認識パイプラインです。CommonVoiceドイツ語データセットにおいて、非常に競争力のあるパフォーマンスを示します。
🚀 クイックスタート
このパイプラインの概要や評価方法、引用方法などの詳細を以下に説明します。
✨ 主な機能
- 新しいwav2vec 2.0 XLS - R 1B TEVRアーキテクチャを用いた音響モデル。
- 5 - gram KenLM言語モデルを組み合わせた音声認識パイプライン。
- CommonVoiceドイツ語データセットでの低い単語誤り率(WER)と文字誤り率(CER)。
📚 ドキュメント
概要
このフォルダには、新しいwav2vec 2.0 XLS - R 1B TEVRアーキテクチャを用いた音響モデルと5 - gram KenLM言語モデルから構成される、完全に学習されたドイツ語音声認識パイプラインが含まれています。TEVRの強化点とその動機についての説明は、以下の論文を参照してください。
TEVR: Improving Speech Recognition by Token Entropy Variance Reduction
このパイプラインは、CommonVoiceドイツ語データセットで(2022年6月時点で)非常に競争力のある**単語誤り率3.64%**を達成しています。文字誤り率は1.54%でした。
引用
この音声認識パイプラインを研究に使用する場合は、以下を引用してください。
@misc{https://doi.org/10.48550/arxiv.2206.12693,
doi = {10.48550/ARXIV.2206.12693},
url = {https://arxiv.org/abs/2206.12693},
author = {Krabbenhöft, Hajo Nils and Barth, Erhardt},
keywords = {Computation and Language (cs.CL), Sound (cs.SD), Audio and Speech Processing (eess.AS), FOS: Computer and information sciences, FOS: Computer and information sciences, FOS: Electrical engineering, electronic engineering, information engineering, FOS: Electrical engineering, electronic engineering, information engineering, F.2.1; I.2.6; I.2.7},
title = {TEVR: Improving Speech Recognition by Token Entropy Variance Reduction},
publisher = {arXiv},
year = {2022},
copyright = {Creative Commons Attribution 4.0 International}
}
TEVRトークナイザの作成とテスト
以下のリンクを参照してください。
評価
このパイプラインを自分で評価する場合や、独自のデータで評価する場合は、HF Eval Script.ipynb
のJupyter Notebookを参照するか、以下のPythonスクリプトを使用してください。
💻 使用例
基本的な使用法
!pip install --quiet --root-user-action=ignore --upgrade pip
!pip install --quiet --root-user-action=ignore "datasets>=1.18.3" "transformers==4.11.3" librosa jiwer huggingface_hub
!pip install --quiet --root-user-action=ignore https://github.com/kpu/kenlm/archive/master.zip pyctcdecode
!pip install --quiet --root-user-action=ignore --upgrade transformers
!pip install --quiet --root-user-action=ignore torch_audiomentations audiomentations
高度な使用法
from datasets import load_dataset, Audio, load_metric
from transformers import AutoModelForCTC, Wav2Vec2ProcessorWithLM
import torchaudio.transforms as T
import torch
import unicodedata
import numpy as np
import re
testing_dataset = load_dataset("common_voice", "de", split="test")
allchars = list(set([c for t in testing_dataset['sentence'] for c in list(t)]))
map_to_space = [c for c in allchars if unicodedata.category(c)[0] in 'PSZ' and c not in 'ʻ-']
replacements = ''.maketrans(''.join(map_to_space), ''.join(' ' for i in range(len(map_to_space))), '\'ʻ')
def text_fix(text):
text = text.replace('ß','ss')
text = text.replace('-',' ').replace(' ',' ').replace(' ',' ')
text = text.lower()
text = text.translate(replacements).strip()
text = re.sub("[âşěýňעảנźțãòàǔł̇æồאắîשðșęūāñë生בøúıśžçćńřğ]+","?",text)
text = ' '.join([w for w in text.split(' ') if w != ''])
return text
model = AutoModelForCTC.from_pretrained("fxtentacle/wav2vec2-xls-r-1b-tevr")
model.to('cuda')
class HajoProcessor(Wav2Vec2ProcessorWithLM):
@staticmethod
def get_missing_alphabet_tokens(decoder, tokenizer):
return []
processor = HajoProcessor.from_pretrained("fxtentacle/wav2vec2-xls-r-1b-tevr")
def predict_single_audio(batch, image=False):
audio = batch['audio']['array']
if batch['audio']['sampling_rate'] != 16000:
audio = T.Resample(orig_freq=batch['audio']['sampling_rate'], new_freq=16000)(torch.from_numpy(audio)).numpy()
audio = (audio - audio.mean()) / np.sqrt(audio.var() + 1e-7)
input_values = processor(audio, return_tensors="pt", sampling_rate=16_000).input_values
with torch.no_grad():
logits = model(input_values.to('cuda')).logits.cpu().numpy()[0]
decoded = processor.decode(logits, beam_width=500)
return { 'groundtruth': text_fix(batch['sentence']), 'prediction': decoded.text }
all_predictions = testing_dataset.map(predict_single_audio, remove_columns=testing_dataset.column_names)
print('WER', load_metric("wer").compute(predictions=all_predictions['prediction'], references=all_predictions['groundtruth'])*100.0, '%')
print('CER', load_metric("cer").compute(predictions=all_predictions['prediction'], references=all_predictions['groundtruth'])*100.0, '%')
WER 3.6433399042523233 %
CER 1.5398893560981173 %
📄 ライセンス
このモデルはApache - 2.0ライセンスの下で提供されています。