モデル概要
モデル特徴
モデル能力
使用事例
🚀 PaliGemmaモデルカード
PaliGemmaは、画像とテキストの入力を組み合わせてテキスト出力を生成する、多言語対応の軽量型マルチモーダルビジュアル言語モデル(VLM)です。このモデルはオープンコンポーネントに基づいて構築されており、画像や短い動画の字幕作成、ビジュアルQA、テキスト読み取り、物体検出、物体セグメンテーションなど、様々なビジュアル言語タスクに適用可能です。
🚀 クイックスタート
Hugging Face上でPaliGemmaにアクセスするには、Googleの使用許諾を確認して同意する必要があります。Hugging Faceにログインした状態で、以下のボタンをクリックすると、リクエストがすぐに処理されます。 [許諾を確認](javascript:void(0);)
モデルページ:PaliGemma
Transformers PaliGemma 3Bの重みは、RSVQA-LRデータセットで224*224の入力画像を使用して微調整されています。これらのモデルは研究目的のみで提供されており、float32、bfloat16、float16の形式で利用できます。微調整の設定は、big_visionで確認できます。
リソースと技術文書:
利用規約:規約
作者:Google
✨ 主な機能
モデル情報
モデル概要
PaliGemmaは、PaLI-3にインスパイアされ、SigLIPビジュアルモデルやGemma言語モデルなどのオープンコンポーネントを基に構築されています。画像とテキストの入力を同時に処理し、多言語対応のテキスト出力を生成することができます。このモデルは、幅広いビジュアル言語タスクで高い微調整性能を実現することを目指しています。
モデルアーキテクチャ
PaliGemmaは、TransformerデコーダとビジュアルTransformer画像エンコーダで構成され、合計30億個のパラメータを持っています。テキストデコーダはGemma-2Bから初期化され、画像エンコーダはSigLIP-So400m/14から初期化されます。PaliGemmaはPaLI-3の方法に従って学習されています。
入力と出力
- 入力:画像とテキスト文字列(画像字幕のプロンプトや質問など)
- 出力:入力に基づいて生成されたテキスト(画像字幕、質問の回答、物体のバウンディングボックス座標のリスト、セグメンテーションコードなど)
モデルデータ
事前学習データセット
PaliGemmaは、以下のデータセットの混合で事前学習されています。
- WebLI:WebLI(ウェブ言語画像)は、公共のウェブに基づくウェブ規模の多言語画像テキストデータセットです。モデルの多機能性(ビジュアル意味理解、物体位置特定、ビジュアルコンテキストのテキスト理解、多言語能力など)を獲得するために、複数のWebLI分割が使用されています。
- CC3M-35L:ウェブページから精心選択された英語の画像 - 代替テキストペア(Sharmaら、2018)。Google Cloud Translation APIを使用して、さらに34の言語に翻訳されています。
- VQ²A-CC3M-35L/VQG-CC3M-35L:VQ2A-CC3Mのサブセット(Changpinyoら、2022a)。Google Cloud Translation APIを使用して、CC3M-35Lと同じ34の言語に翻訳されています。
- OpenImages:OpenImagesデータセットに基づく手動ルールによって生成された検出と物体感知QA(Piergiovanniら、2022)。
- WIT:ウィキペディアから収集された画像とテキスト(Srinivasanら、2021)。
データ責任フィルタリング
PaliGemmaをクリーンなデータで学習させるために、WebLIには以下のフィルタリングが適用されています。
- 色情画像フィルタリング:色情的な性質を持つと見なされる画像を削除します。
- テキストセキュリティフィルタリング:不安全なテキスト(児童性虐待材料、色情コンテンツ、下品な言葉、その他の不快な内容を含むと見なされるテキスト)とペアになっている画像を識別し、フィルタリングします。
- テキスト毒性フィルタリング:Perspective APIを使用して、侮辱的、猥褻的、仇恨的、またはその他の毒性のあると見なされるテキストとペアになっている画像を識別し、フィルタリングします。
- テキスト個人情報フィルタリング:Cloud Data Loss Prevention (DLP) APIを使用して、特定の個人情報やその他の機密データをフィルタリングし、個人情報を保護します。社会保障番号などの識別子やその他の機密情報タイプが削除されます。
- その他の方法:コンテンツの品質と安全性に基づいてフィルタリングし、当社のポリシーと実践に合致させます。
📦 インストール
8ビットまたは4ビット精度で自動的に推論を実行するには、bitsandbytes
をインストールする必要があります。
pip install bitsandbytes accelerate
💻 使用例
基本的な使用法
PaliGemmaはシングルラウンドのビジュアル言語モデルであり、対話には適しておらず、特定のユースケースに合わせて微調整すると最適な結果が得られます。
「detect」や「segment」などのタスクプレフィックスを使用することで、モデルが解決するタスクを設定することができます。事前学習モデルは、このような方法で学習され、豊富な機能(QA、字幕作成、セグメンテーションなど)を備えています。ただし、これらは直接使用するものではなく、類似したプロンプト構造を持つ特定のタスクに微調整して転用されます。インタラクティブなテストには、複数のタスクの混合で微調整された「mix」シリーズのモデルを使用することができます。
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
from PIL import Image
import requests
import torch
model_id = "google/paligemma-3b-mix-224"
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true"
image = Image.open(requests.get(url, stream=True).raw)
model = PaliGemmaForConditionalGeneration.from_pretrained(model_id).eval()
processor = AutoProcessor.from_pretrained(model_id)
# モデルにスペイン語の字幕を作成するよう指示する
prompt = "caption es"
model_inputs = processor(text=prompt, images=image, return_tensors="pt")
input_len = model_inputs["input_ids"].shape[-1]
with torch.inference_mode():
generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False)
generation = generation[0][input_len:]
decoded = processor.decode(generation, skip_special_tokens=True)
print(decoded)
出力:Un auto azul estacionado frente a un edificio.
高度な使用法
CUDA上での他の精度での実行
利便性のために、リポジトリにはbfloat16
とfloat16
に変換された重みのバージョンが含まれています。これらを使用すると、ダウンロードサイズを削減し、ローカルコンピュータでの型変換を回避することができます。
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
from PIL import Image
import requests
import torch
model_id = "google/paligemma-3b-mix-224"
device = "cuda:0"
dtype = torch.bfloat16
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true"
image = Image.open(requests.get(url, stream=True).raw)
model = PaliGemmaForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=dtype,
device_map=device,
revision="bfloat16",
).eval()
processor = AutoProcessor.from_pretrained(model_id)
# モデルにスペイン語の字幕を作成するよう指示する
prompt = "caption es"
model_inputs = processor(text=prompt, images=image, return_tensors="pt").to(model.device)
input_len = model_inputs["input_ids"].shape[-1]
with torch.inference_mode():
generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False)
generation = generation[0][input_len:]
decoded = processor.decode(generation, skip_special_tokens=True)
print(decoded)
4ビット/8ビットでの読み込み
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
from PIL import Image
import requests
import torch
model_id = "google/paligemma-3b-mix-224"
device = "cuda:0"
dtype = torch.bfloat16
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true"
image = Image.open(requests.get(url, stream=True).raw)
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
model = PaliGemmaForConditionalGeneration.from_pretrained(
model_id, quantization_config=quantization_config
).eval()
processor = AutoProcessor.from_pretrained(model_id)
# モデルにスペイン語の字幕を作成するよう指示する
prompt = "caption es"
model_inputs = processor(text=prompt, images=image, return_tensors="pt").to(model.device)
input_len = model_inputs["input_ids"].shape[-1]
with torch.inference_mode():
generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False)
generation = generation[0][input_len:]
decoded = processor.decode(generation, skip_special_tokens=True)
print(decoded)
🔧 技術詳細
ハードウェア
PaliGemmaは最新世代のテンソル処理ユニット(TPU)ハードウェア(TPUv5e)を使用して学習されています。
ソフトウェア
学習には、JAX、Flax、TFDS、およびbig_vision
が使用されています。
JAXにより、研究員は最新世代のハードウェア(TPUを含む)を活用して、大型モデルをより高速かつ効率的に学習することができます。TFDSはデータセットへのアクセスに使用され、Flaxはモデルアーキテクチャに使用されます。PaliGemmaの微調整コードと推論コードは、big_vision
のGitHubリポジトリで公開されています。
📚 ドキュメント
評価情報
ベンチマーク結果
PaliGemmaの様々な学術タスクへの転用可能性を検証するために、各タスクで事前学習モデルを微調整しました。また、複数の転用タスクの混合で混合モデルを学習させました。異なる解像度での結果を報告し、どのタスクが高解像度から恩恵を受けるかを把握しました。重要なことは、これらのタスクやデータセットは事前学習データの混合に含まれておらず、それらの画像はウェブ規模の事前学習データから明確に除外されているということです。
混合モデル(複数の転用タスクの混合で微調整)
ベンチマーク | 指標(セグメンテーション) | mix-224 | mix-448 |
---|---|---|---|
MMVP | ペア精度 | 46.00 | 45.33 |
POPE | 精度 (ランダム/人気/対抗) |
88.00 86.63 85.67 |
89.37 88.40 87.47 |
GQA | 精度(テスト) | 65.20 | 65.47 |
単一タスク(単一タスクで微調整)
ベンチマーク (学習分割) |
指標 (分割) |
pt-224 | pt-448 | pt-896 |
---|---|---|---|---|
字幕生成 | ||||
COCO captions (train+restval) |
CIDEr(検証) | 141.92 | 144.60 | - |
NoCaps (COCO字幕転用評価) |
CIDEr(検証) | 121.72 | 123.58 | - |
COCO-35L (学習) |
CIDEr開発 (英語/34言語平均/平均) |
139.2 115.8 116.4 |
141.2 118.0 118.6 |
- |
XM3600 (COCO-35L転用評価) |
CIDEr開発 (英語/34言語平均/平均) |
78.1 41.3 42.4 |
80.0 41.9 42.9 |
- |
TextCaps (学習) |
CIDEr(検証) | 127.48 | 153.94 | - |
SciCap (第一文、サブ図なし) (train+val) |
CIDEr/BLEU-4 (テスト) |
162.25 0.192 |
181.49 0.211 |
- |
Screen2words (train+dev) |
CIDEr(テスト) | 117.57 | 119.59 | - |
Widget Captioning (train+dev) |
CIDEr(テスト) | 136.07 | 148.36 | - |
QA | ||||
VQAv2 (train+validation) |
精度 (テストサーバ - 標準) |
83.19 | 85.64 | - |
MMVP (VQAv2転用評価) |
ペア精度 | 47.33 | 45.33 | - |
POPE (VQAv2転用評価) |
精度 (ランダム/人気/ 対抗) |
87.80 85.87 84.27 |
88.23 86.77 85.90 |
- |
OKVQA (学習) |
精度(検証) | 63.54 | 63.15 | - |
A-OKVQA (MC) (train+val) |
精度 (テストサーバ) |
76.37 | 76.90 | - |
A-OKVQA (DA) (train+val) |
精度 (テストサーバ) |
61.85 | 63.22 | - |
GQA (train_balanced+ val_balanced) |
精度 (testdev balanced) |
65.61 | 67.03 | - |
xGQA (GQA転用評価) |
平均精度 (bn, de, en, id, ko, pt, ru, zh) |
58.37 | 59.07 | - |
NLVR2 (train+dev) |
精度(テスト) | 90.02 | 88.93 | - |
MaRVL (NLVR2転用評価) |
平均精度 (テスト) (id, sw, ta, tr, zh) |
80.57 | 76.78 | - |
AI2D (学習) |
精度(テスト) | 72.12 | 73.28 | - |
ScienceQA (画像サブセット、思考チェーンなし) (train+val) |
精度(テスト) | 95.39 | 95.93 | - |
RSVQA-LR (Non numeric) (train+val) |
平均精度 (テスト) |
92.65 | 93.11 | - |
RSVQA-HR (Non numeric) (train+val) |
平均精度 (テスト/テスト2) |
92.61 90.58 |
92.79 90.54 |
- |
ChartQA (human+aug)x(train+val) |
平均緩和精度 (テスト人間、 テスト拡張) |
57.08 | 71.36 | - |
VizWiz VQA (train+val) |
精度 (テストサーバ - 標準) |
73.7 | 75.52 | - |
TallyQA (学習) |
精度 (test_simple/ test_complex) |
81.72 69.56 |
84.86 72.27 |
- |
OCR-VQA (train+val) |
精度(テスト) | 72.32 | 74.61 | 74.93 |
TextVQA (train+val) |
精度 (テストサーバ - 標準) |
55.47 | 73.15 | 76.48 |
DocVQA (train+val) |
ANLS(テストサーバ) | 43.74 | 78.02 | 84.77 |
Infographic VQA (train+val) |
ANLS(テストサーバ) | 28.46 | 40.47 | 47.75 |
SceneText VQA (train+val) |
ANLS(テストサーバ) | 63.29 | 81.82 | 84.40 |
セグメンテーション | ||||
RefCOCO (refcoco, refcoco+, refcocogの組み合わせ、検証 とテスト画像を除外) |
MIoU (検証) refcoco/refcoco+/ refcocog |
73.40 68.32 67.65 |
75.57 69.76 70.17 |
76.94 72.18 72.22 |
ビデオタスク(字幕/QA) | ||||
MSR - (ドキュメントでここに関連する内容は完全に記載されていません) |
📄 ライセンス
このプロジェクトはgemmaライセンスを採用しています。









