🚀 MambaVision: A Hybrid Mamba - Transformer Vision Backbone
This project presents the first hybrid model for computer vision that combines the advantages of Mamba and Transformers. It addresses the need for more efficient visual feature modeling and long - range spatial dependency capture in computer vision tasks.
🚀 Quick Start
To quickly get started with MambaVision, you need to install the necessary requirements. It is highly recommended to run the following command:
pip install mambavision
✨ Features
Model Overview
We have developed the first hybrid model for computer vision which leverages the strengths of Mamba and Transformers. Specifically, our core contribution includes redesigning the Mamba formulation to enhance its capability for efficient modeling of visual features. In addition, we conducted a comprehensive ablation study on the feasibility of integrating Vision Transformers (ViT) with Mamba. Our results demonstrate that equipping the Mamba architecture with several self - attention blocks at the final layers greatly improves the modeling capacity to capture long - range spatial dependencies. Based on our findings, we introduce a family of MambaVision models with a hierarchical architecture to meet various design criteria.
Model Performance
MambaVision - L2 - 512 - 21K is pretrained on the ImageNet - 21K dataset and finetuned on ImageNet - 1K at a 512 x 512 resolution.
Property |
Details |
Model Type |
Image Classification |
Training Data |
ILSVRC/imagenet - 21K |
The following table shows the performance metrics:
Name |
Acc@1(%) |
Acc@5(%) |
#Params(M) |
FLOPs(G) |
Resolution |
MambaVision - L2 - 512 - 21K |
87.3 |
98.4 |
241.5 |
196.3 |
512x512 |
In addition, the MambaVision models demonstrate strong performance by achieving a new SOTA Pareto - front in terms of Top - 1 accuracy and throughput.

📦 Installation
To install MambaVision, run the following command:
pip install mambavision
💻 Usage Examples
Basic Usage
Image Classification
In the following example, we demonstrate how MambaVision can be used for image classification. Given an image from the COCO dataset val set as an input:

The following Python code can be used for image classification:
from transformers import AutoModelForImageClassification
from PIL import Image
from timm.data.transforms_factory import create_transform
import requests
model = AutoModelForImageClassification.from_pretrained("nvidia/MambaVision-L2-512-21K", trust_remote_code=True)
model.cuda().eval()
url = 'http://images.cocodataset.org/val2017/000000020247.jpg'
image = Image.open(requests.get(url, stream=True).raw)
input_resolution = (3, 512, 512)
transform = create_transform(input_size=input_resolution,
is_training=False,
mean=model.config.mean,
std=model.config.std,
crop_mode=model.config.crop_mode,
crop_pct=model.config.crop_pct)
inputs = transform(image).unsqueeze(0).cuda()
outputs = model(inputs)
logits = outputs['logits']
predicted_class_idx = logits.argmax(-1).item()
print("Predicted class:", model.config.id2label[predicted_class_idx])
The predicted label is brown bear, bruin, Ursus arctos.
Feature Extraction
MambaVision can also be used as a generic feature extractor. Specifically, we can extract the outputs of each stage of the model (4 stages) as well as the final averaged - pool features that are flattened.
The following Python code can be used for feature extraction:
from transformers import AutoModel
from PIL import Image
from timm.data.transforms_factory import create_transform
import requests
model = AutoModel.from_pretrained("nvidia/MambaVision-L2-512-21K", trust_remote_code=True)
model.cuda().eval()
url = 'http://images.cocodataset.org/val2017/000000020247.jpg'
image = Image.open(requests.get(url, stream=True).raw)
input_resolution = (3, 512, 512)
transform = create_transform(input_size=input_resolution,
is_training=False,
mean=model.config.mean,
std=model.config.std,
crop_mode=model.config.crop_mode,
crop_pct=model.config.crop_pct)
inputs = transform(image).unsqueeze(0).cuda()
out_avg_pool, features = model(inputs)
print("Size of the averaged pool features:", out_avg_pool.size())
print("Number of stages in extracted features:", len(features))
print("Size of extracted features in stage 1:", features[0].size())
print("Size of extracted features in stage 4:", features[3].size())
📄 License
This project is licensed under the NVIDIA Source Code License - NC.