🚀 MambaVision-L2-512-21K
MambaVisionは、コンピュータビジョンのためのハイブリッドモデルで、MambaとTransformerの強みを活用した画像分類モデルです。
🚀 クイックスタート
このセクションでは、MambaVision-L2-512-21Kモデルの概要、性能、使用方法、ライセンスについて説明します。
✨ 主な機能
- ハイブリッドモデル:MambaとTransformerの強みを組み合わせた、コンピュータビジョン用の初のハイブリッドモデルを開発しました。
- Mambaの改良:Mambaの定式化を再設計し、視覚的特徴を効率的にモデリングする能力を向上させました。
- ViTとの統合:Vision Transformers (ViT) とMambaの統合の実現可能性について包括的なアブレーション研究を行いました。
- 長距離依存関係のキャプチャ:Mambaアーキテクチャの最終層にいくつかの自己注意ブロックを備えることで、長距離の空間的依存関係をキャプチャするモデリング能力が大幅に向上します。
- 階層的アーキテクチャ:様々な設計基準を満たすために、階層的アーキテクチャを持つMambaVisionモデルファミリーを導入しました。
📦 インストール
MambaVisionの必要なパッケージをインストールするには、以下のコマンドを実行してください。
pip install mambavision
💻 使用例
基本的な使用法
画像分類
以下のコードは、MambaVisionを使用して画像分類を行う例です。
from transformers import AutoModelForImageClassification
from PIL import Image
from timm.data.transforms_factory import create_transform
import requests
model = AutoModelForImageClassification.from_pretrained("nvidia/MambaVision-L2-512-21K", trust_remote_code=True)
model.cuda().eval()
url = 'http://images.cocodataset.org/val2017/000000020247.jpg'
image = Image.open(requests.get(url, stream=True).raw)
input_resolution = (3, 512, 512)
transform = create_transform(input_size=input_resolution,
is_training=False,
mean=model.config.mean,
std=model.config.std,
crop_mode=model.config.crop_mode,
crop_pct=model.config.crop_pct)
inputs = transform(image).unsqueeze(0).cuda()
outputs = model(inputs)
logits = outputs['logits']
predicted_class_idx = logits.argmax(-1).item()
print("Predicted class:", model.config.id2label[predicted_class_idx])
特徴抽出
MambaVisionは、一般的な特徴抽出器としても使用できます。以下のコードは、特徴抽出を行う例です。
from transformers import AutoModel
from PIL import Image
from timm.data.transforms_factory import create_transform
import requests
model = AutoModel.from_pretrained("nvidia/MambaVision-L2-512-21K", trust_remote_code=True)
model.cuda().eval()
url = 'http://images.cocodataset.org/val2017/000000020247.jpg'
image = Image.open(requests.get(url, stream=True).raw)
input_resolution = (3, 512, 512)
transform = create_transform(input_size=input_resolution,
is_training=False,
mean=model.config.mean,
std=model.config.std,
crop_mode=model.config.crop_mode,
crop_pct=model.config.crop_pct)
inputs = transform(image).unsqueeze(0).cuda()
out_avg_pool, features = model(inputs)
print("Size of the averaged pool features:", out_avg_pool.size())
print("Number of stages in extracted features:", len(features))
print("Size of extracted features in stage 1:", features[0].size())
print("Size of extracted features in stage 4:", features[3].size())
📚 ドキュメント
モデル概要
私たちは、MambaとTransformerの強みを活用した、コンピュータビジョン用の初のハイブリッドモデルを開発しました。具体的には、Mambaの定式化を再設計し、視覚的特徴を効率的にモデリングする能力を向上させました。また、Vision Transformers (ViT) とMambaの統合の実現可能性について包括的なアブレーション研究を行いました。結果は、Mambaアーキテクチャの最終層にいくつかの自己注意ブロックを備えることで、長距離の空間的依存関係をキャプチャするモデリング能力が大幅に向上することを示しています。これらの知見に基づいて、様々な設計基準を満たすために、階層的アーキテクチャを持つMambaVisionモデルファミリーを導入しました。
モデル性能
MambaVision-L2-512-21Kは、ImageNet-21Kデータセットで事前学習され、512 x 512の解像度でImageNet-1Kに微調整されています。
名前 |
1位精度(%) |
5位精度(%) |
パラメータ数(M) |
FLOPs(G) |
解像度 |
MambaVision-L2-512-21K |
87.3 |
98.4 |
241.5 |
196.3 |
512x512 |
また、MambaVisionモデルは、Top-1精度とスループットの面で新しいSOTAパレートフロントを達成することで、強力な性能を示しています。
📄 ライセンス
このモデルは、NVIDIA Source Code License-NCの下で提供されています。