模型简介
模型特点
模型能力
使用案例
🚀 TinyLLaVA:小规模大多模态模型框架
TinyLLaVA是一个专注于小规模大多模态模型的框架,能够在减少参数的情况下实现高性能表现,为多模态模型的研究和应用提供了新的思路和方法。
🚀 快速开始
环境准备
我们推荐按照以下要求准备环境:
- 克隆仓库并进入LLaVA文件夹
git clone https://github.com/DLCV-BUAA/TinyLLaVABench.git
cd TinyLLaVABench
- 安装基础包
conda create -n tinyllava python=3.10 -y
conda activate tinyllava
pip install --upgrade pip # 启用PEP 660支持
pip install -e .
- 为训练场景安装额外的包
pip install -e ".[train]"
pip install flash-attn --no-build-isolation
代码更新
git pull
pip install -e .
# 如果你在更新时遇到导入错误,请尝试运行以下命令(去掉#)
# pip install flash-attn --no-build-isolation --no-cache-dir
✨ 主要特性
⚡ 高性能与轻量级参数
我们表现最优的模型TinyLLaVA - 3.1B,相较于现有的7B模型(如LLaVA - 1.5和Qwen - VL),在整体性能上表现更优。
📈 多数据集支持
在论文中,使用了[LLaVA数据集](https://github.com/haotian - liu/LLaVA?tab=readme - ov - file#pretrain - feature - alignment)和[ShareGPT4V数据集](https://github.com/InternLM/InternLM - XComposer/blob/main/projects/ShareGPT4V/docs/Data.md),并对比了它们的差异。
🎨 多模态支持
模型支持多图像和多提示生成,使用时需遵循正确的提示模板 (USER: <image>xxx\nASSISTANT:
),其中 <image>
是图像嵌入的占位特殊令牌。
📦 安装指南
克隆仓库
git clone https://github.com/DLCV-BUAA/TinyLLaVABench.git
cd TinyLLaVABench
创建并激活虚拟环境
conda create -n tinyllava python=3.10 -y
conda activate tinyllava
安装必要的包
pip install --upgrade pip
pip install -e .
安装训练所需的额外包
pip install -e ".[train]"
pip install flash-attn --no-build-isolation
💻 使用示例
基础用法
Gradio Web演示
运行以下命令启动本地Web演示:
python tinyllava/serve/app.py --model-path bczhou/TinyLLaVA-3.1B --model-name TinyLLaVA-3.1B
CLI推理
我们也支持使用CLI进行推理。运行以下命令使用我们的模型:
python -m tinyllava.serve.cli \
--model-path bczhou/TinyLLaVA-3.1B \
--image-file "./tinyllava/serve/examples/extreme_ironing.jpg"
加载模型
from tinyllava.model.builder import load_pretrained_model
from tinyllava.mm_utils import get_model_name_from_path
from tinyllava.eval.run_tiny_llava import eval_model
model_path = "bczhou/TinyLLaVA-3.1B"
tokenizer, model, image_processor, context_len = load_pretrained_model(
model_path=model_path,
model_base=None,
model_name=get_model_name_from_path(model_path)
)
高级用法
运行推理
以下是使用 TinyLLaVA - 3.1B 进行推理的示例:
from tinyllava.model.builder import load_pretrained_model
from tinyllava.mm_utils import get_model_name_from_path
from tinyllava.eval.run_tiny_llava import eval_model
model_path = "bczhou/TinyLLaVA-3.1B"
prompt = "What are the things I should be cautious about when I visit here?"
image_file = "https://llava-vl.github.io/static/images/view.jpg"
args = type('Args', (), {
"model_path": model_path,
"model_base": None,
"model_name": get_model_name_from_path(model_path),
"query": prompt,
"conv_mode": "phi",
"image_file": image_file,
"sep": ",",
"temperature": 0,
"top_p": None,
"num_beams": 1,
"max_new_tokens": 512
})()
eval_model(args)
重要提示
不同的模型使用不同的 conv_mode
。请根据以下表格替换 args
中的 conv_mode
:
模型 | conv_mode |
---|---|
TinyLLaVA - 3.1B | phi |
TinyLLaVA - 2.0B | phi |
TinyLLaVA - 1.5B | v1 |
使用 pipeline
进行推理
from transformers import pipeline
from PIL import Image
import requests
model_id = "bczhou/tiny-llava-v1-hf"
pipe = pipeline("image-to-text", model=model_id)
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/ai2d-demo.jpg"
image = Image.open(requests.get(url, stream=True).raw)
prompt = "USER: <image>\nWhat does the label 15 represent? (1) lava (2) core (3) tunnel (4) ash cloud\nASSISTANT:"
outputs = pipe(image, prompt=prompt, generate_kwargs={"max_new_tokens": 200})
print(outputs[0])
使用纯 transformers
进行推理
import requests
from PIL import Image
import torch
from transformers import AutoProcessor, LlavaForConditionalGeneration
model_id = "bczhou/tiny-llava-v1-hf"
prompt = "USER: <image>\nWhat are these?\nASSISTANT:"
image_file = "http://images.cocodataset.org/val2017/000000039769.jpg"
model = LlavaForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
).to(0)
processor = AutoProcessor.from_pretrained(model_id)
raw_image = Image.open(requests.get(image_file, stream=True).raw)
inputs = processor(prompt, raw_image, return_tensors='pt').to(0, torch.float16)
output = model.generate(**inputs, max_new_tokens=200, do_sample=False)
print(processor.decode(output[0][2:], skip_special_tokens=True))
📚 详细文档
模型库
旧版模型
预训练模型
模型详情
名称 | 大语言模型 | 检查点 | LLaVA - Bench - Wild | MME | MMBench | MM - Vet | SQA - image | VQA - v2 | GQA | TextVQA |
---|---|---|---|---|---|---|---|---|---|---|
TinyLLaVA - 3.1B | Phi - 2 | TinyLLaVA - 3.1B | 75.8 | 1464.9 | 66.9 | 32.0 | 69.1 | 79.9 | 62.0 | 59.1 |
TinyLLaVA - 2.0B | StableLM - 2 - 1.6B | TinyLLaVA - 2.0B | 66.4 | 1433.8 | 63.3 | 32.6 | 64.7 | 78.9 | 61.9 | 56.4 |
TinyLLaVA - 1.5B | TinyLlama | TinyLLaVA - 1.5B | 60.8 | 1276.5 | 55.2 | 25.8 | 60.3 | 76.9 | 60.3 | 51.7 |
评估
为确保可重复性,我们使用贪心解码对模型进行评估。详情见 Evaluation.md。
数据准备
预训练图像
- LLaVA:预训练图像来自LAION - CC - SBU数据集的558K子集。
- ShareGPT4V:预训练图像是558K LAION - CC - SBU子集、SAM数据集和COCO数据集的混合。
预训练注释
SFT图像和注释
两个SFT数据集的大部分数据相同,不同之处在于LLaVA - 1.5 - SFT中的23K详细描述数据被从 100K ShareGPT4V数据 中随机采样的详细字幕替换。
数据下载
- 下载相关图像
- LAION - CC - SBU - 558K:images.zip
- COCO:该数据集来自 COCO2017挑战。下载:train2017
- WebData:该数据集由 ShareGPT4V项目 整理。下载:[images](https://drive.google.com/drive/folders/1tCUQ - sq6vdshZVkF0ZeF3K4eztkXJgax?usp=sharing)。仅用于学术用途。
- SAM:该数据集由 Meta 收集。下载:images。目前我们仅使用000000 ~ 000050.tar。如果你只想将ShareGPT4V用于SFT,可以从 此处 快速下载9K图像。
- GQA:GQA项目页面。下载:images
- OCR - VQA:[OCR - VQA项目页面](https://ocr - vqa.github.io/)。下载:下载脚本。我们将所有文件保存为
.jpg
- TextVQA:TextVQA项目页面。下载:trainvalimages
- VisualGenome:VisualGenome项目页面。下载:part1,part2
- 下载相关注释
- LLaVA的预训练注释:blip_laion_cc_sbu_558k.json
- LLaVA的SFT注释:llava_v1_5_mix665k.json
- ShareGPT4V的预训练注释:share - captioner_coco_lcs_sam_1246k_1107.json
- ShareGPT4V的SFT注释:sharegpt4v_mix665k_cap23k_coco - ap9k_lcs3k_sam9k_div2k.json
数据组织
在 path/to/your/data
中按以下方式组织图像文件和注释文件:
data
├── llava
│ ├── llava_pretrain
│ │ ├── images
│ │ ├── blip_laion_cc_sbu_558k.json
├── coco
│ ├── train2017
├── sam
│ ├── images
├── gqa
│ ├── images
├── ocr_vqa
│ ├── images
├── textvqa
│ ├── train_images
├── vg
│ ├── VG_100K
│ ├── VG_100K_2
├── share_textvqa
│ ├── images
├── web-celebrity
│ ├── images
├── web-landmark
│ ├── images
├── wikiart
│ ├── images
├── text_files
│ ├── llava_v1_5_mix665k.json
│ ├── share-captioner_coco_lcs_sam_1246k_1107.json
│ ├── sharegpt4v_mix665k_cap23k_coco-ap9k_lcs3k_sam9k_div2k.json
训练
超参数
以下是预训练和微调中使用的超参数:
- 预训练 | 超参数 | 全局批量大小 | 学习率 | 轮数 | 最大长度 | 权重衰减 | |----------------| ---: | ---: | ---: |-----------:| ---: | | TinyLLaVA - 3.1B | 256 | 1e - 3 | 1 | 3072 | 0 |
- 微调 | 超参数 | 全局批量大小 | 学习率 | 轮数 | 最大长度 | 权重衰减 | |----------------| ---: | ---: | ---: |-----------:| ---: | | TinyLLaVA - 3.1B | 128 | 2e - 5 | 1 | 3072 | 0 |
预训练
使用DeepSpeed ZeRO - 2的训练脚本:pretrain.sh
。请将路径替换为你自己的路径。
微调
使用DeepSpeed ZeRO - 3的训练脚本:finetune.sh
。请将路径替换为你自己的路径。
自定义微调
查看我们使用LoRA进行的自定义微调 此处。
🔧 技术细节
模型架构
TinyLLaVA基于不同的大语言模型(如Phi - 2、StableLM - 2 - 1.6B、TinyLlama)构建,通过特定的训练方法和数据进行优化,以实现多模态处理能力。
训练方法
使用了预训练和微调两个阶段。预训练阶段在大规模数据集上进行特征对齐,微调阶段则根据具体任务进行进一步优化。
推理机制
推理时,模型根据输入的图像和文本提示,通过特定的解码方式生成输出结果。不同的模型使用不同的 conv_mode
进行推理。
📄 许可证
本项目采用Apache 2.0许可证。
📢 最新消息
- [2024.03.10] 基础配方发布!
- [2024.03.10] 微调脚本发布!
- [2024.02.25] 更新评估脚本和文档!
- [2024.02.25] 数据描述发布。发布TinyLLaVA - 1.5B和TinyLLaVA - 2.0B!
- [2024.02.24] 添加推理和模型加载的示例代码!
- [2024.02.23] 评估代码和脚本发布!
- [2024.02.21] 在GitHub上创建 TinyLLaVABench 仓库!
- [2024.02.21] 我们的论文:TinyLLaVA: A Framework of Small - scale Large Multimodal Models 发布!
- [2024.01.11] 我们的第一个模型 TinyLLaVA - 1.4B 发布!
📋 待办事项
- [ ] 添加对Ollama和llama.cpp的支持。
- [x] 开发者指南/如何在本地构建演示。
- [x] 训练和自定义微调文档。
- [x] 模型库描述。
- [x] 示例和推理。
- [x] 发布训练代码。
- [x] 添加评估描述。
- [x] 添加数据准备描述。
- [x] 发布TinyLLaVA - 1.5B和TinyLLaVA - 2.0B。
- [x] 发布TinyLLaVA - 3.1B。
- [x] 今天(2024.2.23)发布评估代码和权重。
📖 引用
如果您觉得我们的论文和代码对您的研究有帮助,请考虑给个星 ⭐ 并引用:
@misc{zhou2024tinyllava,
title={TinyLLaVA: A Framework of Small-scale Large Multimodal Models},
author={Baichuan Zhou and Ying Hu and Xi Weng and Junlong Jia and Jie Luo and Xien Liu and Ji Wu and Lei Huang},
year={2024},
eprint={2402.14289},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
👏 社区贡献
- 我们的代码基于 [LLaVA](https://github.com/haotian - liu/LLaVA) 项目构建,感谢他们的出色工作!
- 我们的项目使用了 [ShareGPT4V](https://github.com/InternLM/InternLM - XComposer/tree/main/projects/ShareGPT4V) 项目的数据,感谢他们的贡献!








