🚀 RobustSAM:在退化图像上实现稳健的任意分割
RobustSAM是一种针对退化图像进行优化的分割模型。它在保留SAM模型的可提示性和零样本泛化能力的基础上,显著提升了在低质量图像上的分割性能。同时,该模型只需少量参数增量和计算资源,具有较高的可行性和实用性。
🚀 快速开始
Segment Anything Model (SAM) 在图像分割领域展现出强大的零样本分割能力和灵活的提示系统,但在处理低质量图像时性能受限。为解决这一问题,我们提出了Robust Segment Anything Model (RobustSAM),它在提升SAM在低质量图像上性能的同时,保留了其可提示性和零样本泛化能力。
我们的方法基于预训练的SAM模型,仅需少量参数增量和计算资源。RobustSAM的额外参数可在8个GPU上30小时内完成优化,适合一般研究实验室使用。此外,我们还引入了包含688K对不同退化图像 - 掩码对的Robust - Seg数据集,用于模型的训练和评估。大量实验表明,RobustSAM在各种分割任务和数据集上表现出色,尤其是在零样本条件下,具有广泛的实际应用潜力。同时,该方法还能有效提升基于SAM的下游任务(如单图像去雾和去模糊)的性能。

免责声明:本模型卡片的内容由Hugging Face团队撰写,部分内容从原始的 SAM模型卡片 复制粘贴而来。
✨ 主要特性
- 性能提升:显著增强了SAM在低质量图像上的分割性能。
- 资源高效:只需少量参数增量和计算资源。
- 数据集丰富:引入Robust - Seg数据集,包含688K对不同退化图像 - 掩码对。
- 应用广泛:可有效提升基于SAM的下游任务性能。
📦 安装指南
由于文档未提供具体安装命令,此部分跳过。
💻 使用示例
基础用法
from PIL import Image
import requests
from transformers import AutoProcessor, AutoModelForMaskGeneration
processor = AutoProcessor.from_pretrained("jadechoghari/robustsam-vit-base")
model = AutoModelForMaskGeneration.from_pretrained("jadechoghari/robustsam-vit-base")
img_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
input_points = [[[450, 600]]]
高级用法
inputs = processor(raw_image, input_points=input_points, return_tensors="pt").to("cuda")
with torch.no_grad():
outputs = model(**inputs)
masks = processor.image_processor.post_process_masks(outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu())
scores = outputs.iou_scores
在生成掩码时,除了上述代码中的输入点,你还可以传入感兴趣对象的二维位置、包围感兴趣对象的边界框(格式应为边界框左上角和右下角的x、y坐标)、分割掩码。根据 官方仓库,目前官方模型不支持将文本作为输入。更多详细信息,请参考相关笔记本,其中有可视化示例展示了如何使用该模型。
自动掩码生成
from transformers import pipeline
generator = pipeline("mask-generation", model="jadechoghari/robustsam-vit-base", device=0, points_per_batch=256)
image_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
outputs = generator(image_url, points_per_batch=256)
以下代码用于在图像上显示生成的掩码:
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
def show_mask(mask, ax, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
plt.imshow(np.array(raw_image))
ax = plt.gca()
for mask in outputs["masks"]:
show_mask(mask, ax=ax, random_color=True)
plt.axis("off")
plt.show()
视觉对比

📚 详细文档
模型细节
RobustSAM模型由以下4个模块组成:
VisionEncoder
:基于VIT的图像编码器。它使用注意力机制对图像块进行处理,计算图像嵌入,并使用相对位置嵌入。
PromptEncoder
:为点和边界框生成嵌入。
MaskDecoder
:双向Transformer,在图像嵌入和点嵌入之间进行交叉注意力计算(->),并在点嵌入和图像嵌入之间进行交叉注意力计算。输出结果将被进一步处理。
Neck
:根据MaskDecoder
生成的上下文掩码预测输出掩码。
🔧 技术细节
本方法基于预训练的SAM模型,通过少量参数增量和计算资源的投入,提升了模型在低质量图像上的性能。RobustSAM的额外参数可在8个GPU上30小时内完成优化,证明了其在一般研究实验室中的可行性和实用性。同时,引入的Robust - Seg数据集为模型的训练和评估提供了丰富的数据支持。
📄 许可证
本项目采用MIT许可证。
📖 引用
如果您觉得本工作有用,请考虑引用我们:
@inproceedings{chen2024robustsam,
title={RobustSAM: Segment Anything Robustly on Degraded Images},
author={Chen, Wei - Ting and Vong, Yu - Jiet and Kuo, Sy - Yen and Ma, Sizhou and Wang, Jian},
journal={CVPR},
year={2024}
}
🙏 致谢
感谢 SAM 的作者,我们的仓库基于该项目开发。