🚀 RobustSAM: 劣化画像に対する高精度セグメンテーションモデル (CVPR 2024 Highlight)
RobustSAMは、劣化画像に対しても強力なセグメンテーション能力を発揮するモデルです。SAMの性能を向上させ、低品質画像でも高精度なセグメンテーションを実現します。

RobustSAMの公式リポジトリです。劣化画像に対しても強力なセグメンテーション能力を発揮します。
プロジェクトページ | 論文 | データセット
🚀 クイックスタート
Segment Anything Model (SAM) は画像セグメンテーションにおいて画期的なアプローチとして登場し、強力なゼロショットセグメンテーション能力と柔軟なプロンプトシステムで評価されています。しかし、画質が劣化した画像に対しては性能が低下するという課題があります。この制限を解消するために、我々はRobust Segment Anything Model (RobustSAM) を提案します。このモデルは、低品質画像に対するSAMの性能を向上させると同時に、プロンプト性とゼロショット汎化能力を維持します。
我々の手法は、事前学習されたSAMモデルを活用し、わずかなパラメータの増加と計算コストで実現されます。RobustSAMの追加パラメータは、8台のGPUで30時間以内に最適化できるため、一般的な研究室でも実用的です。また、我々はRobust-Segデータセットを導入しました。これは、様々な劣化を持つ688Kの画像-マスクペアのコレクションで、モデルの訓練と評価に最適です。様々なセグメンテーションタスクとデータセットに対する広範な実験により、RobustSAMの優れた性能が確認されており、特にゼロショット条件下での性能が高く、実世界での広範な応用が期待されます。さらに、我々の手法は、単一画像のヘイズ除去やブラー除去などのSAMベースの下流タスクの性能を効果的に向上させることが示されています。
免責事項: このモデルカードの内容はHugging Faceチームによって作成され、一部は元のSAMモデルカードからコピーされています。
✨ 主な機能
- 低品質画像での高性能化:SAMの性能を向上させ、低品質画像に対しても高精度なセグメンテーションを実現します。
- 少ないパラメータ増加:事前学習されたSAMモデルを活用し、わずかなパラメータの増加で実現されます。
- 実用的な訓練時間:追加パラメータは8台のGPUで30時間以内に最適化できます。
- Robust-Segデータセット:様々な劣化を持つ画像-マスクペアのコレクションで、モデルの訓練と評価に最適です。
- 広範な応用可能性:様々なセグメンテーションタスクやSAMベースの下流タスクでの性能向上が期待されます。
📚 ドキュメント
モデルの詳細
RobustSAMモデルは3つのモジュールで構成されています。
VisionEncoder
:VITベースの画像エンコーダです。画像のパッチに対するアテンションを使用して画像埋め込みを計算します。相対位置埋め込みが使用されています。
PromptEncoder
:ポイントとバウンディングボックスの埋め込みを生成します。
MaskDecoder
:双方向トランスフォーマーで、画像埋め込みとポイント埋め込みの間、およびポイント埋め込みと画像埋め込みの間でクロスアテンションを実行します。出力はNeck
に入力されます。
Neck
:MaskDecoder
によって生成された文脈化されたマスクに基づいて出力マスクを予測します。
💻 使用例
基本的な使用法
プロンプト付きマスク生成
from PIL import Image
import requests
from transformers import AutoProcessor, AutoModelForMaskGeneration
processor = AutoProcessor.from_pretrained("jadechoghari/robustsam-vit-base")
model = AutoModelForMaskGeneration.from_pretrained("jadechoghari/robustsam-vit-base")
img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
input_points = [[[450, 600]]]
inputs = processor(raw_image, input_points=input_points, return_tensors="pt").to("cuda")
with torch.no_grad():
outputs = model(**inputs)
masks = processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())
scores = outputs.iou_scores
マスクを生成するための他の引数の中で、関心のあるオブジェクトの概略位置の2D座標、関心のあるオブジェクトを囲むバウンディングボックス(バウンディングボックスの右上と左下の点のx, y座標の形式)、セグメンテーションマスクを渡すことができます。執筆時点では、公式モデルではテキストを入力として渡すことはサポートされていません。詳細については、このノートブックを参照してください。これは、モデルの使用方法を視覚的な例とともに説明しています。
自動マスク生成
モデルは、入力画像を与えることで「ゼロショット」方式でセグメンテーションマスクを生成するために使用できます。モデルには自動的に1024
のポイントのグリッドがプロンプトとして与えられ、すべてがモデルに入力されます。
以下のコードは、マスク生成のパイプラインを初期化し、画像からマスクを生成する方法を示しています。
from transformers import pipeline
generator = pipeline("mask-generation", model="jadechoghari/robustsam-vit-base", device=0, points_per_batch=256)
image_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
outputs = generator(image_url, points_per_batch=256)
生成されたマスクを画像上に表示するには、以下のコードを使用します。
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
def show_mask(mask, ax, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
plt.imshow(np.array(raw_image))
ax = plt.gca()
for mask in outputs["masks"]:
show_mask(mask, ax=ax, random_color=True)
plt.axis("off")
plt.show()
視覚的な比較
🔧 引用
この研究が役に立った場合は、以下のように引用してください。
@inproceedings{chen2024robustsam,
title={RobustSAM: Segment Anything Robustly on Degraded Images},
author={Chen, Wei-Ting and Vong, Yu-Jiet and Kuo, Sy-Yen and Ma, Sizhou and Wang, Jian},
journal={CVPR},
year={2024}
}
謝辞
我々のリポジトリは、SAM の著者のおかげで可能になりました。彼らに感謝します。
📄 ライセンス
このプロジェクトはMITライセンスの下で公開されています。