🚀 VARGPT: 視覚自己回帰型マルチモーダル大規模言語モデルにおける統一的理解と生成
VARGPT (7B+2B) は、統一モデル内で理解と生成を2つの異なるパラダイムとしてモデリングしています。具体的には、視覚理解のために次のトークンを予測し、視覚生成のために次のスケールを予測します。
当社のモデルを使用するための簡単な生成プロセスを提供しています。詳細については、Github: VARGPT-v1 を参照してください。
🚀 クイックスタート
当社のモデルの使用方法について、以下で説明します。
✨ 主な機能
VARGPTは、視覚自己回帰型マルチモーダル大規模言語モデルで、理解と生成を統一的に扱うことができます。具体的には、マルチモーダル理解とマルチモーダル生成の2つの機能を備えています。
💻 使用例
基本的な使用法
マルチモーダル理解
マルチモーダル理解の推論デモです。以下のコードを実行できます。
import requests
from PIL import Image
import torch
from transformers import AutoProcessor, AutoTokenizer
from vargpt_llava.modeling_vargpt_llava import VARGPTLlavaForConditionalGeneration
from vargpt_llava.prepare_vargpt_llava import prepare_vargpt_llava
from vargpt_llava.processing_vargpt_llava import VARGPTLlavaProcessor
from patching_utils.patching import patching
model_id = "VARGPT_LLaVA-v1"
prepare_vargpt_llava(model_id)
model = VARGPTLlavaForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.float32,
low_cpu_mem_usage=True,
).to(0)
patching(model)
tokenizer = AutoTokenizer.from_pretrained(model_id)
processor = VARGPTLlavaProcessor.from_pretrained(model_id)
conversation = [
{
"role": "user",
"content": [
{"type": "text", "text": "Please explain the meme in detail."},
{"type": "image"},
],
},
]
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
image_file = "./assets/llava_bench_demo.png"
print(prompt)
raw_image = Image.open(image_file)
inputs = processor(images=raw_image, text=prompt, return_tensors='pt').to(0, torch.float32)
output = model.generate(
**inputs,
max_new_tokens=2048,
do_sample=False)
print(processor.decode(output[0], skip_special_tokens=True))
マルチモーダル生成
テキストから画像への生成の推論デモです。以下のコードを実行できます。
import requests
from PIL import Image
import torch
from transformers import AutoProcessor, AutoTokenizer
from vargpt_llava.modeling_vargpt_llava import VARGPTLlavaForConditionalGeneration
from vargpt_llava.prepare_vargpt_llava import prepare_vargpt_llava
from vargpt_llava.processing_vargpt_llava import VARGPTLlavaProcessor
from patching_utils.patching import patching
model_id = "VARGPT_LLaVA-v1"
prepare_vargpt_llava(model_id)
model = VARGPTLlavaForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.float32,
low_cpu_mem_usage=True,
).to(0)
patching(model)
tokenizer = AutoTokenizer.from_pretrained(model_id)
processor = VARGPTLlavaProcessor.from_pretrained(model_id)
conversation = [
{
"role": "user",
"content": [
{"type": "text", "text": "Please design a drawing of a butterfly on a flower."},
],
},
]
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
print(prompt)
inputs = processor(text=prompt, return_tensors='pt').to(0, torch.float32)
model._IMAGE_GEN_PATH = "output.png"
output = model.generate(
**inputs,
max_new_tokens=2048,
do_sample=False)
print(processor.decode(output[0], skip_special_tokens=True))
📄 ライセンス
このプロジェクトは、Apache-2.0ライセンスの下で提供されています。
📚 ドキュメント
属性 |
详情 |
モデルタイプ |
VARGPT (7B+2B) |
訓練データ |
VARGPT-family/VARGPT_datasets |
評価指標 |
正解率、F1スコア |
パイプラインタグ |
任意から任意 |
ライブラリ名 |
transformers |