🚀 VLRM
本仓库包含了通过论文 VLRM: Vision-Language Models act as Reward Models for Image Captioning 中介绍的强化学习方法微调的 BLIP - 2 OPT - 2.7B 模型的权重。与原始模型相比,经过强化学习微调的模型能够在零计算开销的情况下生成更长、更全面的描述。
你可以在 GitHub 仓库 (待完成) 中找到其他详细信息。
🚀 快速开始
💻 使用示例
基础用法
你可以从本仓库加载整个模型:
import torch
import requests
from PIL import Image
from transformers import Blip2Processor, Blip2ForConditionalGeneration
processor = Blip2Processor.from_pretrained("sashakunitsyn/vlrm-blip2-opt-2.7b")
model = Blip2ForConditionalGeneration.from_pretrained("sashakunitsyn/vlrm-blip2-opt-2.7b", torch_dtype=torch.float16, device_map="auto")
img_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg'
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB')
inputs = processor(raw_image, return_tensors="pt").to("cuda", torch.float16)
out = model.generate(**inputs, max_new_tokens=60)
processor.decode(out[0], skip_special_tokens=True).strip()
>>> 'a woman in a plaid shirt shaking hands with a yellow labrador retriever sitting on the ground at sunset on a beach in florida'
高级用法
由于微调层在整个模型中所占比例较小,你可以先加载原始模型,然后加载经过强化学习微调的权重。
步骤 1. 加载原始模型
import torch
import requests
from PIL import Image
from transformers import Blip2Processor, Blip2ForConditionalGeneration
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16, device_map="auto")
img_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg'
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB')
inputs = processor(raw_image, return_tensors="pt").to("cuda", torch.float16)
out = model.generate(**inputs, max_new_tokens=60)
processor.decode(out[0], skip_special_tokens=True).strip()
>>> 'a woman sitting on the beach with a dog'
步骤 2. 加载经过强化学习微调的权重
可用的检查点:
vlrm-blip2-opt-2.7b.pt
(论文中的 VLRM)
vlrm-rs-blip2-opt-2.7b.pt
(论文中的 VLRM - RS)
from huggingface_hub import hf_hub_download
finetuned_weights_state_dict = torch.load(hf_hub_download(repo_id="sashakunitsyn/vlrm-blip2-opt-2.7b", filename="vlrm-blip2-opt-2.7b.pt"))
model.load_state_dict(finetuned_weights_state_dict, strict=False)
out = model.generate(**inputs, max_new_tokens=60)
processor.decode(out[0], skip_special_tokens=True).strip()
>>> 'a woman in a plaid shirt shaking hands with a yellow labrador retriever sitting on the ground at sunset on a beach in florida'
📄 许可证
本项目采用 MIT 许可证。