🚀 VLRM
本倉庫包含了通過論文 VLRM: Vision-Language Models act as Reward Models for Image Captioning 中介紹的強化學習方法微調的 BLIP - 2 OPT - 2.7B 模型的權重。與原始模型相比,經過強化學習微調的模型能夠在零計算開銷的情況下生成更長、更全面的描述。
你可以在 GitHub 倉庫 (待完成) 中找到其他詳細信息。
🚀 快速開始
💻 使用示例
基礎用法
你可以從本倉庫加載整個模型:
import torch
import requests
from PIL import Image
from transformers import Blip2Processor, Blip2ForConditionalGeneration
processor = Blip2Processor.from_pretrained("sashakunitsyn/vlrm-blip2-opt-2.7b")
model = Blip2ForConditionalGeneration.from_pretrained("sashakunitsyn/vlrm-blip2-opt-2.7b", torch_dtype=torch.float16, device_map="auto")
img_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg'
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB')
inputs = processor(raw_image, return_tensors="pt").to("cuda", torch.float16)
out = model.generate(**inputs, max_new_tokens=60)
processor.decode(out[0], skip_special_tokens=True).strip()
>>> 'a woman in a plaid shirt shaking hands with a yellow labrador retriever sitting on the ground at sunset on a beach in florida'
高級用法
由於微調層在整個模型中所佔比例較小,你可以先加載原始模型,然後加載經過強化學習微調的權重。
步驟 1. 加載原始模型
import torch
import requests
from PIL import Image
from transformers import Blip2Processor, Blip2ForConditionalGeneration
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16, device_map="auto")
img_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg'
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB')
inputs = processor(raw_image, return_tensors="pt").to("cuda", torch.float16)
out = model.generate(**inputs, max_new_tokens=60)
processor.decode(out[0], skip_special_tokens=True).strip()
>>> 'a woman sitting on the beach with a dog'
步驟 2. 加載經過強化學習微調的權重
可用的檢查點:
vlrm-blip2-opt-2.7b.pt
(論文中的 VLRM)
vlrm-rs-blip2-opt-2.7b.pt
(論文中的 VLRM - RS)
from huggingface_hub import hf_hub_download
finetuned_weights_state_dict = torch.load(hf_hub_download(repo_id="sashakunitsyn/vlrm-blip2-opt-2.7b", filename="vlrm-blip2-opt-2.7b.pt"))
model.load_state_dict(finetuned_weights_state_dict, strict=False)
out = model.generate(**inputs, max_new_tokens=60)
processor.decode(out[0], skip_special_tokens=True).strip()
>>> 'a woman in a plaid shirt shaking hands with a yellow labrador retriever sitting on the ground at sunset on a beach in florida'
📄 許可證
本項目採用 MIT 許可證。