🚀 RobustSAM:在退化圖像上實現穩健的任意分割
RobustSAM是一種針對退化圖像進行優化的分割模型。它在保留SAM模型的可提示性和零樣本泛化能力的基礎上,顯著提升了在低質量圖像上的分割性能。同時,該模型只需少量參數增量和計算資源,具有較高的可行性和實用性。
🚀 快速開始
Segment Anything Model (SAM) 在圖像分割領域展現出強大的零樣本分割能力和靈活的提示系統,但在處理低質量圖像時性能受限。為解決這一問題,我們提出了Robust Segment Anything Model (RobustSAM),它在提升SAM在低質量圖像上性能的同時,保留了其可提示性和零樣本泛化能力。
我們的方法基於預訓練的SAM模型,僅需少量參數增量和計算資源。RobustSAM的額外參數可在8個GPU上30小時內完成優化,適合一般研究實驗室使用。此外,我們還引入了包含688K對不同退化圖像 - 掩碼對的Robust - Seg數據集,用於模型的訓練和評估。大量實驗表明,RobustSAM在各種分割任務和數據集上表現出色,尤其是在零樣本條件下,具有廣泛的實際應用潛力。同時,該方法還能有效提升基於SAM的下游任務(如單圖像去霧和去模糊)的性能。

免責聲明:本模型卡片的內容由Hugging Face團隊撰寫,部分內容從原始的 SAM模型卡片 複製粘貼而來。
✨ 主要特性
- 性能提升:顯著增強了SAM在低質量圖像上的分割性能。
- 資源高效:只需少量參數增量和計算資源。
- 數據集豐富:引入Robust - Seg數據集,包含688K對不同退化圖像 - 掩碼對。
- 應用廣泛:可有效提升基於SAM的下游任務性能。
📦 安裝指南
由於文檔未提供具體安裝命令,此部分跳過。
💻 使用示例
基礎用法
from PIL import Image
import requests
from transformers import AutoProcessor, AutoModelForMaskGeneration
processor = AutoProcessor.from_pretrained("jadechoghari/robustsam-vit-base")
model = AutoModelForMaskGeneration.from_pretrained("jadechoghari/robustsam-vit-base")
img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
input_points = [[[450, 600]]]
高級用法
inputs = processor(raw_image, input_points=input_points, return_tensors="pt").to("cuda")
with torch.no_grad():
outputs = model(**inputs)
masks = processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())
scores = outputs.iou_scores
在生成掩碼時,除了上述代碼中的輸入點,你還可以傳入感興趣對象的二維位置、包圍感興趣對象的邊界框(格式應為邊界框左上角和右下角的x、y座標)、分割掩碼。根據 官方倉庫,目前官方模型不支持將文本作為輸入。更多詳細信息,請參考相關筆記本,其中有可視化示例展示瞭如何使用該模型。
自動掩碼生成
from transformers import pipeline
generator = pipeline("mask-generation", model="jadechoghari/robustsam-vit-base", device=0, points_per_batch=256)
image_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
outputs = generator(image_url, points_per_batch=256)
以下代碼用於在圖像上顯示生成的掩碼:
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
def show_mask(mask, ax, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
plt.imshow(np.array(raw_image))
ax = plt.gca()
for mask in outputs["masks"]:
show_mask(mask, ax=ax, random_color=True)
plt.axis("off")
plt.show()
視覺對比

📚 詳細文檔
模型細節
RobustSAM模型由以下4個模塊組成:
VisionEncoder
:基於VIT的圖像編碼器。它使用注意力機制對圖像塊進行處理,計算圖像嵌入,並使用相對位置嵌入。
PromptEncoder
:為點和邊界框生成嵌入。
MaskDecoder
:雙向Transformer,在圖像嵌入和點嵌入之間進行交叉注意力計算(->),並在點嵌入和圖像嵌入之間進行交叉注意力計算。輸出結果將被進一步處理。
Neck
:根據MaskDecoder
生成的上下文掩碼預測輸出掩碼。
🔧 技術細節
本方法基於預訓練的SAM模型,通過少量參數增量和計算資源的投入,提升了模型在低質量圖像上的性能。RobustSAM的額外參數可在8個GPU上30小時內完成優化,證明了其在一般研究實驗室中的可行性和實用性。同時,引入的Robust - Seg數據集為模型的訓練和評估提供了豐富的數據支持。
📄 許可證
本項目採用MIT許可證。
📖 引用
如果您覺得本工作有用,請考慮引用我們:
@inproceedings{chen2024robustsam,
title={RobustSAM: Segment Anything Robustly on Degraded Images},
author={Chen, Wei - Ting and Vong, Yu - Jiet and Kuo, Sy - Yen and Ma, Sizhou and Wang, Jian},
journal={CVPR},
year={2024}
}
🙏 致謝
感謝 SAM 的作者,我們的倉庫基於該項目開發。