🚀 VARGPT:視覺自迴歸多模態大語言模型中的統一理解與生成
VARGPT(7B + 2B)將理解和生成建模為統一模型中的兩種不同範式:為視覺理解預測下一個標記,為視覺生成預測下一個尺度。本模型可實現多模態理解與生成,為用戶提供了便捷的使用體驗。
我們提供了使用模型的簡單生成過程。更多詳細信息,您可以參考 GitHub:VARGPT-v1。
🚀 快速開始
多模態理解
以下是多模態理解的推理演示代碼,您可以執行以下代碼:
import requests
from PIL import Image
import torch
from transformers import AutoProcessor, AutoTokenizer
from vargpt_llava.modeling_vargpt_llava import VARGPTLlavaForConditionalGeneration
from vargpt_llava.prepare_vargpt_llava import prepare_vargpt_llava
from vargpt_llava.processing_vargpt_llava import VARGPTLlavaProcessor
from patching_utils.patching import patching
model_id = "VARGPT_LLaVA-v1"
prepare_vargpt_llava(model_id)
model = VARGPTLlavaForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.float32,
low_cpu_mem_usage=True,
).to(0)
patching(model)
tokenizer = AutoTokenizer.from_pretrained(model_id)
processor = VARGPTLlavaProcessor.from_pretrained(model_id)
conversation = [
{
"role": "user",
"content": [
{"type": "text", "text": "Please explain the meme in detail."},
{"type": "image"},
],
},
]
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
image_file = "./assets/llava_bench_demo.png"
print(prompt)
raw_image = Image.open(image_file)
inputs = processor(images=raw_image, text=prompt, return_tensors='pt').to(0, torch.float32)
output = model.generate(
**inputs,
max_new_tokens=2048,
do_sample=False)
print(processor.decode(output[0], skip_special_tokens=True))
多模態生成
以下是文本到圖像生成的推理演示代碼,您可以執行以下代碼:
import requests
from PIL import Image
import torch
from transformers import AutoProcessor, AutoTokenizer
from vargpt_llava.modeling_vargpt_llava import VARGPTLlavaForConditionalGeneration
from vargpt_llava.prepare_vargpt_llava import prepare_vargpt_llava
from vargpt_llava.processing_vargpt_llava import VARGPTLlavaProcessor
from patching_utils.patching import patching
model_id = "VARGPT_LLaVA-v1"
prepare_vargpt_llava(model_id)
model = VARGPTLlavaForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.float32,
low_cpu_mem_usage=True,
).to(0)
patching(model)
tokenizer = AutoTokenizer.from_pretrained(model_id)
processor = VARGPTLlavaProcessor.from_pretrained(model_id)
conversation = [
{
"role": "user",
"content": [
{"type": "text", "text": "Please design a drawing of a butterfly on a flower."},
],
},
]
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
print(prompt)
inputs = processor(text=prompt, return_tensors='pt').to(0, torch.float32)
model._IMAGE_GEN_PATH = "output.png"
output = model.generate(
**inputs,
max_new_tokens=2048,
do_sample=False)
print(processor.decode(output[0], skip_special_tokens=True))
📄 許可證
本項目採用 Apache-2.0 許可證。
📦 模型信息
屬性 |
詳情 |
模型類型 |
VARGPT(7B + 2B) |
訓練數據 |
VARGPT-family/VARGPT_datasets |
評估指標 |
準確率、F1值 |
任務類型 |
任意到任意 |
庫名稱 |
transformers |