模型简介
模型特点
模型能力
使用案例
🚀 Airavata
Airavata 是一个 70 亿参数的语言模型,它基于 OpenHathi 模型,并在 IndicInstruct 数据集 上进行了微调。该数据集集合了多种指令数据集,包括 Anudesh、wikiHow、Flan v2、Dolly、Anthropic-HHH、OpenAssistant v1 和 LymSys-Chat。此模型的相关研究在技术报告 Airavata: Introducing Hindi Instruction-tuned LLM 中有详细介绍,训练和评估该模型的代码库可在 https://github.com/AI4Bharat/IndicInstruct 找到。
🚀 快速开始
环境准备
克隆 https://github.com/AI4Bharat/IndicInstruct 代码库,并安装所需的依赖项。然后将本模型下载或克隆到同一台机器上。
输入格式
该模型采用类似于 open-instruct 代码库 的对话格式进行训练(注意换行符):
<|user|>
您的消息内容!
<|assistant|>
为获得最佳效果,请按照此格式格式化所有输入。请确保在 <|assistant|>
后包含换行符,这对生成质量有较大影响。
✨ 主要特性
- 多语言支持:支持英语(en)和印地语(hi)等多种语言。
- 指令微调:在 IndicInstruct 数据集上进行了微调,能够更好地理解和响应指令。
- 基于 LoRA 微调:使用 LoRA 技术对 OpenHathi 基础模型进行微调,提高了模型的训练效率和性能。
📦 安装指南
克隆代码库并安装依赖:
git clone https://github.com/AI4Bharat/IndicInstruct
cd IndicInstruct
pip install -r requirements.txt
下载或克隆模型到本地:
git clone https://huggingface.co/ai4bharat/Airavata
💻 使用示例
基础用法
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
device = "cuda" if torch.cuda.is_available() else "cpu"
def create_prompt_with_chat_format(messages, bos="<s>", eos="</s>", add_bos=True):
formatted_text = ""
for message in messages:
if message["role"] == "system":
formatted_text += "<|system|>\n" + message["content"] + "\n"
elif message["role"] == "user":
formatted_text += "<|user|>\n" + message["content"] + "\n"
elif message["role"] == "assistant":
formatted_text += "<|assistant|>\n" + message["content"].strip() + eos + "\n"
else:
raise ValueError(
"Tulu chat template only supports 'system', 'user' and 'assistant' roles. Invalid role: {}.".format(
message["role"]
)
)
formatted_text += "<|assistant|>\n"
formatted_text = bos + formatted_text if add_bos else formatted_text
return formatted_text
def inference(input_prompts, model, tokenizer):
input_prompts = [
create_prompt_with_chat_format([{"role": "user", "content": input_prompt}], add_bos=False)
for input_prompt in input_prompts
]
encodings = tokenizer(input_prompts, padding=True, return_tensors="pt")
encodings = encodings.to(device)
with torch.inference_mode():
outputs = model.generate(encodings.input_ids, do_sample=False, max_new_tokens=250)
output_texts = tokenizer.batch_decode(outputs.detach(), skip_special_tokens=True)
input_prompts = [
tokenizer.decode(tokenizer.encode(input_prompt), skip_special_tokens=True) for input_prompt in input_prompts
]
output_texts = [output_text[len(input_prompt) :] for input_prompt, output_text in zip(input_prompts, output_texts)]
return output_texts
model_name = "ai4bharat/Airavata"
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16).to(device)
input_prompts = [
"मैं अपने समय प्रबंधन कौशल को कैसे सुधार सकता हूँ? मुझे पांच बिंदु बताएं।",
"मैं अपने समय प्रबंधन कौशल को कैसे सुधार सकता हूँ? मुझे पांच बिंदु बताएं और उनका वर्णन करें।",
]
outputs = inference(input_prompts, model, tokenizer)
print(outputs)
📚 详细文档
超参数设置
我们使用 LoRA 技术在上述 IndicInstruct 数据集上对 OpenHathi 基础模型进行微调,LoRA 微调的超参数如下:
参数 | 值 |
---|---|
LoRA Rank | 16 |
LoRA alpha | 32 |
LoRA Dropout | 0.05 |
LoRA Target Modules | ["q_proj", "v_proj", "k_proj", "down_proj", "gate_proj", "up_proj"] |
Epochs | 4 |
Learning rate | 5e-4 |
Batch Size | 128 |
Floating Point Precision | bfloat16 |
更多关于模型训练、消融实验和评估结果的详细信息,请参考 我们的官方博客文章。
引用信息
@article{gala2024airavata,
title = {Airavata: Introducing Hindi Instruction-tuned LLM},
author = {Jay Gala and Thanmay Jayakumar and Jaavid Aktar Husain and Aswanth Kumar M and Mohammed Safi Ur Rahman Khan and Diptesh Kanojia and Ratish Puduppully and Mitesh M. Khapra and Raj Dabre and Rudra Murthy and Anoop Kunchukuttan},
year = {2024},
journal = {arXiv preprint arXiv: 2401.15006}
}
评估结果
详细的评估结果可在 Open LLM Leaderboard 中查看,具体结果如下:
指标 | 值 |
---|---|
平均 | 45.52 |
AI2 Reasoning Challenge (25-Shot) | 46.50 |
HellaSwag (10-Shot) | 69.26 |
MMLU (5-Shot) | 43.90 |
TruthfulQA (0-shot) | 40.62 |
Winogrande (5-shot) | 68.82 |
GSM8k (5-shot) | 4.02 |
🔧 技术细节
该模型基于 OpenHathi 基础模型,使用 LoRA 技术在 IndicInstruct 数据集上进行微调。LoRA 技术通过引入低秩自适应矩阵,在不改变原模型参数的情况下,对模型进行高效微调,从而减少了训练所需的计算资源和时间。在微调过程中,我们使用了特定的超参数设置,以确保模型能够学习到数据集中的指令信息,提高模型的指令跟随能力。
📄 许可证
本模型使用 Llama 2 许可证。



