🚀 分割一切模型(SAM) - ViT 巨型(ViT - H)版本
分割一切模型(SAM)能夠根據點或框等輸入提示生成高質量的對象掩碼,還可用於為圖像中的所有對象生成掩碼。它在包含 1100 萬張圖像和 11 億個掩碼的數據集上進行訓練,在各種分割任務中具有出色的零樣本性能。
🚀 快速開始
模型概覽
原始倉庫鏈接
該論文的摘要指出:
我們推出了分割一切(SA)項目:一個用於圖像分割的新任務、模型和數據集。通過在數據收集循環中使用高效模型,我們構建了迄今為止最大的分割數據集(遠超以往),在 1100 萬張有許可且尊重隱私的圖像上擁有超過 11 億個掩碼。該模型的設計和訓練具有可提示性,因此它可以零樣本遷移到新的圖像分佈和任務。我們在眾多任務上評估了其能力,發現其零樣本性能令人印象深刻 —— 通常與之前的全監督結果相媲美,甚至更優。我們在 https://segment-anything.com 上發佈了分割一切模型(SAM)以及包含 11 億個掩碼和 1100 萬張圖像的相應數據集(SA - 1B),以促進計算機視覺基礎模型的研究。
免責聲明:此模型卡片的內容由 Hugging Face 團隊編寫,部分內容從原始 SAM 模型卡片 複製粘貼而來。
模型架構
分割一切模型(SAM)的詳細架構。
SAM 模型由 3 個模塊組成:
VisionEncoder
:一個基於 VIT 的圖像編碼器。它使用注意力機制對圖像塊進行計算,以得到圖像嵌入,並使用了相對位置嵌入。
PromptEncoder
:為點和邊界框生成嵌入。
MaskDecoder
:一個雙向變壓器,在圖像嵌入和點嵌入之間進行交叉注意力計算(->),並在點嵌入和圖像嵌入之間進行交叉注意力計算。輸出結果會被進一步處理。
Neck
:根據 MaskDecoder
生成的上下文掩碼預測輸出掩碼。
💻 使用示例
基礎用法 - 提示掩碼生成
from PIL import Image
import requests
from transformers import SamModel, SamProcessor
model = SamModel.from_pretrained("facebook/sam-vit-huge")
processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
input_points = [[[450, 600]]]
inputs = processor(raw_image, input_points=input_points, return_tensors="pt").to("cuda")
outputs = model(**inputs)
masks = processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())
scores = outputs.iou_scores
在生成掩碼的其他參數中,你可以傳入感興趣對象的大致 2D 位置、包裹感興趣對象的邊界框(格式應為邊界框右上角和左下角的 x、y 座標)、分割掩碼。在撰寫本文時,根據 官方倉庫,官方模型不支持將文本作為輸入。
更多詳細信息,請參考這個筆記本,它通過可視化示例展示瞭如何使用該模型!
高級用法 - 自動掩碼生成
該模型可用於以“零樣本”方式為輸入圖像生成分割掩碼。模型會自動使用一個包含 1024
個點的網格進行提示,並將這些點全部輸入到模型中。
以下是自動掩碼生成的管道示例,展示瞭如何輕鬆運行它(可在任何設備上運行!只需傳入適當的 points_per_batch
參數):
from transformers import pipeline
generator = pipeline("mask-generation", device = 0, points_per_batch = 256)
image_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
outputs = generator(image_url, points_per_batch = 256)
現在來顯示圖像:
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
def show_mask(mask, ax, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
plt.imshow(np.array(raw_image))
ax = plt.gca()
for mask in outputs["masks"]:
show_mask(mask, ax=ax, random_color=True)
plt.axis("off")
plt.show()
這應該會得到如下結果:
📄 許可證
本項目採用 Apache - 2.0 許可證。
📚 引用
如果您使用此模型,請使用以下 BibTeX 條目進行引用。
@article{kirillov2023segany,
title={Segment Anything},
author={Kirillov, Alexander and Mintun, Eric and Ravi, Nikhila and Mao, Hanzi and Rolland, Chloe and Gustafson, Laura and Xiao, Tete and Whitehead, Spencer and Berg, Alexander C. and Lo, Wan-Yen and Doll{\'a}r, Piotr and Girshick, Ross},
journal={arXiv:2304.02643},
year={2023}
}