🚀 MambaVision:混合Mamba-Transformer視覺主幹網絡
MambaVision是一種專門為視覺應用設計的新型混合Mamba-Transformer主幹網絡。它重新設計了Mamba公式,增強了對視覺特徵的高效建模能力,在圖像分類和下游任務中表現出色。
🚀 快速開始
安裝
強烈建議通過運行以下命令來安裝MambaVision的依賴項:
pip install mambavision
✨ 主要特性
- 提出了一種新穎的混合Mamba-Transformer主幹網絡MambaVision,專為視覺應用量身定製。
- 重新設計Mamba公式,增強其對視覺特徵的高效建模能力。
- 對Vision Transformers (ViT) 與Mamba集成的可行性進行了全面的消融研究。
- 在Mamba架構的最後幾層配備幾個自注意力塊,大大提高了捕捉長距離空間依賴的建模能力。
- 引入了具有分層架構的MambaVision模型家族,以滿足各種設計標準。
- 在ImageNet - 1K數據集的圖像分類任務中,MambaVision模型變體在Top - 1準確率和圖像吞吐量方面達到了新的最優性能。
- 在MS COCO和ADE20K數據集的目標檢測、實例分割和語義分割等下游任務中,MambaVision優於同等規模的主幹網絡。
📚 詳細文檔
模型描述
我們提出了一種新穎的混合Mamba - Transformer主幹網絡,稱為MambaVision,專門用於視覺應用。我們的核心貢獻包括重新設計Mamba公式,以增強其對視覺特徵的高效建模能力。此外,我們對Vision Transformers (ViT) 與Mamba集成的可行性進行了全面的消融研究。結果表明,在Mamba架構的最後幾層配備幾個自注意力塊,大大提高了捕捉長距離空間依賴的建模能力。基於這些發現,我們引入了具有分層架構的MambaVision模型家族,以滿足各種設計標準。在ImageNet - 1K數據集的圖像分類任務中,MambaVision模型變體在Top - 1準確率和圖像吞吐量方面達到了新的最優 (SOTA) 性能。在MS COCO和ADE20K數據集的目標檢測、實例分割和語義分割等下游任務中,MambaVision優於同等規模的主幹網絡,表現更出色。代碼鏈接:https://github.com/NVlabs/MambaVision 。
模型性能
MambaVision - L - 21K在ImageNet - 21K數據集上進行預訓練,並在ImageNet - 1K上進行微調。
名稱 |
準確率@1(%) |
準確率@5(%) |
參數數量(M) |
浮點運算次數(G) |
分辨率 |
MambaVision - L - 21K |
86.1 |
97.9 |
227.9 |
34.9 |
224x224 |
此外,MambaVision模型在Top - 1準確率和吞吐量方面達到了新的SOTA帕累託前沿,表現強勁。
模型使用
基礎用法
圖像分類
以下示例展示瞭如何使用MambaVision進行圖像分類。給定來自COCO數據集驗證集的圖像作為輸入:
from transformers import AutoModelForImageClassification
from PIL import Image
from timm.data.transforms_factory import create_transform
import requests
model = AutoModelForImageClassification.from_pretrained("nvidia/MambaVision-L-21K", trust_remote_code=True)
model.cuda().eval()
url = 'http://images.cocodataset.org/val2017/000000020247.jpg'
image = Image.open(requests.get(url, stream=True).raw)
input_resolution = (3, 224, 224)
transform = create_transform(input_size=input_resolution,
is_training=False,
mean=model.config.mean,
std=model.config.std,
crop_mode=model.config.crop_mode,
crop_pct=model.config.crop_pct)
inputs = transform(image).unsqueeze(0).cuda()
outputs = model(inputs)
logits = outputs['logits']
predicted_class_idx = logits.argmax(-1).item()
print("預測類別:", model.config.id2label[predicted_class_idx])
預測標籤為brown bear, bruin, Ursus arctos.
特徵提取
MambaVision也可用作通用特徵提取器。具體來說,我們可以提取模型每個階段(4個階段)的輸出以及最終的平均池化特徵(已展平)。
from transformers import AutoModel
from PIL import Image
from timm.data.transforms_factory import create_transform
import requests
model = AutoModel.from_pretrained("nvidia/MambaVision-L-21K", trust_remote_code=True)
model.cuda().eval()
url = 'http://images.cocodataset.org/val2017/000000020247.jpg'
image = Image.open(requests.get(url, stream=True).raw)
input_resolution = (3, 224, 224)
transform = create_transform(input_size=input_resolution,
is_training=False,
mean=model.config.mean,
std=model.config.std,
crop_mode=model.config.crop_mode,
crop_pct=model.config.crop_pct)
inputs = transform(image).unsqueeze(0).cuda()
out_avg_pool, features = model(inputs)
print("平均池化特徵的大小:", out_avg_pool.size())
print("提取特徵的階段數:", len(features))
print("第1階段提取特徵的大小:", features[0].size())
print("第4階段提取特徵的大小:", features[3].size())
許可證
NVIDIA源代碼許可協議 - 非商業用途
信息表格
屬性 |
詳情 |
數據集 |
ILSVRC/imagenet - 21k |
許可證 |
其他(NVIDIA Source Code License - NC) |
許可證名稱 |
nvclv1 |
許可證鏈接 |
LICENSE |
任務類型 |
圖像分類 |
庫名稱 |
transformers |