🚀 MambaVision:混合Mamba - 變壓器視覺主幹網絡
MambaVision是首個融合Mamba和Transformer優勢的計算機視覺混合模型,重新設計Mamba公式以高效建模視覺特徵,通過實驗驗證在Mamba架構末層添加自注意力塊可提升長程空間依賴建模能力,推出分層架構的模型家族,在圖像分類和特徵提取上表現出色。
🚀 快速開始
強烈建議通過運行以下命令來安裝MambaVision的依賴項:
pip install mambavision
✨ 主要特性
- 創新架構:開發了首個結合Mamba和Transformer優勢的計算機視覺混合模型。
- 高效特徵建模:重新設計Mamba公式,增強其對視覺特徵的高效建模能力。
- 實驗驗證:通過全面的消融實驗,驗證了在Mamba架構的最後幾層配備自注意力塊能顯著提高捕捉長程空間依賴的建模能力。
- 性能卓越:在Top - 1準確率和吞吐量方面達到了新的SOTA帕累託前沿。
📦 安裝指南
運行以下命令安裝MambaVision的依賴項:
pip install mambavision
💻 使用示例
基礎用法
圖像分類
以下示例展示瞭如何使用MambaVision進行圖像分類。以COCO數據集驗證集中的圖像為輸入:
使用以下代碼片段進行圖像分類:
```Python
from transformers import AutoModelForImageClassification
from PIL import Image
from timm.data.transforms_factory import create_transform
import requests
model = AutoModelForImageClassification.from_pretrained("nvidia/MambaVision-T2-1K", trust_remote_code=True)
eval mode for inference
model.cuda().eval()
prepare image for the model
url = 'http://images.cocodataset.org/val2017/000000020247.jpg'
image = Image.open(requests.get(url, stream=True).raw)
input_resolution = (3, 224, 224) # MambaVision supports any input resolutions
transform = create_transform(input_size=input_resolution,
is_training=False,
mean=model.config.mean,
std=model.config.std,
crop_mode=model.config.crop_mode,
crop_pct=model.config.crop_pct)
inputs = transform(image).unsqueeze(0).cuda()
model inference
outputs = model(inputs)
logits = outputs['logits']
predicted_class_idx = logits.argmax(-1).item()
print("Predicted class:", model.config.id2label[predicted_class_idx])
預測標籤為```brown bear, bruin, Ursus arctos.```
#### 特徵提取
MambaVision也可用作通用特徵提取器。可以提取模型每個階段(共4個階段)的輸出以及最終的平均池化特徵。
```Python
from transformers import AutoModel
from PIL import Image
from timm.data.transforms_factory import create_transform
import requests
model = AutoModel.from_pretrained("nvidia/MambaVision-T2-1K", trust_remote_code=True)
# eval mode for inference
model.cuda().eval()
# prepare image for the model
url = 'http://images.cocodataset.org/val2017/000000020247.jpg'
image = Image.open(requests.get(url, stream=True).raw)
input_resolution = (3, 224, 224) # MambaVision supports any input resolutions
transform = create_transform(input_size=input_resolution,
is_training=False,
mean=model.config.mean,
std=model.config.std,
crop_mode=model.config.crop_mode,
crop_pct=model.config.crop_pct)
inputs = transform(image).unsqueeze(0).cuda()
# model inference
out_avg_pool, features = model(inputs)
print("Size of the averaged pool features:", out_avg_pool.size()) # torch.Size([1, 640])
print("Number of stages in extracted features:", len(features)) # 4 stages
print("Size of extracted features in stage 1:", features[0].size()) # torch.Size([1, 80, 56, 56])
print("Size of extracted features in stage 4:", features[3].size()) # torch.Size([1, 640, 7, 7])
📚 詳細文檔
模型概述
我們開發了首個用於計算機視覺的混合模型,該模型充分利用了Mamba和Transformer的優勢。具體而言,我們的核心貢獻包括重新設計Mamba公式,以增強其對視覺特徵進行高效建模的能力。此外,我們對將視覺Transformer(ViT)與Mamba集成的可行性進行了全面的消融研究。結果表明,在Mamba架構的最後幾層配備幾個自注意力塊,可以大大提高捕捉長程空間依賴的建模能力。基於這些發現,我們推出了具有分層架構的MambaVision模型家族,以滿足各種設計標準。
模型性能
MambaVision表現出色,在Top - 1準確率和吞吐量方面達到了新的SOTA帕累託前沿。
📄 許可證
NVIDIA源代碼許可協議 - 非商業用途