🚀 MambaVision:混合Mamba-Transformer视觉骨干网络
MambaVision是首个用于计算机视觉的混合模型,结合了Mamba和Transformer的优势,重新设计Mamba公式以高效建模视觉特征,还研究了与ViT集成的可行性,推出分层架构的模型家族,满足不同设计需求。
🚀 快速开始
安装
强烈建议通过运行以下命令来安装MambaVision所需的依赖:
pip install mambavision
使用示例
基础用法
MambaVision可用于图像分类和特征提取,以下是具体示例:
from transformers import AutoModelForImageClassification
from PIL import Image
from timm.data.transforms_factory import create_transform
import requests
model = AutoModelForImageClassification.from_pretrained("nvidia/MambaVision-L2-512-21K", trust_remote_code=True)
model.cuda().eval()
url = 'http://images.cocodataset.org/val2017/000000020247.jpg'
image = Image.open(requests.get(url, stream=True).raw)
input_resolution = (3, 512, 512)
transform = create_transform(input_size=input_resolution,
is_training=False,
mean=model.config.mean,
std=model.config.std,
crop_mode=model.config.crop_mode,
crop_pct=model.config.crop_pct)
inputs = transform(image).unsqueeze(0).cuda()
outputs = model(inputs)
logits = outputs['logits']
predicted_class_idx = logits.argmax(-1).item()
print("Predicted class:", model.config.id2label[predicted_class_idx])
from transformers import AutoModel
from PIL import Image
from timm.data.transforms_factory import create_transform
import requests
model = AutoModel.from_pretrained("nvidia/MambaVision-L2-512-21K", trust_remote_code=True)
model.cuda().eval()
url = 'http://images.cocodataset.org/val2017/000000020247.jpg'
image = Image.open(requests.get(url, stream=True).raw)
input_resolution = (3, 512, 512)
transform = create_transform(input_size=input_resolution,
is_training=False,
mean=model.config.mean,
std=model.config.std,
crop_mode=model.config.crop_mode,
crop_pct=model.config.crop_pct)
inputs = transform(image).unsqueeze(0).cuda()
out_avg_pool, features = model(inputs)
print("Size of the averaged pool features:", out_avg_pool.size())
print("Number of stages in extracted features:", len(features))
print("Size of extracted features in stage 1:", features[0].size())
print("Size of extracted features in stage 4:", features[3].size())
✨ 主要特性
- 开发了首个结合Mamba和Transformer优势的计算机视觉混合模型。
- 重新设计Mamba公式,增强其对视觉特征的高效建模能力。
- 对Vision Transformers (ViT) 与Mamba集成的可行性进行了全面的消融研究。
- 提出了具有分层架构的MambaVision模型家族,以满足各种设计标准。
📦 安装指南
通过以下命令安装MambaVision:
pip install mambavision
📚 详细文档
模型概述
我们开发了首个用于计算机视觉的混合模型,该模型充分利用了Mamba和Transformer的优势。具体而言,我们的核心贡献包括重新设计Mamba公式,以增强其对视觉特征进行高效建模的能力。此外,我们对Vision Transformers (ViT) 与Mamba集成的可行性进行了全面的消融研究。结果表明,在Mamba架构的最后几层配备几个自注意力块,可以大大提高其捕捉长距离空间依赖关系的建模能力。基于这些发现,我们推出了具有分层架构的MambaVision模型家族,以满足各种设计标准。
模型性能
MambaVision-L2-512-21K在ImageNet-21K数据集上进行预训练,并在512 x 512分辨率的ImageNet-1K上进行微调。
名称 |
准确率@1(%) |
准确率@5(%) |
参数数量(M) |
浮点运算数(G) |
分辨率 |
MambaVision-L2-512-21K |
87.3 |
98.4 |
241.5 |
196.3 |
512x512 |
此外,MambaVision模型在Top-1准确率和吞吐量方面达到了新的SOTA Pareto前沿,表现出色。

模型使用
MambaVision可用于图像分类和特征提取,具体使用方法见上文的使用示例。
许可证
NVIDIA源代码许可协议 - 非商业用途
🔧 技术细节
- 提出了一种混合模型,结合了Mamba和Transformer的优势,用于计算机视觉任务。
- 重新设计了Mamba公式,以提高其对视觉特征的建模能力。
- 研究了Vision Transformers (ViT) 与Mamba集成的可行性,并通过实验证明了在Mamba架构的最后几层添加自注意力块可以提高其捕捉长距离空间依赖关系的能力。
- 推出了具有分层架构的MambaVision模型家族,以满足不同的设计需求。
📄 许可证
本项目采用 NVIDIA源代码许可协议 - 非商业用途。
信息表格