🚀 Wav2Vec2-Large-Tedlium
このモデルは、TEDLIUMコーパスでファインチューニングされたWav2Vec2の大規模モデルです。TEDの講演音声を高精度に文字起こしすることができます。
🚀 クイックスタート
このモデルは、LibriVoxプロジェクトの60,000時間のオーディオブックで事前学習されたFacebookのWav2Vec2 large LV - 60kチェックポイントで初期化されています。そして、TEDLIUMコーパス(リリース3)の452時間のTED講演でファインチューニングされています。モデルを使用する際には、音声入力が16Khzでサンプリングされていることを確認してください。
このモデルは、開発セットで8.4%、テストセットで8.2%の単語誤り率(WER)を達成しています。トレーニングログには、50kステップのファインチューニングにおけるトレーニングと評価の進捗が記録されています。
このモデルがどのようにファインチューニングされたかについての詳細は、このノートブックを参照してください。
✨ 主な機能
- TEDの講演音声を高精度に文字起こしできます。
- 事前学習された大規模モデルをベースに、TEDLIUMコーパスでファインチューニングされています。
💻 使用例
基本的な使用法
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
from datasets import load_dataset
import torch
processor = Wav2Vec2Processor.from_pretrained("sanchit-gandhi/wav2vec2-large-tedlium")
model = Wav2Vec2ForCTC.from_pretrained("sanchit-gandhi/wav2vec2-large-tedlium")
ds = load_dataset("sanchit-gandhi/tedlium_dummy", split="validation")
input_values = processor(ds[0]["audio"]["array"], return_tensors="pt", padding="longest").input_values
logits = model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)
print("Target: ", ds["text"][0])
print("Transcription: ", transcription[0])
高度な使用法
from datasets import load_dataset
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
import torch
from jiwer import wer
tedlium_eval = load_dataset("LIUM/tedlium", "release3", split="test")
model = Wav2Vec2ForCTC.from_pretrained("sanchit-gandhi/wav2vec2-large-tedlium").to("cuda")
processor = Wav2Vec2Processor.from_pretrained("sanchit-gandhi/wav2vec2-large-tedlium")
def map_to_pred(batch):
input_values = processor(batch["audio"]["array"], return_tensors="pt", padding="longest").input_values
with torch.no_grad():
logits = model(input_values.to("cuda")).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)
batch["transcription"] = transcription
return batch
result = tedlium_eval.map(map_to_pred, batched=True, batch_size=1, remove_columns=["speech"])
print("WER:", wer(result["text"], result["transcription"]))
📄 ライセンス
このモデルは、Apache 2.0ライセンスの下で提供されています。