🚀 SlimSAM(SAM压缩版,即“任意分割模型”)模型卡片
SlimSAM是任意分割模型(SAM)的压缩(剪枝)版本,能够根据点或框等输入提示生成高质量的目标掩码。
📚 详细文档
🔖 目录
- 简要概述
- 模型详情
- 使用方法
- 引用信息
⏱️ 简要概述
SlimSAM是任意分割模型(SAM)的压缩(剪枝)版本,能够根据点或框等输入提示生成高质量的目标掩码。
论文摘要指出:
任意分割模型(SAM)庞大的模型规模和高昂的计算需求,使其在资源受限的设备上部署变得困难。现有的SAM压缩方法通常需要从头开始训练一个新的网络,在压缩成本和模型性能之间面临着艰难的权衡。为了解决这个问题,本文提出了SlimSAM,一种新颖的SAM压缩方法,它以极低的训练成本实现了卓越的性能。这是通过一个统一的剪枝 - 蒸馏框架高效重用预训练的SAM来实现的。为了增强从原始SAM继承的知识,我们采用了一种创新的交替瘦身策略,将压缩过程划分为一个渐进的过程。与以往的剪枝技术不同,我们以交替的方式精心剪枝和蒸馏解耦的模型结构。此外,还提出了一种新颖的无标签剪枝准则,使剪枝目标与优化目标保持一致,从而提高剪枝后的蒸馏效果。SlimSAM在性能上有显著提升,同时训练成本比任何现有方法低10倍以上。即使与原始的SAM - H相比,SlimSAM在性能接近的情况下,将参数数量减少到仅0.9%(570万),MACs减少到0.8%(21G),并且只需要SAM训练数据的0.1%(1万)。
原始仓库链接
免责声明:本模型卡片的内容由Hugging Face团队撰写,部分内容从原始的SAM模型卡片复制粘贴而来。
🧐 模型详情
SAM模型由3个模块组成:
VisionEncoder
:基于VIT的图像编码器。它通过对图像块进行注意力计算来生成图像嵌入,使用了相对位置嵌入。
PromptEncoder
:为点和边界框生成嵌入。
MaskDecoder
:一种双向变压器,在图像嵌入和点嵌入之间(->)以及点嵌入和图像嵌入之间执行交叉注意力。输出被输入到下一步。
Neck
:根据MaskDecoder
生成的上下文掩码预测输出掩码。
💻 使用示例
🌟 基础用法
from PIL import Image
import requests
from transformers import SamModel, SamProcessor
model = SamModel.from_pretrained("nielsr/slimsam-77-uniform")
processor = SamProcessor.from_pretrained("nielsr/slimsam-77-uniform")
img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
input_points = [[[450, 600]]]
inputs = processor(raw_image, input_points=input_points, return_tensors="pt").to("cuda")
outputs = model(**inputs)
masks = processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())
scores = outputs.iou_scores
在生成掩码的其他参数中,您可以传递感兴趣对象大致位置的二维坐标、包围感兴趣对象的边界框(格式应为边界框右上角和左下角的x、y坐标)、分割掩码。在撰写本文时,根据官方仓库,官方模型不支持将文本作为输入。
如需更多详细信息,请参考此笔记本,其中展示了如何使用该模型的详细步骤,并配有可视化示例!
from transformers import pipeline
generator = pipeline(task="mask-generation", model="nielsr/slimsam-77-uniform", device = 0, points_per_batch = 256)
image_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
outputs = generator(image_url, points_per_batch = 256)
现在展示图像:
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
def show_mask(mask, ax, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
plt.imshow(np.array(raw_image))
ax = plt.gca()
for mask in outputs["masks"]:
show_mask(mask, ax=ax, random_color=True)
plt.axis("off")
plt.show()
📑 引用信息
如果您使用此模型,请使用以下BibTeX条目。
@article{kirillov2023segany,
title={Segment Anything},
author={Kirillov, Alexander and Mintun, Eric and Ravi, Nikhila and Mao, Hanzi and Rolland, Chloe and Gustafson, Laura and Xiao, Tete and Whitehead, Spencer and Berg, Alexander C. and Lo, Wan-Yen and Doll{\'a}r, Piotr and Girshick, Ross},
journal={arXiv:2304.02643},
year={2023}
}
@misc{chen202301,
title={0.1% Data Makes Segment Anything Slim},
author={Zigeng Chen and Gongfan Fang and Xinyin Ma and Xinchao Wang},
year={2023},
eprint={2312.05284},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
📄 许可证
本模型采用Apache - 2.0许可证。