🚀 MambaVision:混合Mamba - 变压器视觉主干网络
MambaVision是首个融合Mamba和Transformer优势的计算机视觉混合模型,重新设计Mamba公式以高效建模视觉特征,通过实验验证在Mamba架构末层添加自注意力块可提升长程空间依赖建模能力,推出分层架构的模型家族,在图像分类和特征提取上表现出色。
🚀 快速开始
强烈建议通过运行以下命令来安装MambaVision的依赖项:
pip install mambavision
✨ 主要特性
- 创新架构:开发了首个结合Mamba和Transformer优势的计算机视觉混合模型。
- 高效特征建模:重新设计Mamba公式,增强其对视觉特征的高效建模能力。
- 实验验证:通过全面的消融实验,验证了在Mamba架构的最后几层配备自注意力块能显著提高捕捉长程空间依赖的建模能力。
- 性能卓越:在Top - 1准确率和吞吐量方面达到了新的SOTA帕累托前沿。
📦 安装指南
运行以下命令安装MambaVision的依赖项:
pip install mambavision
💻 使用示例
基础用法
图像分类
以下示例展示了如何使用MambaVision进行图像分类。以COCO数据集验证集中的图像为输入:
使用以下代码片段进行图像分类:
```Python
from transformers import AutoModelForImageClassification
from PIL import Image
from timm.data.transforms_factory import create_transform
import requests
model = AutoModelForImageClassification.from_pretrained("nvidia/MambaVision-T2-1K", trust_remote_code=True)
eval mode for inference
model.cuda().eval()
prepare image for the model
url = 'http://images.cocodataset.org/val2017/000000020247.jpg'
image = Image.open(requests.get(url, stream=True).raw)
input_resolution = (3, 224, 224) # MambaVision supports any input resolutions
transform = create_transform(input_size=input_resolution,
is_training=False,
mean=model.config.mean,
std=model.config.std,
crop_mode=model.config.crop_mode,
crop_pct=model.config.crop_pct)
inputs = transform(image).unsqueeze(0).cuda()
model inference
outputs = model(inputs)
logits = outputs['logits']
predicted_class_idx = logits.argmax(-1).item()
print("Predicted class:", model.config.id2label[predicted_class_idx])
预测标签为```brown bear, bruin, Ursus arctos.```
#### 特征提取
MambaVision也可用作通用特征提取器。可以提取模型每个阶段(共4个阶段)的输出以及最终的平均池化特征。
```Python
from transformers import AutoModel
from PIL import Image
from timm.data.transforms_factory import create_transform
import requests
model = AutoModel.from_pretrained("nvidia/MambaVision-T2-1K", trust_remote_code=True)
# eval mode for inference
model.cuda().eval()
# prepare image for the model
url = 'http://images.cocodataset.org/val2017/000000020247.jpg'
image = Image.open(requests.get(url, stream=True).raw)
input_resolution = (3, 224, 224) # MambaVision supports any input resolutions
transform = create_transform(input_size=input_resolution,
is_training=False,
mean=model.config.mean,
std=model.config.std,
crop_mode=model.config.crop_mode,
crop_pct=model.config.crop_pct)
inputs = transform(image).unsqueeze(0).cuda()
# model inference
out_avg_pool, features = model(inputs)
print("Size of the averaged pool features:", out_avg_pool.size()) # torch.Size([1, 640])
print("Number of stages in extracted features:", len(features)) # 4 stages
print("Size of extracted features in stage 1:", features[0].size()) # torch.Size([1, 80, 56, 56])
print("Size of extracted features in stage 4:", features[3].size()) # torch.Size([1, 640, 7, 7])
📚 详细文档
模型概述
我们开发了首个用于计算机视觉的混合模型,该模型充分利用了Mamba和Transformer的优势。具体而言,我们的核心贡献包括重新设计Mamba公式,以增强其对视觉特征进行高效建模的能力。此外,我们对将视觉Transformer(ViT)与Mamba集成的可行性进行了全面的消融研究。结果表明,在Mamba架构的最后几层配备几个自注意力块,可以大大提高捕捉长程空间依赖的建模能力。基于这些发现,我们推出了具有分层架构的MambaVision模型家族,以满足各种设计标准。
模型性能
MambaVision表现出色,在Top - 1准确率和吞吐量方面达到了新的SOTA帕累托前沿。
📄 许可证
NVIDIA源代码许可协议 - 非商业用途