モデル概要
モデル特徴
モデル能力
使用事例
library_name: transformers pipeline_tag: mask-generation license: apache-2.0 tags:
- vision
高品質セグメンテーション・アニシングモデル(SAM-HQ)のモデルカード
オリジナルSAMモデルと比較したSAM-HQのアーキテクチャ。HQ出力トークンとグローバル-ローカル特徴融合コンポーネントを示しています。
目次
概要
SAM-HQ(高品質セグメント・アニシング)は、ポイントやボックスなどの入力プロンプトから高品質なオブジェクトマスクを生成するSegment Anything Model(SAM)の拡張版です。SAMは1100万枚の画像と11億のマスクでトレーニングされましたが、そのマスク予測品質は多くの場合、特に複雑な構造を持つオブジェクトを扱う際に不十分でした。SAM-HQは、最小限の追加パラメータと計算コストでこれらの制限に対処します。
このモデルは、複雑な境界や細い構造を持つオブジェクトに対しても高品質なセグメンテーションマスクを生成するのに優れており、オリジナルのSAMモデルが苦戦する場面でも高い性能を発揮します。SAM-HQはSAMの元のプロンプト可能な設計、効率性、ゼロショット汎化性を維持しながら、マスク品質を大幅に向上させています。
モデル詳細
SAM-HQは、SAMの事前学習済み重みを保持しながら、オリジナルのSAMアーキテクチャに2つの主要な革新を加えています:
-
高品質出力トークン: SAMのマスクデコーダに注入される学習可能なトークンで、高品質なマスクを予測する役割を担います。SAMの元の出力トークンとは異なり、このトークンとそれに関連するMLP層は、高精度なセグメンテーションマスクを生成するために特別にトレーニングされています。
-
グローバル-ローカル特徴融合: HQ出力トークンをマスクデコーダの特徴にのみ適用するのではなく、SAM-HQはまずこれらの特徴を初期および最終ViT特徴と融合させ、マスクの詳細を改善します。これにより、高レベルの意味的コンテキストと低レベルの境界情報の両方が組み合わされ、より正確なセグメンテーションが可能になります。
SAM-HQは、極めて正確なアノテーションを持つ複数のソースから収集された44,000の細かいマスク(HQSeg-44K)の慎重にキュレーションされたデータセットでトレーニングされました。トレーニングプロセスは8GPUでわずか4時間しかかからず、オリジナルのSAMモデルと比較して0.5%未満の追加パラメータしか導入しません。
このモデルは、さまざまな下流タスクにわたる10の多様なセグメンテーションデータセットのスイートで評価され、そのうち8つはゼロショット転移プロトコルで評価されました。結果は、SAM-HQがオリジナルのSAMモデルよりも大幅に優れたマスクを生成できる一方で、そのゼロショット汎化能力を維持していることを示しています。
SAM-HQは、オリジナルのSAMモデルの2つの主要な問題に対処します:
- 粗いマスク境界、特に細いオブジェクト構造を無視しがちな問題
- 困難なケースでの誤った予測、壊れたマスク、または大きなエラー
これらの改善により、SAM-HQは自動アノテーションや画像/動画編集タスクなど、高度に正確な画像マスクを必要とするアプリケーションに特に価値があります。
使用方法
プロンプト付きマスク生成
from PIL import Image
import requests
from transformers import SamHQModel, SamHQProcessor
model = SamHQModel.from_pretrained("syscv-community/sam-hq-vit-huge")
processor = SamHQProcessor.from_pretrained("syscv-community/sam-hq-vit-huge")
img_url = "https://raw.githubusercontent.com/SysCV/sam-hq/refs/heads/main/demo/input_imgs/example1.png"
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
input_boxes = [[[306, 132, 925, 893]]] # 画像のバウンディングボックス
inputs = processor(raw_image, input_boxes=input_boxes, return_tensors="pt").to("cuda")
outputs = model(**inputs)
masks = processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())
scores = outputs.iou_scores
マスク生成のための他の引数の中には、関心のあるオブジェクトのおおよその位置にある2D位置、関心のあるオブジェクトを囲むバウンディングボックス(形式はバウンディングボックスの右上と左下の点のx、y座標である必要があります)、セグメンテーションマスクを渡すことができます。この記事の執筆時点では、公式リポジトリによると、公式モデルではテキストを入力として渡すことはサポートされていません。詳細については、モデルの使用方法と視覚的な例を示すこのノートブックを参照してください!
自動マスク生成
このモデルは、入力画像が与えられると、「ゼロショット」方式でセグメンテーションマスクを生成するために使用できます。モデルには1024
ポイントのグリッドが自動的にプロンプトとして与えられ、すべてがモデルにフィードされます。
パイプラインは自動マスク生成のために作られています。次のスニペットは、それを実行するのがどれほど簡単かを示しています(任意のデバイスで!適切なpoints_per_batch
引数を渡すだけです)
from transformers import pipeline
generator = pipeline("mask-generation", model="syscv-community/sam-hq-vit-huge", device=0, points_per_batch=256)
image_url = "https://raw.githubusercontent.com/SysCV/sam-hq/refs/heads/main/demo/input_imgs/example1.png"
outputs = generator(image_url, points_per_batch=256)
画像を表示するには:
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
def show_mask(mask, ax, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
plt.imshow(np.array(raw_image))
ax = plt.gca()
for mask in outputs["masks"]:
show_mask(mask, ax=ax, random_color=True)
plt.axis("off")
plt.show()
可視化付き完全な例
import numpy as np
import matplotlib.pyplot as plt
def show_mask(mask, ax, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30/255, 144/255, 255/255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
def show_box(box, ax):
x0, y0 = box[0], box[1]
w, h = box[2] - box[0], box[3] - box[1]
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
def show_boxes_on_image(raw_image, boxes):
plt.figure(figsize=(10,10))
plt.imshow(raw_image)
for box in boxes:
show_box(box, plt.gca())
plt.axis('on')
plt.show()
def show_points_on_image(raw_image, input_points, input_labels=None):
plt.figure(figsize=(10,10))
plt.imshow(raw_image)
input_points = np.array(input_points)
if input_labels is None:
labels = np.ones_like(input_points[:, 0])
else:
labels = np.array(input_labels)
show_points(input_points, labels, plt.gca())
plt.axis('on')
plt.show()
def show_points_and_boxes_on_image(raw_image, boxes, input_points, input_labels=None):
plt.figure(figsize=(10,10))
plt.imshow(raw_image)
input_points = np.array(input_points)
if input_labels is None:
labels = np.ones_like(input_points[:, 0])
else:
labels = np.array(input_labels)
show_points(input_points, labels, plt.gca())
for box in boxes:
show_box(box, plt.gca())
plt.axis('on')
plt.show()
def show_points_and_boxes_on_image(raw_image, boxes, input_points, input_labels=None):
plt.figure(figsize=(10,10))
plt.imshow(raw_image)
input_points = np.array(input_points)
if input_labels is None:
labels = np.ones_like(input_points[:, 0])
else:
labels = np.array(input_labels)
show_points(input_points, labels, plt.gca())
for box in boxes:
show_box(box, plt.gca())
plt.axis('on')
plt.show()
def show_points(coords, labels, ax, marker_size=375):
pos_points = coords[labels==1]
neg_points = coords[labels==0]
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
def show_masks_on_image(raw_image, masks, scores):
if len(masks.shape) == 4:
masks = masks.squeeze()
if scores.shape[0] == 1:
scores = scores.squeeze()
nb_predictions = scores.shape[-1]
fig, axes = plt.subplots(1, nb_predictions, figsize=(15, 15))
for i, (mask, score) in enumerate(zip(masks, scores)):
mask = mask.cpu().detach()
axes[i].imshow(np.array(raw_image))
show_mask(mask, axes[i])
axes[i].title.set_text(f"マスク {i+1}, スコア: {score.item():.3f}")
axes[i].axis("off")
plt.show()
def show_masks_on_single_image(raw_image, masks, scores):
if len(masks.shape) == 4:
masks = masks.squeeze()
if scores.shape[0] == 1:
scores = scores.squeeze()
# 画像がまだNumPy配列でない場合は変換
image_np = np.array(raw_image)
# 図を作成
fig, ax = plt.subplots(figsize=(8, 8))
ax.imshow(image_np)
# 同じ画像にすべてのマスクをオーバーレイ
for i, (mask, score) in enumerate(zip(masks, scores)):
mask = mask.cpu().detach().numpy() # NumPyに変換
show_mask(mask, ax) # `show_mask`がマスクを適切にオーバーレイすると仮定
ax.set_title(f"スコア付きオーバーレイマスク")
ax.axis("off")
plt.show()
import torch
from transformers import SamHQModel, SamHQProcessor
device = "cuda" if torch.cuda.is_available() else "cpu"
model = SamHQModel.from_pretrained("syscv-community/sam-hq-vit-huge").to(device)
processor = SamHQProcessor.from_pretrained("syscv-community/sam-hq-vit-huge")
from PIL import Image
import requests
img_url = "https://raw.githubusercontent.com/SysCV/sam-hq/refs/heads/main/demo/input_imgs/example1.png"
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
plt.imshow(raw_image)
inputs = processor(raw_image, return_tensors="pt").to(device)
image_embeddings, intermediate_embeddings = model.get_image_embeddings(inputs["pixel_values"])
input_boxes = [[[306, 132, 925, 893]]]
show_boxes_on_image(raw_image, input_boxes[0])
inputs.pop("pixel_values", None)
inputs.update({"image_embeddings": image_embeddings})
inputs.update({"intermediate_embeddings": intermediate_embeddings})
with torch.no_grad():
outputs = model(**inputs)
masks = processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())
scores = outputs.iou_scores
show_masks_on_single_image(raw_image, masks[0], scores)
show_masks_on_image(raw_image, masks[0], scores)
引用
@misc{ke2023segmenthighquality,
title={Segment Anything in High Quality},
author={Lei Ke and Mingqiao Ye and Martin Danelljan and Yifan Liu and Yu-Wing Tai and Chi-Keung Tang and Fisher Yu},
year={2023},
eprint={2306.01567},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2306.01567},
}











