🚀 SlimSAM(SAM壓縮版,即“任意分割模型”)模型卡片
SlimSAM是任意分割模型(SAM)的壓縮(剪枝)版本,能夠根據點或框等輸入提示生成高質量的目標掩碼。
📚 詳細文檔
🔖 目錄
- 簡要概述
- 模型詳情
- 使用方法
- 引用信息
⏱️ 簡要概述
SlimSAM是任意分割模型(SAM)的壓縮(剪枝)版本,能夠根據點或框等輸入提示生成高質量的目標掩碼。
論文摘要指出:
任意分割模型(SAM)龐大的模型規模和高昂的計算需求,使其在資源受限的設備上部署變得困難。現有的SAM壓縮方法通常需要從頭開始訓練一個新的網絡,在壓縮成本和模型性能之間面臨著艱難的權衡。為了解決這個問題,本文提出了SlimSAM,一種新穎的SAM壓縮方法,它以極低的訓練成本實現了卓越的性能。這是通過一個統一的剪枝 - 蒸餾框架高效重用預訓練的SAM來實現的。為了增強從原始SAM繼承的知識,我們採用了一種創新的交替瘦身策略,將壓縮過程劃分為一個漸進的過程。與以往的剪枝技術不同,我們以交替的方式精心剪枝和蒸餾解耦的模型結構。此外,還提出了一種新穎的無標籤剪枝準則,使剪枝目標與優化目標保持一致,從而提高剪枝後的蒸餾效果。SlimSAM在性能上有顯著提升,同時訓練成本比任何現有方法低10倍以上。即使與原始的SAM - H相比,SlimSAM在性能接近的情況下,將參數數量減少到僅0.9%(570萬),MACs減少到0.8%(21G),並且只需要SAM訓練數據的0.1%(1萬)。
原始倉庫鏈接
免責聲明:本模型卡片的內容由Hugging Face團隊撰寫,部分內容從原始的SAM模型卡片複製粘貼而來。
🧐 模型詳情
SAM模型由3個模塊組成:
VisionEncoder
:基於VIT的圖像編碼器。它通過對圖像塊進行注意力計算來生成圖像嵌入,使用了相對位置嵌入。
PromptEncoder
:為點和邊界框生成嵌入。
MaskDecoder
:一種雙向變壓器,在圖像嵌入和點嵌入之間(->)以及點嵌入和圖像嵌入之間執行交叉注意力。輸出被輸入到下一步。
Neck
:根據MaskDecoder
生成的上下文掩碼預測輸出掩碼。
💻 使用示例
🌟 基礎用法
from PIL import Image
import requests
from transformers import SamModel, SamProcessor
model = SamModel.from_pretrained("nielsr/slimsam-77-uniform")
processor = SamProcessor.from_pretrained("nielsr/slimsam-77-uniform")
img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
input_points = [[[450, 600]]]
inputs = processor(raw_image, input_points=input_points, return_tensors="pt").to("cuda")
outputs = model(**inputs)
masks = processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())
scores = outputs.iou_scores
在生成掩碼的其他參數中,您可以傳遞感興趣對象大致位置的二維座標、包圍感興趣對象的邊界框(格式應為邊界框右上角和左下角的x、y座標)、分割掩碼。在撰寫本文時,根據官方倉庫,官方模型不支持將文本作為輸入。
如需更多詳細信息,請參考此筆記本,其中展示瞭如何使用該模型的詳細步驟,並配有可視化示例!
from transformers import pipeline
generator = pipeline(task="mask-generation", model="nielsr/slimsam-77-uniform", device = 0, points_per_batch = 256)
image_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
outputs = generator(image_url, points_per_batch = 256)
現在展示圖像:
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
def show_mask(mask, ax, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
plt.imshow(np.array(raw_image))
ax = plt.gca()
for mask in outputs["masks"]:
show_mask(mask, ax=ax, random_color=True)
plt.axis("off")
plt.show()
📑 引用信息
如果您使用此模型,請使用以下BibTeX條目。
@article{kirillov2023segany,
title={Segment Anything},
author={Kirillov, Alexander and Mintun, Eric and Ravi, Nikhila and Mao, Hanzi and Rolland, Chloe and Gustafson, Laura and Xiao, Tete and Whitehead, Spencer and Berg, Alexander C. and Lo, Wan-Yen and Doll{\'a}r, Piotr and Girshick, Ross},
journal={arXiv:2304.02643},
year={2023}
}
@misc{chen202301,
title={0.1% Data Makes Segment Anything Slim},
author={Zigeng Chen and Gongfan Fang and Xinyin Ma and Xinchao Wang},
year={2023},
eprint={2312.05284},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
📄 許可證
本模型採用Apache - 2.0許可證。