🚀 Airavata
このモデルは、7BのOpenHathiモデルをIndicInstructデータセットでファインチューニングしたものです。IndicInstructデータセットは、命令データセット(Anudesh、wikiHow、Flan v2、Dolly、Anthropic - HHH、OpenAssistant v1、およびLymSys - Chat)のコレクションです。詳細については、対応するHugging Faceデータセットカードを確認してください。
このモデルは、技術レポート《Airavata: Introducing Hindi Instruction - tuned LLM》の一環としてトレーニングされました。このモデルをトレーニングおよび評価するために使用されたコードベースは、https://github.com/AI4Bharat/IndicInstructで見つけることができます。
🚀 クイックスタート
https://github.com/AI4Bharat/IndicInstructをクローンし、必要な依存関係をインストールします。次に、このモデルを同じマシンにダウンロードまたはクローンします。
✨ 主な機能
- 複数言語対応(英語、ヒンディー語など)
- 命令チューニングによる高性能なテキスト生成
💻 使用例
基本的な使用法
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
device = "cuda" if torch.cuda.is_available() else "cpu"
def create_prompt_with_chat_format(messages, bos="<s>", eos="</s>", add_bos=True):
formatted_text = ""
for message in messages:
if message["role"] == "system":
formatted_text += "<|system|>\n" + message["content"] + "\n"
elif message["role"] == "user":
formatted_text += "<|user|>\n" + message["content"] + "\n"
elif message["role"] == "assistant":
formatted_text += "<|assistant|>\n" + message["content"].strip() + eos + "\n"
else:
raise ValueError(
"Tulu chat template only supports 'system', 'user' and 'assistant' roles. Invalid role: {}.".format(
message["role"]
)
)
formatted_text += "<|assistant|>\n"
formatted_text = bos + formatted_text if add_bos else formatted_text
return formatted_text
def inference(input_prompts, model, tokenizer):
input_prompts = [
create_prompt_with_chat_format([{"role": "user", "content": input_prompt}], add_bos=False)
for input_prompt in input_prompts
]
encodings = tokenizer(input_prompts, padding=True, return_tensors="pt")
encodings = encodings.to(device)
with torch.inference_mode():
outputs = model.generate(encodings.input_ids, do_sample=False, max_new_tokens=250)
output_texts = tokenizer.batch_decode(outputs.detach(), skip_special_tokens=True)
input_prompts = [
tokenizer.decode(tokenizer.encode(input_prompt), skip_special_tokens=True) for input_prompt in input_prompts
]
output_texts = [output_text[len(input_prompt) :] for input_prompt, output_text in zip(input_prompts, output_texts)]
return output_texts
model_name = "ai4bharat/Airavata"
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16).to(device)
input_prompts = [
"मैं अपने समय प्रबंधन कौशल को कैसे सुधार सकता हूँ? मुझे पांच बिंदु बताएं।",
"मैं अपने समय प्रबंधन कौशल को कैसे सुधार सकता हूँ? मुझे पांच बिंदु बताएं और उनका वर्णन करें।",
]
outputs = inference(input_prompts, model, tokenizer)
print(outputs)
📚 ドキュメント
入力形式
このモデルは、[open - instructコードリポジトリ](https://github.com/allenai/open - instruct)と同様のチャット形式(改行に注意)を使用するようにトレーニングされています。
<|user|>
Your message here!
<|assistant|>
最良の結果を得るためには、すべての入力をこの形式でフォーマットしてください。<|assistant|>
の後に改行を含めることを確認してください。これは生成品質にかなり影響を与える可能性があります。
ハイパーパラメータ
私たちは、前述のIndicInstructデータセットでLoRAを使用してOpenHathiベースモデルをファインチューニングしました。LoRAファインチューニングのハイパーパラメータは以下の通りです。
- LoRA Rank: 16
- LoRA alpha: 32
- LoRA Dropout: 0.05
- LoRA Target Modules: ["q_proj", "v_proj", "k_proj", "down_proj", "gate_proj", "up_proj"]
- Epochs: 4
- Learning rate: 5e - 4
- Batch Size: 128
- Floating Point Precision: bfloat16
モデルのトレーニング、アブレーション、および評価結果の詳細については、公式ブログ記事をご覧ください。
引用
@article{gala2024airavata,
title = {Airavata: Introducing Hindi Instruction-tuned LLM},
author = {Jay Gala and Thanmay Jayakumar and Jaavid Aktar Husain and Aswanth Kumar M and Mohammed Safi Ur Rahman Khan and Diptesh Kanojia and Ratish Puduppully and Mitesh M. Khapra and Raj Dabre and Rudra Murthy and Anoop Kunchukuttan},
year = {2024},
journal = {arXiv preprint arXiv: 2401.15006}
}
詳細な結果はこちらで確認できます。
メトリクス |
値 |
平均 |
45.52 |
AI2 Reasoning Challenge (25 - Shot) |
46.50 |
HellaSwag (10 - Shot) |
69.26 |
MMLU (5 - Shot) |
43.90 |
TruthfulQA (0 - shot) |
40.62 |
Winogrande (5 - shot) |
68.82 |
GSM8k (5 - shot) |
4.02 |
📄 ライセンス
このモデルは、llama2ライセンスの下で提供されています。