🚀 R1 - AQA --- 強化学習が教師あり微調整を上回る:音声質問応答のケーススタディ
R1 - AQAは、Qwen2 - Audio - 7B - Instruct
をベースとした音声質問応答(AQA)モデルです。グループ相対ポリシー最適化(GRPO)アルゴリズムを用いた強化学習によって最適化されています。この実装では、わずか38kの事後トレーニングサンプルで、MMAUベンチマークにおいて最先端の性能を達成しています。詳細については、Githubと技術レポートを参照してください。
🚀 クイックスタート
R1 - AQAは、音声質問応答タスクにおいて、強化学習によって最適化されたモデルです。このモデルは、Qwen2 - Audio - 7B - Instruct
をベースに、GRPOアルゴリズムを用いて学習されています。
✨ 主な機能
- GRPOアルゴリズムを音声モダリティに直接適用でき、8.2Bパラメータの
Qwen2 - Audio - 7B - Instruct
にも有効です。
- わずか38kの事後トレーニングサンプルで、強化学習が教師あり微調整を上回り、大規模なデータセットがなくてもRLベースのアプローチが有効であることを示しています。
- 明示的な推論プロセスはAQAタスクに大きな恩恵をもたらさず、「深い思考」や段階的な推論を効率的に活用する方法は今後の研究課題です。
- 大規模音声言語モデル(LALMs)は依然として人間の聴覚 - 言語推論能力に大きく劣っており、RLベースのアプローチのさらなる探索が必要です。
📚 ドキュメント
導入
R1 - AQAは、Qwen2 - Audio - 7B - Instruct
をベースとした音声質問応答(AQA)モデルで、グループ相対ポリシー最適化(GRPO)アルゴリズムを用いた強化学習によって最適化されています。この実装では、わずか38kの事後トレーニングサンプルで、MMAUベンチマークにおいて最先端の性能を達成しています。
主な発見は以下の通りです:
- GRPOアルゴリズムは、音声モダリティに直接かつ効果的に適用でき、8.2Bパラメータの
Qwen2 - Audio - 7B - Instruct
にも有効です。
- わずか38kの事後トレーニングサンプルで、強化学習が教師あり微調整を上回り、大規模なデータセットがなくてもRLベースのアプローチが有効であることを示しています。
- 明示的な推論プロセスはAQAタスクに大きな恩恵をもたらさず、「深い思考」や段階的な推論を効率的に活用する方法は今後の研究課題です。
- 大規模音声言語モデル(LALMs)は依然として人間の聴覚 - 言語推論能力に大きく劣っており、RLベースのアプローチのさらなる探索が必要です。
追加の注意事項:
- AVQAトレーニングセットはもともと約40kのサンプルで構成されていましたが、一部のデータソースが無効になったため、約38kのサンプルのみを使用しています。YouTubeソースを使用する他のデータセット(AudioSetなど)も同様の問題に直面しています。欠落した2kのサンプルはトレーニング結果に大きな影響を与えないと考えています。
- 8.2Bパラメータに関する記述は、Qwen2 - Audio Technical Reportに基づいています。
表:MMAUベンチマークにおける精度(%)
モデル |
方法 |
Test - mini |
Test |
Test - mini |
Test |
Test - mini |
Test |
Test - mini |
Test |
- |
人間* |
86.31 |
- |
78.22 |
- |
82.17 |
- |
82.23 |
- |
Gemini Pro 2.0 Flash |
直接推論* |
56.46 |
61.73 |
58.68 |
56.53 |
51.65 |
61.53 |
55.60 |
59.93 |
Audio Flamingo 2 |
直接推論* |
61.56 |
65.10 |
73.95 |
72.90 |
30.93 |
40.26 |
55.48 |
59.42 |
GPT4o + 強力なキャプション |
直接推論* |
57.35 |
55.83 |
49.70 |
51.73 |
64.86 |
68.66 |
57.30 |
58.74 |
Llama - 3 - 8B - Instruct + 強力なキャプション |
直接推論* |
50.75 |
49.10 |
48.93 |
48.93 |
55.25 |
62.70 |
52.10 |
53.57 |
Qwen2 - Audio - 7B - Instruct |
直接推論* |
54.95 |
45.90 |
50.98 |
53.26 |
42.04 |
45.90 |
49.20 |
52.50 |
SALAMONN |
直接推論* |
41.00 |
40.30 |
34.80 |
33.76 |
25.50 |
24.24 |
33.70 |
32.77 |
Qwen2 - Audio - 7B - Instruct |
CoTA [1] |
60.06 |
- |
64.30 |
- |
60.70 |
- |
61.71 |
- |
Qwen2 - Audio - 7B - Instruct |
Zero - Shot - CoT [2] |
61.86 |
- |
56.29 |
- |
55.26 |
- |
57.80 |
- |
Qwen2 - Audio - 7B - Instruct |
GRPO (Ours) 1️⃣ |
69.37 |
- |
66.77 |
- |
57.36 |
- |
64.50 |
- |
Qwen2 - Audio - 7B - Instruct |
GRPO (Ours) 2️⃣ |
68.77 |
69.76 |
64.37 |
61.40 |
63.66 |
62.70 |
65.60 |
64.36 |
注釈
* データはMMAUリーダーボードから取得しています。
[1] Xie, Zhifei, et al. "Audio - Reasoner: Improving Reasoning Capability in Large Audio Language Models." arXiv preprint arXiv:2503.02318 (2025).
[2] Ma, Ziyang, et al. "Audio - CoT: Exploring Chain - of - Thought Reasoning in Large Audio Language Model." arXiv preprint arXiv:2501.07246 (2025).
1️⃣ これは元のモデルで、Hugging Face上のものと同じで、技術レポートに記載されています。
2️⃣ これはMMAUリーダーボードに提出されたモデルで、バランスの取れた結果を得るために複数回トレーニングされています。
💻 使用例
基本的な使用法
import torch
import torchaudio
from transformers import Qwen2AudioForConditionalGeneration, AutoProcessor
model_name = "mispeech/r1-aqa"
processor = AutoProcessor.from_pretrained(model_name)
model = Qwen2AudioForConditionalGeneration.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto")
wav_path = "test-mini-audios/3fe64f3d-282c-4bc8-a753-68f8f6c35652.wav"
waveform, sampling_rate = torchaudio.load(wav_path)
if sampling_rate != 16000:
waveform = torchaudio.transforms.Resample(orig_freq=sampling_rate, new_freq=16000)(waveform)
audios = [waveform[0].numpy()]
question = "Based on the given audio, identify the source of the speaking voice."
options = ["Man", "Woman", "Child", "Robot"]
prompt = f"{question} Please choose the answer from the following options: {str(options)}. Output the final answer in <answer> </answer>."
message = [
{"role": "user", "content": [
{"type": "audio", "audio_url": wav_path},
{"type": "text", "text": prompt}
]}
]
texts = processor.apply_chat_template(message, add_generation_prompt=True, tokenize=False)
inputs = processor(text=texts, audios=audios, sampling_rate=16000, return_tensors="pt", padding=True).to(model.device)
generated_ids = model.generate(**inputs, max_new_tokens=256)
generated_ids = generated_ids[:, inputs.input_ids.size(1):]
response = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
print(response)
📄 ライセンス
このプロジェクトは、Apache 2.0ライセンスの下で公開されています。
引用
@article{li2025reinforcement,
title={Reinforcement Learning Outperforms Supervised Fine-Tuning: A Case Study on Audio Question Answering},
author={Li, Gang and Liu, Jizhong and Dinkel, Heinrich and Niu, Yadong and Zhang, Junbo and Luan, Jian},
journal={arXiv preprint arXiv:2503.11197},
year={2025},
url={https://github.com/xiaomi-research/r1-aqa; https://huggingface.co/mispeech/r1-aqa}
}