🚀 VARGPT:视觉自回归多模态大语言模型中的统一理解与生成
VARGPT(7B + 2B)将理解和生成建模为统一模型中的两种不同范式:为视觉理解预测下一个标记,为视觉生成预测下一个尺度。本模型可实现多模态理解与生成,为用户提供了便捷的使用体验。
我们提供了使用模型的简单生成过程。更多详细信息,您可以参考 GitHub:VARGPT-v1。
🚀 快速开始
多模态理解
以下是多模态理解的推理演示代码,您可以执行以下代码:
import requests
from PIL import Image
import torch
from transformers import AutoProcessor, AutoTokenizer
from vargpt_llava.modeling_vargpt_llava import VARGPTLlavaForConditionalGeneration
from vargpt_llava.prepare_vargpt_llava import prepare_vargpt_llava
from vargpt_llava.processing_vargpt_llava import VARGPTLlavaProcessor
from patching_utils.patching import patching
model_id = "VARGPT_LLaVA-v1"
prepare_vargpt_llava(model_id)
model = VARGPTLlavaForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.float32,
low_cpu_mem_usage=True,
).to(0)
patching(model)
tokenizer = AutoTokenizer.from_pretrained(model_id)
processor = VARGPTLlavaProcessor.from_pretrained(model_id)
conversation = [
{
"role": "user",
"content": [
{"type": "text", "text": "Please explain the meme in detail."},
{"type": "image"},
],
},
]
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
image_file = "./assets/llava_bench_demo.png"
print(prompt)
raw_image = Image.open(image_file)
inputs = processor(images=raw_image, text=prompt, return_tensors='pt').to(0, torch.float32)
output = model.generate(
**inputs,
max_new_tokens=2048,
do_sample=False)
print(processor.decode(output[0], skip_special_tokens=True))
多模态生成
以下是文本到图像生成的推理演示代码,您可以执行以下代码:
import requests
from PIL import Image
import torch
from transformers import AutoProcessor, AutoTokenizer
from vargpt_llava.modeling_vargpt_llava import VARGPTLlavaForConditionalGeneration
from vargpt_llava.prepare_vargpt_llava import prepare_vargpt_llava
from vargpt_llava.processing_vargpt_llava import VARGPTLlavaProcessor
from patching_utils.patching import patching
model_id = "VARGPT_LLaVA-v1"
prepare_vargpt_llava(model_id)
model = VARGPTLlavaForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.float32,
low_cpu_mem_usage=True,
).to(0)
patching(model)
tokenizer = AutoTokenizer.from_pretrained(model_id)
processor = VARGPTLlavaProcessor.from_pretrained(model_id)
conversation = [
{
"role": "user",
"content": [
{"type": "text", "text": "Please design a drawing of a butterfly on a flower."},
],
},
]
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
print(prompt)
inputs = processor(text=prompt, return_tensors='pt').to(0, torch.float32)
model._IMAGE_GEN_PATH = "output.png"
output = model.generate(
**inputs,
max_new_tokens=2048,
do_sample=False)
print(processor.decode(output[0], skip_special_tokens=True))
📄 许可证
本项目采用 Apache-2.0 许可证。
📦 模型信息
属性 |
详情 |
模型类型 |
VARGPT(7B + 2B) |
训练数据 |
VARGPT-family/VARGPT_datasets |
评估指标 |
准确率、F1值 |
任务类型 |
任意到任意 |
库名称 |
transformers |