模型简介
模型特点
模型能力
使用案例
🚀 PaliGemma模型卡片
PaliGemma是一款多功能轻量级视觉语言模型(VLM),它以图像和文本作为输入,生成文本输出,支持多种语言。该模型可用于图像和短视频字幕、视觉问答、文本阅读、目标检测和目标分割等多种视觉语言任务。
🚀 快速开始
PaliGemma是单轮视觉语言模型,不适用于对话场景,在针对特定用例进行微调时效果最佳。你可以通过任务前缀(如 “detect” 或 “segment”)来配置模型要解决的任务。预训练模型以这种方式进行训练,以赋予它们丰富的能力(问答、字幕、分割等)。不过,它们并非设计用于直接使用,而是通过微调转移到使用类似提示结构的特定任务。对于交互式测试,你可以使用 “mix” 系列模型,这些模型已针对多种任务进行了微调。要查看模型 google/paligemma-3b-mix-448 的实际运行情况,请查看 使用Transformers代码库的Space。
请参考使用和限制部分了解预期用例,或访问博客文章获取更多详细信息和示例。
✨ 主要特性
- 多功能性:支持图像和文本输入,能处理图像和短视频字幕、视觉问答、文本阅读、目标检测和目标分割等多种视觉语言任务。
- 多语言支持:可以处理多种语言的输入和输出。
- 轻量级:模型参数相对较少,便于在不同设备上进行部署和微调。
📦 安装指南
你需要安装 bitsandbytes
以自动使用8位或4位精度运行推理:
pip install bitsandbytes accelerate
💻 使用示例
基础用法
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
from PIL import Image
import requests
import torch
model_id = "google/paligemma-3b-mix-224"
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true"
image = Image.open(requests.get(url, stream=True).raw)
model = PaliGemmaForConditionalGeneration.from_pretrained(model_id).eval()
processor = AutoProcessor.from_pretrained(model_id)
# Instruct the model to create a caption in Spanish
prompt = "caption es"
model_inputs = processor(text=prompt, images=image, return_tensors="pt")
input_len = model_inputs["input_ids"].shape[-1]
with torch.inference_mode():
generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False)
generation = generation[0][input_len:]
decoded = processor.decode(generation, skip_special_tokens=True)
print(decoded)
输出: Un auto azul estacionado frente a un edificio.
高级用法
在CUDA上运行其他精度
为了方便起见,仓库中包含已转换为 bfloat16
和 float16
的权重版本,因此你可以使用它们来减小下载大小,并避免在本地计算机上进行类型转换。
以下是如何在NVIDIA CUDA卡上运行 bfloat16
的示例:
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
from PIL import Image
import requests
import torch
model_id = "google/paligemma-3b-mix-224"
device = "cuda:0"
dtype = torch.bfloat16
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true"
image = Image.open(requests.get(url, stream=True).raw)
model = PaliGemmaForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=dtype,
device_map=device,
revision="bfloat16",
).eval()
processor = AutoProcessor.from_pretrained(model_id)
# Instruct the model to create a caption in Spanish
prompt = "caption es"
model_inputs = processor(text=prompt, images=image, return_tensors="pt").to(model.device)
input_len = model_inputs["input_ids"].shape[-1]
with torch.inference_mode():
generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False)
generation = generation[0][input_len:]
decoded = processor.decode(generation, skip_special_tokens=True)
print(decoded)
以4位/8位加载模型
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration, BitsAndBytesConfig
from PIL import Image
import requests
import torch
model_id = "google/paligemma-3b-mix-224"
device = "cuda:0"
dtype = torch.bfloat16
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true"
image = Image.open(requests.get(url, stream=True).raw)
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
model = PaliGemmaForConditionalGeneration.from_pretrained(
model_id, quantization_config=quantization_config
).eval()
processor = AutoProcessor.from_pretrained(model_id)
# Instruct the model to create a caption in Spanish
prompt = "caption es"
model_inputs = processor(text=prompt, images=image, return_tensors="pt").to(model.device)
input_len = model_inputs["input_ids"].shape[-1]
with torch.inference_mode():
generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False)
generation = generation[0][input_len:]
decoded = processor.decode(generation, skip_special_tokens=True)
print(decoded)
📚 详细文档
模型信息
模型概述
PaliGemma是受PaLI - 3启发,基于SigLIP视觉模型和Gemma语言模型等开放组件构建的多功能轻量级视觉语言模型(VLM)。它以图像和文本作为输入,生成文本输出,支持多种语言。该模型专为在图像和短视频字幕、视觉问答、文本阅读、目标检测和目标分割等广泛的视觉语言任务中实现一流的微调性能而设计。
模型架构
PaliGemma由Transformer解码器和视觉Transformer图像编码器组成,总参数为30亿。文本解码器从Gemma - 2B初始化,图像编码器从SigLIP - So400m/14初始化。PaliGemma按照PaLI - 3的方法进行训练。
输入和输出
- 输入:图像和文本字符串,如图像字幕提示或问题。
- 输出:针对输入生成的文本,如图像字幕、问题答案、目标边界框坐标列表或分割代码字。
模型数据
预训练数据集
PaliGemma在以下数据集的混合上进行预训练:
- WebLI:WebLI(Web语言图像)是一个基于公共网络构建的网络规模多语言图像 - 文本数据集。使用了多种WebLI分割来获得通用的模型能力,如视觉语义理解、目标定位、视觉情境文本理解、多语言能力等。
- CC3M - 35L:从网页中精心挑选的英语图像 - 替代文本对(Sharma等人,2018)。使用Google Cloud Translation API将其翻译成另外34种语言。
- VQ²A - CC3M - 35L/VQG - CC3M - 35L:VQ2A - CC3M的一个子集(Changpinyo等人,2022a),使用Google Cloud Translation API翻译成与CC3M - 35L相同的另外34种语言。
- OpenImages:在[OpenImages数据集]上通过手工规则生成的检测和目标感知问题及答案(Piergiovanni等人,2022)。
- WIT:从维基百科收集的图像和文本(Srinivasan等人,2021)。
数据责任过滤
为了在干净的数据上训练PaliGemma,对WebLI应用了以下过滤:
- 色情图像过滤:此过滤器去除被认为具有色情性质的图像。
- 文本安全过滤:识别并过滤掉与不安全文本配对的图像。不安全文本是指任何被认为包含或涉及儿童性虐待材料、色情内容、粗俗语言或其他冒犯性内容的文本。
- 文本毒性过滤:进一步使用Perspective API识别并过滤掉与被认为具有侮辱性、淫秽性、仇恨性或其他毒性的文本配对的图像。
- 文本个人信息过滤:使用Cloud Data Loss Prevention (DLP) API过滤某些个人信息和其他敏感数据,以保护个人隐私。去除了如社会安全号码和[其他敏感信息类型]等标识符。
- 其他方法:根据内容质量和安全性进行过滤,符合我们的政策和实践。
实现信息
硬件
PaliGemma使用最新一代的张量处理单元(TPU)硬件(TPUv5e)进行训练。
软件
训练使用了JAX、Flax、TFDS和big_vision
。
JAX允许研究人员利用最新一代的硬件(包括TPU),以更快、更高效地训练大型模型。TFDS用于访问数据集,Flax用于模型架构。PaliGemma的微调代码和推理代码在big_vision
GitHub仓库中发布。
评估信息
基准测试结果
为了验证PaliGemma对各种学术任务的可迁移性,我们在每个任务上对预训练模型进行微调。此外,我们还使用迁移任务的混合训练了混合模型。我们报告了不同分辨率下的结果,以了解哪些任务从更高的分辨率中受益。重要的是,这些任务或数据集都不是预训练数据混合的一部分,并且它们的图像已从网络规模的预训练数据中明确移除。
单任务(在单任务上微调)
基准测试(训练分割) | 指标(分割) | pt - 224 | pt - 448 | pt - 896 |
---|---|---|---|---|
字幕生成 | ||||
COCO captions(train + restval) | CIDEr(val) | 141.92 | 144.60 | |
NoCaps(Eval of COCO captions transfer) | CIDEr(val) | 121.72 | 123.58 | |
COCO - 35L(train) | CIDEr dev(en/avg - 34/avg) | 139.2 115.8 116.4 |
141.2 118.0 118.6 |
|
XM3600(Eval of COCO - 35L transfer) | CIDEr dev(en/avg - 34/avg) | 78.1 41.3 42.4 |
80.0 41.9 42.9 |
|
TextCaps(train) | CIDEr(val) | 127.48 | 153.94 | |
SciCap(first sentence, no subfigure)(train + val) | CIDEr/BLEU - 4(test) | 162.25 0.192 |
181.49 0.211 |
|
Screen2words(train + dev) | CIDEr(test) | 117.57 | 119.59 | |
Widget Captioning(train + dev) | CIDEr(test) | 136.07 | 148.36 | |
问答 | ||||
VQAv2(train + validation) | 准确率(Test server - std) | 83.19 | 85.64 | |
MMVP(Eval of VQAv2 transfer) | 配对准确率 | 47.33 | 45.33 | |
POPE(Eval of VQAv2 transfer) | 准确率(random/popular/ adversarial) |
87.80 85.87 84.27 |
88.23 86.77 85.90 |
|
OKVQA(train) | 准确率(val) | 63.54 | 63.15 | |
[A - OKVQA](https://allenai.org/project/a - okvqa/home) (MC)(train + val) | 准确率(Test server) | 76.37 | 76.90 | |
[A - OKVQA](https://allenai.org/project/a - okvqa/home) (DA)(train + val) | 准确率(Test server) | 61.85 | 63.22 | |
GQA(train_balanced + val_balanced) |
准确率(testdev balanced) | 65.61 | 67.03 | |
[xGQA](https://aclanthology.org/2022.findings - acl.196/)(Eval of GQA transfer) | 平均准确率(bn, de, en, id, ko, pt, ru, zh) |
58.37 | 59.07 | |
NLVR2(train + dev) | 准确率(test) | 90.02 | 88.93 | |
[MaRVL](https://marvl - challenge.github.io/)(Eval of NLVR2 transfer) | 平均准确率(test)(id, sw, ta, tr, zh) | 80.57 | 76.78 | |
AI2D(train) | 准确率(test) | 72.12 | 73.28 | |
ScienceQA(Img subset, no CoT)(train + val) | 准确率(test) | 95.39 | 95.93 | |
RSVQA - LR(Non numeric)(train + val) | 平均准确率(test) | 92.65 | 93.11 | |
RSVQA - HR(Non numeric)(train + val) | 平均准确率(test/test2) | 92.61 90.58 |
92.79 90.54 |
|
ChartQA(human + aug)x(train + val) | 平均宽松准确率(test_human, test_aug) |
57.08 | 71.36 | |
[VizWiz VQA](https://vizwiz.org/tasks - and - datasets/vqa/)(train + val) | 准确率(Test server - std) | 73.7 | 75.52 | |
TallyQA(train) | 准确率(test_simple/ test_complex) |
81.72 69.56 |
84.86 72.27 |
|
[OCR - VQA](https://ocr - vqa.github.io/)(train + val) | 准确率(test) | 72.32 | 74.61 | 74.93 |
TextVQA(train + val) | 准确率(Test server - std) | 55.47 | 73.15 | 76.48 |
DocVQA(train + val) | ANLS(Test server) | 43.74 | 78.02 | 84.77 |
Infographic VQA(train + val) | ANLS(Test server) | 28.46 | 40.47 | 47.75 |
SceneText VQA(train + val) | ANLS(Test server) | 63.29 | 81.82 | 84.40 |
分割 | ||||
RefCOCO(combined refcoco, refcoco+, refcocog excluding val and test images) |
MIoU(validation) refcoco/refcoco+/ refcocog |
73.40 68.32 67.65 |
75.57 69.76 70.17 |
76.94 72.18 72.22 |
视频任务(字幕/问答) | ||||
MSR - VTT(Captioning) | CIDEr(test) | 70.54 | ||
MSR - VTT(QA) | 准确率(test) | 50.09 | ||
ActivityNet(Captioning) | CIDEr(test) | 34.62 | ||
ActivityNet(QA) | 准确率(test) | 50.78 | ||
VATEX(Captioning) | CIDEr(test) | 79.73 | ||
MSVD(QA) | 准确率(test) | 60.22 |
混合模型(在迁移任务的混合上微调)
基准测试 | 指标(分割) | mix - 224 | mix - 448 |
---|---|---|---|
MMVP | 配对准确率 | 46.00 | 45.33 |
POPE | 准确率(random/popular/adversarial) | 88.00 86.63 85.67 |
89.37 88.40 87.47 |
伦理与安全
评估方法
我们的评估方法包括结构化评估和对相关内容政策的内部红队测试。红队测试由多个不同的团队进行,每个团队有不同的目标和人工评估指标。这些模型针对与伦理和安全相关的多个不同类别进行了评估,包括:
- 对涵盖儿童安全、内容安全和代表性危害的提示进行人工评估。有关评估方法的更多详细信息,请参阅Gemma模型卡片,但采用图像字幕和视觉问答设置。
- 图像到文本基准评估:针对相关学术数据集(如FairFace数据集(Karkkainen等人,2021))进行基准测试。
评估结果
- 伦理和安全评估的人工评估结果在符合[内部政策](https://storage.googleapis.com/gweb - uniblog - publish - prod/documents/2023_Google_AI_Principles_Progress_Update.pdf#page=11)的可接受阈值范围内,这些类别包括儿童安全、内容安全和代表性危害。
- 除了强大的内部评估外,我们还使用Perspective API(阈值为0.8)来测量从FairFace数据集中获取的图像生成字幕中的毒性、亵渎和其他潜在问题。我们报告了每个感知性别、种族和年龄属性的子组中观察到的最大值和中值。
指标 | 感知性别 | 种族 | 年龄组 | |||
---|---|---|---|---|---|---|
最大值 | 中值 | 最大值 | 中值 | 最大值 | 中值 | |
毒性 | 0.04% | 0.03% | 0.08% | 0.00% | 0.09% | 0.00% |
身份攻击 | 0.00% | 0.00% | 0.00% | 0.00% | 0.00% | 0.00% |
侮辱 | 0.06% | 0.04% | 0.09% | 0.07% | 0.16% | 0.00% |
威胁 | 0.06% | 0.05% | 0.14% | 0.05% | 0.17% | 0.00% |
亵渎 | 0.00% | 0.00% | 0.00% | 0.00% | 0.00% | 0.00% |
使用和限制
预期用途
开放视觉语言模型(VLM)在各个行业和领域都有广泛的应用。以下潜在用途列表并不全面。此列表的目的是提供有关模型创建者在模型训练和开发过程中考虑的可能用例的上下文信息。
针对特定视觉语言任务进行微调
- 预训练模型可针对广泛的视觉语言任务进行微调,如图像字幕、短视频字幕、视觉问答、文本阅读、目标检测和目标分割。
- 预训练模型可针对特定领域进行微调,如遥感问答、盲人视觉问题、科学问答、描述UI元素功能。
- 预训练模型可针对具有非文本输出(如边界框或分割掩码)的任务进行微调。
视觉语言研究
- 预训练模型和微调模型可以为研究人员提供基础,用于试验VLM技术、开发算法,并为该领域的发展做出贡献。
伦理考虑和风险
视觉语言模型(VLM)的开发引发了一些伦理问题。在创建开放模型时,我们仔细考虑了以下方面:
- 偏差和公平性:在大规模真实世界图像 - 文本数据上训练的VLM可能反映训练材料中嵌入的社会文化偏差。这些模型经过了仔细审查,输入数据的预处理在本卡片中进行了描述,并报告了后续评估结果。
- 错误信息和滥用:VLM可能被滥用来生成虚假、误导或有害的文本。我们提供了负责任使用模型的指南,请参阅Responsible Generative AI Toolkit。
- 透明度和问责制:本模型卡片总结了模型的架构、能力、限制和评估过程的详细信息。一个负责任开发的开放模型为通过使VLM技术在整个AI生态系统中可供开发者和研究人员使用来分享创新提供了机会。
已识别的风险和缓解措施
- 偏差的延续:鼓励在模型训练、微调及其他用例中使用评估指标和人工审查进行持续监测,并探索去偏技术。
- 有害内容的生成:内容安全机制和指南至关重要。鼓励开发者谨慎行事,并根据其特定的产品政策和应用用例实施适当的内容安全保障措施。
- 用于恶意目的的滥用:技术限制以及对开发者和最终用户的教育有助于减轻大语言模型的恶意应用。我们提供了教育资源和用户举报滥用行为的机制。Gemma模型的禁止使用情况在Gemma Prohibited Use Policy中进行了概述。
- 隐私侵犯:模型在经过过滤以去除某些个人信息和敏感数据的数据上进行训练。鼓励开发者遵守隐私法规并使用保护隐私的技术。
限制
- 大多数继承自基础Gemma模型的限制仍然适用:
- VLM在可以用明确提示和说明构建的任务中表现更好。开放式或高度复杂的任务可能具有挑战性。
- 自然语言本质上是复杂的。VLM可能难以理解微妙的细微差别、讽刺或比喻语言。
- VLM根据从训练数据集中学到的信息生成响应,但它们不是知识库。它们可能生成不正确或过时的事实陈述。
- VLM依赖于语言和图像中的统计模式。它们可能在某些情况下缺乏应用常识推理的能力。
- PaliGemma首先是作为用于转移到专门任务的通用预训练模型而设计的。因此,其“开箱即用”或“零样本”性能可能落后于专门为此设计的模型。
- PaliGemma不是多轮聊天机器人。它设计用于单轮图像和文本输入。
🔧 技术细节
模型类型
PaliGemma是由Transformer解码器和视觉Transformer图像编码器组成的视觉语言模型。
训练数据
预训练数据包括WebLI、CC3M - 35L、VQ²A - CC3M - 35L/VQG - CC3M - 35L、OpenImages和WIT等数据集。同时,对WebLI数据进行了色情图像过滤、文本安全过滤、文本毒性过滤、文本个人信息过滤等处理,以确保训练数据的质量和安全性。
训练环境
硬件方面使用了TPUv5e;软件方面使用了JAX、Flax、TFDS和big_vision
。
📄 许可证
许可证为gemma。
📖 引用
@article{beyer2024paligemma,
title={{PaliGemma: A versatile 3B VLM for transfer}},
author={Lucas Beyer* and Andreas Steiner* and André Susano Pinto* and Alexander Kolesnikov* and Xiao Wang* and Daniel Salz and Maxim Neumann and Ibrahim Alabdulmohsin and Michael Tschannen and Emanuele Bugliarello and Thomas Unterthiner and Daniel Keysers and Skanda Koppula and Fangyu Liu and Adam Grycner and Alexey Gritsenko and Neil Houlsby and Manoj Kumar and Keran Rong and Julian Eisenschlos and Rishabh Kabra and Matthias Bauer and Matko Bošnjak and Xi Chen and Matthias Minderer and Paul Voigtlaender and Ioana Bica and Ivana Balazevic and Joan Puigcerver and Pinelopi Papalampidi and Olivier Henaff and Xi Xiong and Radu Soricut and Jeremiah Harmsen and Xiaohua Zhai*},
year={2024},
journal={arXiv preprint arXiv:2407.07726}
}
查看论文请点击此处。








