🚀 MambaVision: ハイブリッドMamba-Transformerビジョンバックボーン
我々は、ビジョンアプリケーションに特化した、新しいハイブリッドMamba-TransformerバックボーンであるMambaVisionを提案します。このモデルは、視覚特徴の効率的なモデリング能力を向上させるためにMambaの定式化を再設計しています。
プロパティ |
詳細 |
データセット |
ILSVRC/imagenet-21k |
ライセンス |
other |
ライセンス名 |
nvclv1 |
ライセンスリンク |
LICENSE |
パイプラインタグ |
image-classification |
ライブラリ名 |
transformers |
🚀 クイックスタート
MambaVisionを使用するためには、まず必要なパッケージをインストールすることをおすすめします。以下のコマンドを実行してください。
pip install mambavision
✨ 主な機能
- 我々は、ビジョンアプリケーションに特化した新しいハイブリッドMamba-TransformerバックボーンであるMambaVisionを提案します。
- Mambaの定式化を再設計し、視覚特徴の効率的なモデリング能力を向上させます。
- Vision Transformers (ViT) とMambaの統合の実現可能性について包括的なアブレーション研究を行います。
- Mambaアーキテクチャの最終層にいくつかの自己注意ブロックを備えることで、長距離の空間依存関係を捉えるモデリング能力が大幅に向上します。
- 階層的なアーキテクチャを持つMambaVisionモデルファミリーを導入し、様々な設計基準を満たします。
- ImageNet-1Kデータセットでの画像分類において、MambaVisionモデルのバリアントはTop-1精度と画像スループットの面で新たな最先端 (SOTA) 性能を達成します。
- MS COCOおよびADE20Kデータセットでの物体検出、インスタンスセグメンテーション、セマンティックセグメンテーションなどの下流タスクにおいて、MambaVisionは同等サイズのバックボーンを上回り、より良好な性能を示します。
📚 ドキュメント
モデルの説明
MambaVision: A Hybrid Mamba-Transformer Vision Backbone では、ビジョンアプリケーションに特化した新しいハイブリッドMamba-TransformerバックボーンであるMambaVisionを提案しています。我々の主要な貢献は、視覚特徴の効率的なモデリング能力を向上させるためにMambaの定式化を再設計することです。また、Vision Transformers (ViT) とMambaの統合の実現可能性について包括的なアブレーション研究を行いました。結果として、Mambaアーキテクチャの最終層にいくつかの自己注意ブロックを備えることで、長距離の空間依存関係を捉えるモデリング能力が大幅に向上することがわかりました。これらの知見に基づき、様々な設計基準を満たすために階層的なアーキテクチャを持つMambaVisionモデルファミリーを導入しました。ImageNet-1Kデータセットでの画像分類において、MambaVisionモデルのバリアントはTop-1精度と画像スループットの面で新たな最先端 (SOTA) 性能を達成しました。MS COCOおよびADE20Kデータセットでの物体検出、インスタンスセグメンテーション、セマンティックセグメンテーションなどの下流タスクにおいても、MambaVisionは同等サイズのバックボーンを上回り、より良好な性能を示しました。コードは こちら から入手できます。
モデルの性能
MambaVision-L-21Kは、ImageNet-21Kデータセットで事前学習され、ImageNet-1Kで微調整されています。
名前 |
Acc@1(%) |
Acc@5(%) |
#Params(M) |
FLOPs(G) |
解像度 |
MambaVision-L-21K |
86.1 |
97.9 |
227.9 |
34.9 |
224x224 |
また、MambaVisionモデルは、Top-1精度とスループットの面で新たなSOTAパレートフロントを達成し、強力な性能を示しています。
💻 使用例
基本的な使用法
画像分類
以下の例では、MambaVisionを画像分類に使用する方法を示します。
COCOデータセット の検証セットからの以下の画像を入力として与えます。
以下のコードスニペットを使用して画像分類を行うことができます。
from transformers import AutoModelForImageClassification
from PIL import Image
from timm.data.transforms_factory import create_transform
import requests
model = AutoModelForImageClassification.from_pretrained("nvidia/MambaVision-L-21K", trust_remote_code=True)
model.cuda().eval()
url = 'http://images.cocodataset.org/val2017/000000020247.jpg'
image = Image.open(requests.get(url, stream=True).raw)
input_resolution = (3, 224, 224)
transform = create_transform(input_size=input_resolution,
is_training=False,
mean=model.config.mean,
std=model.config.std,
crop_mode=model.config.crop_mode,
crop_pct=model.config.crop_pct)
inputs = transform(image).unsqueeze(0).cuda()
outputs = model(inputs)
logits = outputs['logits']
predicted_class_idx = logits.argmax(-1).item()
print("Predicted class:", model.config.id2label[predicted_class_idx])
予測されたラベルは brown bear, bruin, Ursus arctos.
です。
特徴抽出
MambaVisionは、一般的な特徴抽出器としても使用できます。具体的には、モデルの各段階 (4段階) の出力と、最終的な平均プーリングされた特徴を平坦化したものを抽出することができます。
以下のコードスニペットを使用して特徴抽出を行うことができます。
from transformers import AutoModel
from PIL import Image
from timm.data.transforms_factory import create_transform
import requests
model = AutoModel.from_pretrained("nvidia/MambaVision-L-21K", trust_remote_code=True)
model.cuda().eval()
url = 'http://images.cocodataset.org/val2017/000000020247.jpg'
image = Image.open(requests.get(url, stream=True).raw)
input_resolution = (3, 224, 224)
transform = create_transform(input_size=input_resolution,
is_training=False,
mean=model.config.mean,
std=model.config.std,
crop_mode=model.config.crop_mode,
crop_pct=model.config.crop_pct)
inputs = transform(image).unsqueeze(0).cuda()
out_avg_pool, features = model(inputs)
print("Size of the averaged pool features:", out_avg_pool.size())
print("Number of stages in extracted features:", len(features))
print("Size of extracted features in stage 1:", features[0].size())
print("Size of extracted features in stage 4:", features[3].size())
📄 ライセンス
このモデルは NVIDIA Source Code License-NC の下で提供されています。