模型概述
模型特點
模型能力
使用案例
🚀 Airavata
Airavata 是一個 70 億參數的語言模型,它基於 OpenHathi 模型,並在 IndicInstruct 數據集 上進行了微調。該數據集集合了多種指令數據集,包括 Anudesh、wikiHow、Flan v2、Dolly、Anthropic-HHH、OpenAssistant v1 和 LymSys-Chat。此模型的相關研究在技術報告 Airavata: Introducing Hindi Instruction-tuned LLM 中有詳細介紹,訓練和評估該模型的代碼庫可在 https://github.com/AI4Bharat/IndicInstruct 找到。
🚀 快速開始
環境準備
克隆 https://github.com/AI4Bharat/IndicInstruct 代碼庫,並安裝所需的依賴項。然後將本模型下載或克隆到同一臺機器上。
輸入格式
該模型採用類似於 open-instruct 代碼庫 的對話格式進行訓練(注意換行符):
<|user|>
您的消息內容!
<|assistant|>
為獲得最佳效果,請按照此格式格式化所有輸入。請確保在 <|assistant|>
後包含換行符,這對生成質量有較大影響。
✨ 主要特性
- 多語言支持:支持英語(en)和印地語(hi)等多種語言。
- 指令微調:在 IndicInstruct 數據集上進行了微調,能夠更好地理解和響應指令。
- 基於 LoRA 微調:使用 LoRA 技術對 OpenHathi 基礎模型進行微調,提高了模型的訓練效率和性能。
📦 安裝指南
克隆代碼庫並安裝依賴:
git clone https://github.com/AI4Bharat/IndicInstruct
cd IndicInstruct
pip install -r requirements.txt
下載或克隆模型到本地:
git clone https://huggingface.co/ai4bharat/Airavata
💻 使用示例
基礎用法
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
device = "cuda" if torch.cuda.is_available() else "cpu"
def create_prompt_with_chat_format(messages, bos="<s>", eos="</s>", add_bos=True):
formatted_text = ""
for message in messages:
if message["role"] == "system":
formatted_text += "<|system|>\n" + message["content"] + "\n"
elif message["role"] == "user":
formatted_text += "<|user|>\n" + message["content"] + "\n"
elif message["role"] == "assistant":
formatted_text += "<|assistant|>\n" + message["content"].strip() + eos + "\n"
else:
raise ValueError(
"Tulu chat template only supports 'system', 'user' and 'assistant' roles. Invalid role: {}.".format(
message["role"]
)
)
formatted_text += "<|assistant|>\n"
formatted_text = bos + formatted_text if add_bos else formatted_text
return formatted_text
def inference(input_prompts, model, tokenizer):
input_prompts = [
create_prompt_with_chat_format([{"role": "user", "content": input_prompt}], add_bos=False)
for input_prompt in input_prompts
]
encodings = tokenizer(input_prompts, padding=True, return_tensors="pt")
encodings = encodings.to(device)
with torch.inference_mode():
outputs = model.generate(encodings.input_ids, do_sample=False, max_new_tokens=250)
output_texts = tokenizer.batch_decode(outputs.detach(), skip_special_tokens=True)
input_prompts = [
tokenizer.decode(tokenizer.encode(input_prompt), skip_special_tokens=True) for input_prompt in input_prompts
]
output_texts = [output_text[len(input_prompt) :] for input_prompt, output_text in zip(input_prompts, output_texts)]
return output_texts
model_name = "ai4bharat/Airavata"
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16).to(device)
input_prompts = [
"मैं अपने समय प्रबंधन कौशल को कैसे सुधार सकता हूँ? मुझे पांच बिंदु बताएं।",
"मैं अपने समय प्रबंधन कौशल को कैसे सुधार सकता हूँ? मुझे पांच बिंदु बताएं और उनका वर्णन करें।",
]
outputs = inference(input_prompts, model, tokenizer)
print(outputs)
📚 詳細文檔
超參數設置
我們使用 LoRA 技術在上述 IndicInstruct 數據集上對 OpenHathi 基礎模型進行微調,LoRA 微調的超參數如下:
參數 | 值 |
---|---|
LoRA Rank | 16 |
LoRA alpha | 32 |
LoRA Dropout | 0.05 |
LoRA Target Modules | ["q_proj", "v_proj", "k_proj", "down_proj", "gate_proj", "up_proj"] |
Epochs | 4 |
Learning rate | 5e-4 |
Batch Size | 128 |
Floating Point Precision | bfloat16 |
更多關於模型訓練、消融實驗和評估結果的詳細信息,請參考 我們的官方博客文章。
引用信息
@article{gala2024airavata,
title = {Airavata: Introducing Hindi Instruction-tuned LLM},
author = {Jay Gala and Thanmay Jayakumar and Jaavid Aktar Husain and Aswanth Kumar M and Mohammed Safi Ur Rahman Khan and Diptesh Kanojia and Ratish Puduppully and Mitesh M. Khapra and Raj Dabre and Rudra Murthy and Anoop Kunchukuttan},
year = {2024},
journal = {arXiv preprint arXiv: 2401.15006}
}
評估結果
詳細的評估結果可在 Open LLM Leaderboard 中查看,具體結果如下:
指標 | 值 |
---|---|
平均 | 45.52 |
AI2 Reasoning Challenge (25-Shot) | 46.50 |
HellaSwag (10-Shot) | 69.26 |
MMLU (5-Shot) | 43.90 |
TruthfulQA (0-shot) | 40.62 |
Winogrande (5-shot) | 68.82 |
GSM8k (5-shot) | 4.02 |
🔧 技術細節
該模型基於 OpenHathi 基礎模型,使用 LoRA 技術在 IndicInstruct 數據集上進行微調。LoRA 技術通過引入低秩自適應矩陣,在不改變原模型參數的情況下,對模型進行高效微調,從而減少了訓練所需的計算資源和時間。在微調過程中,我們使用了特定的超參數設置,以確保模型能夠學習到數據集中的指令信息,提高模型的指令跟隨能力。
📄 許可證
本模型使用 Llama 2 許可證。



