🚀 分割一切模型(SAM) - ViT大模型(ViT - L)版本
分割一切模型(SAM) 可以根据点或框等输入提示生成高质量的对象掩码,还能为图像中的所有对象生成掩码。它在包含1100万张图像和11亿个掩码的数据集上进行了训练,在各种分割任务中具有出色的零样本性能。
🚀 快速开始
模型简介
分割一切模型(SAM) 能根据点或框等输入提示生成高质量的对象掩码,还可用于为图像中的所有对象生成掩码。它在一个包含1100万张图像和11亿个掩码的数据集上进行了训练,在各种分割任务中具有强大的零样本性能。
论文摘要指出:
我们推出了分割一切(SA)项目:一个用于图像分割的新任务、模型和数据集。通过在数据收集循环中使用我们高效的模型,我们构建了迄今为止最大的分割数据集(远超以往),在1100万张经过授权且尊重隐私的图像上有超过10亿个掩码。该模型设计并训练为可提示的,因此它可以零样本迁移到新的图像分布和任务。我们在众多任务上评估了其能力,发现其零样本性能令人印象深刻 —— 通常与之前的全监督结果具有竞争力,甚至更优。我们在 https://segment-anything.com 发布了分割一切模型(SAM)以及包含10亿个掩码和1100万张图像的相应数据集(SA - 1B),以促进计算机视觉基础模型的研究。
免责声明:此模型卡片的内容由Hugging Face团队编写,部分内容从原始的 SAM模型卡片 复制粘贴而来。
模型架构图
分割一切模型(SAM)的详细架构。
示例图片
原仓库链接
原仓库链接
✨ 主要特性
- 可根据点或框等输入提示生成高质量的对象掩码。
- 能够为图像中的所有对象生成掩码。
- 在大规模数据集上训练,具有强大的零样本性能。
📚 详细文档
模型细节
SAM模型由3个模块组成:
VisionEncoder
:一个基于VIT的图像编码器。它使用注意力机制对图像块进行计算,得到图像嵌入。使用了相对位置嵌入。
PromptEncoder
:为点和边界框生成嵌入。
MaskDecoder
:一个双向变压器,在图像嵌入和点嵌入之间进行交叉注意力计算(->),并在点嵌入和图像嵌入之间进行交叉注意力计算。输出结果会被进一步处理。
Neck
:根据MaskDecoder
生成的上下文掩码预测输出掩码。
使用方法
提示掩码生成
from PIL import Image
import requests
from transformers import SamModel, SamProcessor
model = SamModel.from_pretrained("facebook/sam-vit-large")
processor = SamProcessor.from_pretrained("facebook/sam-vit-large")
img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
input_points = [[[450, 600]]]
inputs = processor(raw_image, input_points=input_points, return_tensors="pt").to("cuda")
outputs = model(**inputs)
masks = processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())
scores = outputs.iou_scores
在生成掩码的其他参数中,你可以传递感兴趣对象大致位置的二维坐标、包裹感兴趣对象的边界框(格式应为边界框右上角和左下角的x、y坐标)、分割掩码。在撰写本文时,根据 官方仓库,官方模型不支持将文本作为输入。
更多详细信息,请参考这个笔记本,它通过可视化示例展示了如何使用该模型!
自动掩码生成
该模型可以在给定输入图像的情况下,以“零样本”方式生成分割掩码。模型会自动用一个包含1024
个点的网格进行提示,并将这些点全部输入到模型中。
以下管道用于自动掩码生成。以下代码片段展示了如何轻松运行它(在任何设备上!只需提供适当的points_per_batch
参数)
from transformers import pipeline
generator = pipeline("mask-generation", device = 0, points_per_batch = 256)
image_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
outputs = generator(image_url, points_per_batch = 256)
现在来显示图像:
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
def show_mask(mask, ax, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
plt.imshow(np.array(raw_image))
ax = plt.gca()
for mask in outputs["masks"]:
show_mask(mask, ax=ax, random_color=True)
plt.axis("off")
plt.show()
引用
如果您使用此模型,请使用以下BibTeX条目进行引用。
@article{kirillov2023segany,
title={Segment Anything},
author={Kirillov, Alexander and Mintun, Eric and Ravi, Nikhila and Mao, Hanzi and Rolland, Chloe and Gustafson, Laura and Xiao, Tete and Whitehead, Spencer and Berg, Alexander C. and Lo, Wan-Yen and Doll{\'a}r, Piotr and Girshick, Ross},
journal={arXiv:2304.02643},
year={2023}
}
📄 许可证
本项目采用Apache - 2.0许可证。