🚀 MambaVision: ハイブリッドMamba-Transformerビジョンバックボーン
コンピュータビジョンのためのハイブリッドモデルで、MambaとTransformerの強みを活用します。
🚀 クイックスタート
このモデルは、コンピュータビジョンにおける初のハイブリッドモデルで、MambaとTransformerの強みを最大限に活用しています。以下に、モデルの概要、性能、使用方法などを紹介します。
✨ 主な機能
モデル概要
コンピュータビジョン用の最初のハイブリッドモデルを開発しました。このモデルは、MambaとTransformerの強みを活用しています。具体的には、Mambaの定式化を再設計し、視覚的特徴を効率的にモデリングする能力を強化しました。また、Vision Transformers (ViT) とMambaの統合の実現可能性について包括的なアブレーション研究を行いました。結果として、Mambaアーキテクチャの最終層にいくつかの自己注意ブロックを備えることで、長距離の空間依存関係を捉えるモデリング能力が大幅に向上することが示されました。これらの知見に基づき、様々な設計基準を満たす階層的アーキテクチャを持つMambaVisionモデルファミリーを導入しました。
モデル性能
MambaVision-B-21Kは、ImageNet-21Kデータセットで事前学習され、ImageNet-1Kで微調整されています。
名前 |
Acc@1(%) |
Acc@5(%) |
#Params(M) |
FLOPs(G) |
解像度 |
MambaVision-B-21K |
84.9 |
97.5 |
97.7 |
15.0 |
224x224 |
さらに、MambaVisionモデルは、Top-1精度とスループットの面で新しいSOTAパレートフロントを達成し、強力な性能を示しています。
📦 インストール
MambaVisionの要件をインストールするには、以下のコマンドを実行することを強くおすすめします。
コード: https://github.com/NVlabs/MambaVision
pip install mambavision
💻 使用例
基本的な使用法
各モデルには、画像分類と特徴抽出の2つのバリアントがあり、1行のコードでインポートできます。
画像分類
以下の例では、MambaVisionを画像分類に使用する方法を示します。
COCOデータセット の検証セットからの画像を入力として使用します。
以下のコードスニペットを画像分類に使用できます。
from transformers import AutoModelForImageClassification
from PIL import Image
from timm.data.transforms_factory import create_transform
import requests
model = AutoModelForImageClassification.from_pretrained("nvidia/MambaVision-B-21K", trust_remote_code=True)
model.cuda().eval()
url = 'http://images.cocodataset.org/val2017/000000020247.jpg'
image = Image.open(requests.get(url, stream=True).raw)
input_resolution = (3, 224, 224)
transform = create_transform(input_size=input_resolution,
is_training=False,
mean=model.config.mean,
std=model.config.std,
crop_mode=model.config.crop_mode,
crop_pct=model.config.crop_pct)
inputs = transform(image).unsqueeze(0).cuda()
outputs = model(inputs)
logits = outputs['logits']
predicted_class_idx = logits.argmax(-1).item()
print("予測されたクラス:", model.config.id2label[predicted_class_idx])
予測されたラベルは brown bear, bruin, Ursus arctos.
です。
特徴抽出
MambaVisionは、一般的な特徴抽出器としても使用できます。
具体的には、モデルの各段階(4段階)の出力と、最終的な平均プーリングされた平坦化された特徴を抽出できます。
以下のコードスニペットを特徴抽出に使用できます。
from transformers import AutoModel
from PIL import Image
from timm.data.transforms_factory import create_transform
import requests
model = AutoModel.from_pretrained("nvidia/MambaVision-B-21K", trust_remote_code=True)
model.cuda().eval()
url = 'http://images.cocodataset.org/val2017/000000020247.jpg'
image = Image.open(requests.get(url, stream=True).raw)
input_resolution = (3, 224, 224)
transform = create_transform(input_size=input_resolution,
is_training=False,
mean=model.config.mean,
std=model.config.std,
crop_mode=model.config.crop_pct,
crop_pct=model.config.crop_pct)
inputs = transform(image).unsqueeze(0).cuda()
out_avg_pool, features = model(inputs)
print("平均プーリングされた特徴のサイズ:", out_avg_pool.size())
print("抽出された特徴の段階数:", len(features))
print("段階1で抽出された特徴のサイズ:", features[0].size())
print("段階4で抽出された特徴のサイズ:", features[3].size())
📄 ライセンス
NVIDIA Source Code License-NC