🚀 MambaVision:混合Mamba-Transformer视觉主干网络
MambaVision是一种专门为视觉应用设计的新型混合Mamba-Transformer主干网络。它重新设计了Mamba公式,增强了对视觉特征的高效建模能力,在图像分类和下游任务中表现出色。
🚀 快速开始
安装
强烈建议通过运行以下命令来安装MambaVision的依赖项:
pip install mambavision
✨ 主要特性
- 提出了一种新颖的混合Mamba-Transformer主干网络MambaVision,专为视觉应用量身定制。
- 重新设计Mamba公式,增强其对视觉特征的高效建模能力。
- 对Vision Transformers (ViT) 与Mamba集成的可行性进行了全面的消融研究。
- 在Mamba架构的最后几层配备几个自注意力块,大大提高了捕捉长距离空间依赖的建模能力。
- 引入了具有分层架构的MambaVision模型家族,以满足各种设计标准。
- 在ImageNet - 1K数据集的图像分类任务中,MambaVision模型变体在Top - 1准确率和图像吞吐量方面达到了新的最优性能。
- 在MS COCO和ADE20K数据集的目标检测、实例分割和语义分割等下游任务中,MambaVision优于同等规模的主干网络。
📚 详细文档
模型描述
我们提出了一种新颖的混合Mamba - Transformer主干网络,称为MambaVision,专门用于视觉应用。我们的核心贡献包括重新设计Mamba公式,以增强其对视觉特征的高效建模能力。此外,我们对Vision Transformers (ViT) 与Mamba集成的可行性进行了全面的消融研究。结果表明,在Mamba架构的最后几层配备几个自注意力块,大大提高了捕捉长距离空间依赖的建模能力。基于这些发现,我们引入了具有分层架构的MambaVision模型家族,以满足各种设计标准。在ImageNet - 1K数据集的图像分类任务中,MambaVision模型变体在Top - 1准确率和图像吞吐量方面达到了新的最优 (SOTA) 性能。在MS COCO和ADE20K数据集的目标检测、实例分割和语义分割等下游任务中,MambaVision优于同等规模的主干网络,表现更出色。代码链接:https://github.com/NVlabs/MambaVision 。
模型性能
MambaVision - L - 21K在ImageNet - 21K数据集上进行预训练,并在ImageNet - 1K上进行微调。
名称 |
准确率@1(%) |
准确率@5(%) |
参数数量(M) |
浮点运算次数(G) |
分辨率 |
MambaVision - L - 21K |
86.1 |
97.9 |
227.9 |
34.9 |
224x224 |
此外,MambaVision模型在Top - 1准确率和吞吐量方面达到了新的SOTA帕累托前沿,表现强劲。
模型使用
基础用法
图像分类
以下示例展示了如何使用MambaVision进行图像分类。给定来自COCO数据集验证集的图像作为输入:
from transformers import AutoModelForImageClassification
from PIL import Image
from timm.data.transforms_factory import create_transform
import requests
model = AutoModelForImageClassification.from_pretrained("nvidia/MambaVision-L-21K", trust_remote_code=True)
model.cuda().eval()
url = 'http://images.cocodataset.org/val2017/000000020247.jpg'
image = Image.open(requests.get(url, stream=True).raw)
input_resolution = (3, 224, 224)
transform = create_transform(input_size=input_resolution,
is_training=False,
mean=model.config.mean,
std=model.config.std,
crop_mode=model.config.crop_mode,
crop_pct=model.config.crop_pct)
inputs = transform(image).unsqueeze(0).cuda()
outputs = model(inputs)
logits = outputs['logits']
predicted_class_idx = logits.argmax(-1).item()
print("预测类别:", model.config.id2label[predicted_class_idx])
预测标签为brown bear, bruin, Ursus arctos.
特征提取
MambaVision也可用作通用特征提取器。具体来说,我们可以提取模型每个阶段(4个阶段)的输出以及最终的平均池化特征(已展平)。
from transformers import AutoModel
from PIL import Image
from timm.data.transforms_factory import create_transform
import requests
model = AutoModel.from_pretrained("nvidia/MambaVision-L-21K", trust_remote_code=True)
model.cuda().eval()
url = 'http://images.cocodataset.org/val2017/000000020247.jpg'
image = Image.open(requests.get(url, stream=True).raw)
input_resolution = (3, 224, 224)
transform = create_transform(input_size=input_resolution,
is_training=False,
mean=model.config.mean,
std=model.config.std,
crop_mode=model.config.crop_mode,
crop_pct=model.config.crop_pct)
inputs = transform(image).unsqueeze(0).cuda()
out_avg_pool, features = model(inputs)
print("平均池化特征的大小:", out_avg_pool.size())
print("提取特征的阶段数:", len(features))
print("第1阶段提取特征的大小:", features[0].size())
print("第4阶段提取特征的大小:", features[3].size())
许可证
NVIDIA源代码许可协议 - 非商业用途
信息表格
属性 |
详情 |
数据集 |
ILSVRC/imagenet - 21k |
许可证 |
其他(NVIDIA Source Code License - NC) |
许可证名称 |
nvclv1 |
许可证链接 |
LICENSE |
任务类型 |
图像分类 |
库名称 |
transformers |