🚀 MambaVision:混合Mamba-Transformer視覺骨幹網絡
MambaVision是首個結合Mamba和Transformer優勢的計算機視覺混合模型,重新設計Mamba公式以高效建模視覺特徵,在圖像分類任務中表現出色。
🚀 快速開始
安裝依賴
我們提供了一個 Docker文件。此外,假設已經安裝了最新的 PyTorch 包,可以通過運行以下命令來安裝依賴項:
pip install -r requirements.txt
也可以直接運行以下命令安裝MambaVision所需的依賴:
pip install mambavision
✨ 主要特性
我們開發了首個用於計算機視覺的混合模型,充分利用了Mamba和Transformer的優勢。具體而言,我們的核心貢獻包括重新設計Mamba公式,以增強其對視覺特徵進行高效建模的能力。此外,我們對將視覺Transformer(ViT)與Mamba集成的可行性進行了全面的消融研究。結果表明,在Mamba架構的最後幾層配備幾個自注意力塊,可以大大提高其捕捉長距離空間依賴關係的建模能力。基於這些發現,我們推出了一系列具有分層架構的MambaVision模型,以滿足各種設計標準。
💻 使用示例
基礎用法
圖像分類
在以下示例中,我們展示瞭如何使用MambaVision進行圖像分類。以 COCO數據集 驗證集中的一張圖像作為輸入:
可以使用以下代碼片段進行圖像分類:
```python
from transformers import AutoModelForImageClassification
from PIL import Image
from timm.data.transforms_factory import create_transform
import requests
model = AutoModelForImageClassification.from_pretrained("nvidia/MambaVision-L3-512-21K", trust_remote_code=True)
eval mode for inference
model.cuda().eval()
prepare image for the model
url = 'http://images.cocodataset.org/val2017/000000020247.jpg'
image = Image.open(requests.get(url, stream=True).raw)
input_resolution = (3, 512, 512) # MambaVision supports any input resolutions
transform = create_transform(input_size=input_resolution,
is_training=False,
mean=model.config.mean,
std=model.config.std,
crop_mode=model.config.crop_mode,
crop_pct=model.config.crop_pct)
inputs = transform(image).unsqueeze(0).cuda()
model inference
outputs = model(inputs)
logits = outputs['logits']
predicted_class_idx = logits.argmax(-1).item()
print("Predicted class:", model.config.id2label[predicted_class_idx])
預測的標籤是 ```brown bear, bruin, Ursus arctos.```
#### 特徵提取
MambaVision還可以用作通用特徵提取器。具體來說,我們可以提取模型每個階段(共4個階段)的輸出以及最終的平均池化特徵(已展平)。可以使用以下代碼片段進行特徵提取:
```python
from transformers import AutoModel
from PIL import Image
from timm.data.transforms_factory import create_transform
import requests
model = AutoModel.from_pretrained("nvidia/MambaVision-L3-512-21K", trust_remote_code=True)
# eval mode for inference
model.cuda().eval()
# prepare image for the model
url = 'http://images.cocodataset.org/val2017/000000020247.jpg'
image = Image.open(requests.get(url, stream=True).raw)
input_resolution = (3, 512, 512) # MambaVision supports any input resolutions
transform = create_transform(input_size=input_resolution,
is_training=False,
mean=model.config.mean,
std=model.config.std,
crop_mode=model.config.crop_mode,
crop_pct=model.config.crop_pct)
inputs = transform(image).unsqueeze(0).cuda()
# model inference
out_avg_pool, features = model(inputs)
print("Size of the averaged pool features:", out_avg_pool.size()) # torch.Size([1, 1568])
print("Number of stages in extracted features:", len(features)) # 4 stages
print("Size of extracted features in stage 1:", features[0].size()) # torch.Size([1, 196, 128, 128])
print("Size of extracted features in stage 4:", features[3].size()) # torch.Size([1, 1568, 16, 16])
📚 詳細文檔
模型性能
MambaVision-L3-512-21K在ImageNet-21K數據集上進行預訓練,並在512 x 512分辨率的ImageNet-1K數據集上進行微調。
名稱 |
準確率@1(%) |
準確率@5(%) |
參數數量(M) |
浮點運算次數(G) |
分辨率 |
MambaVision-L3-512-21K |
88.1 |
98.6 |
739.6 |
489.1 |
512x512 |
此外,MambaVision模型在Top-1準確率和吞吐量方面達到了新的SOTA帕累託前沿,表現出色。
預訓練模型結果
ImageNet-21K
名稱 |
準確率@1(%) |
準確率@5(%) |
參數數量(M) |
浮點運算次數(G) |
分辨率 |
Hugging Face鏈接 |
下載鏈接 |
MambaVision-B-21K |
84.9 |
97.5 |
97.7 |
15.0 |
224x224 |
鏈接 |
模型 |
MambaVision-L-21K |
86.1 |
97.9 |
227.9 |
34.9 |
224x224 |
鏈接 |
模型 |
MambaVision-L2-512-21K |
87.3 |
98.4 |
241.5 |
196.3 |
512x512 |
鏈接 |
模型 |
MambaVision-L3-256-21K |
87.3 |
98.3 |
739.6 |
122.3 |
256x256 |
鏈接 |
模型 |
MambaVision-L3-512-21K |
88.1 |
98.6 |
739.6 |
489.1 |
512x512 |
鏈接 |
模型 |
ImageNet-1K
名稱 |
準確率@1(%) |
準確率@5(%) |
吞吐量(圖像/秒) |
分辨率 |
參數數量(M) |
浮點運算次數(G) |
Hugging Face鏈接 |
下載鏈接 |
MambaVision-T |
82.3 |
96.2 |
6298 |
224x224 |
31.8 |
4.4 |
鏈接 |
模型 |
MambaVision-T2 |
82.7 |
96.3 |
5990 |
224x224 |
35.1 |
5.1 |
鏈接 |
模型 |
MambaVision-S |
83.3 |
96.5 |
4700 |
224x224 |
50.1 |
7.5 |
鏈接 |
模型 |
MambaVision-B |
84.2 |
96.9 |
3670 |
224x224 |
97.7 |
15.0 |
鏈接 |
模型 |
MambaVision-L |
85.0 |
97.1 |
2190 |
224x224 |
227.9 |
34.9 |
鏈接 |
模型 |
MambaVision-L2 |
85.3 |
97.2 |
1021 |
224x224 |
241.5 |
37.5 |
鏈接 |
模型 |
📄 許可證
英偉達源代碼許可協議 - 非商業用途
參考資料