🚀 Jamba-v0.1-9B
Jamba-v0.1 の高密度バージョンで、最初のエキスパートの重みを抽出しています。
これはもはやMoEを使用しなくなりまし。詳細については、このスクリプト を参照してください。
単一の3090/4090で推論を行うことができ、使用方法はJamba-v0.1とまったく同じです。
🚀 クイックスタート
Jambaは最先端のハイブリッドSSM-Transformer LLMです。従来のTransformerベースのモデルに比べてスループットが向上し、同サイズクラスの主要モデルと同等またはそれ以上の性能を、ほとんどの一般的なベンチマークで発揮します。
このモデルカードはJambaのベースバージョンに関するものです。事前学習済みのエキスパート混合(MoE)生成テキストモデルで、アクティブパラメータが120億、すべてのエキスパートを合わせた総パラメータは520億です。256Kのコンテキスト長をサポートし、単一の80GB GPUで最大140Kトークンを扱うことができます。
このモデルの詳細については、リリースブログ記事 をご覧ください。
✨ 主な機能
- 最先端のハイブリッドSSM-Transformer LLMで、従来のTransformerベースのモデルに比べてスループットが向上します。
- 同サイズクラスの主要モデルと同等またはそれ以上の性能を、ほとんどの一般的なベンチマークで発揮します。
- 256Kのコンテキスト長をサポートし、単一の80GB GPUで最大140Kトークンを扱うことができます。
📦 インストール
前提条件
Jambaでは transformers
バージョン4.39.0以上が必要です。
pip install transformers>=4.39.0
最適化されたMambaの実装を実行するには、まず mamba-ssm
と causal-conv1d
をインストールする必要があります。
pip install mamba-ssm causal-conv1d>=1.2.0
また、モデルをCUDAデバイス上で実行する必要があります。
最適化されたMambaカーネルを使用せずにモデルを実行することもできますが、これは大幅にレイテンシが増加するため お勧めしません。その場合は、モデルをロードする際に use_mamba_kernels=False
を指定する必要があります。
💻 使用例
基本的な使用法
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1",
trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-v0.1")
input_ids = tokenizer("In the recent Super Bowl LVIII,", return_tensors='pt').to(model.device)["input_ids"]
outputs = model.generate(input_ids, max_new_tokens=216)
print(tokenizer.batch_decode(outputs))
高度な使用法
半精度でモデルをロードする
公開されているチェックポイントはBF16で保存されています。BF16/FP16でRAMにロードするには、torch_dtype
を指定する必要があります。
from transformers import AutoModelForCausalLM
import torch
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1",
trust_remote_code=True,
torch_dtype=torch.bfloat16)
半精度を使用する場合、Attentionブロックの FlashAttention2 実装を有効にすることができます。これを使用するには、モデルをCUDAデバイス上に配置する必要もあります。この精度ではモデルが大きすぎて単一の80GB GPUに収まらないため、accelerate を使用して並列化する必要もあります。
from transformers import AutoModelForCausalLM
import torch
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1",
trust_remote_code=True,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
device_map="auto")
8ビットでモデルをロードする
8ビット精度を使用すると、単一の80GB GPUで最大140Kのシーケンス長を収めることができます。 bitsandbytes を使用してモデルを簡単に8ビットに量子化することができます。モデルの品質を低下させないために、Mambaブロックを量子化から除外することをお勧めします。
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(load_in_8bit=True,
llm_int8_skip_modules=["mamba"])
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1",
trust_remote_code=True,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
quantization_config=quantization_config)
ファインチューニングの例
Jambaは、カスタムソリューション(チャット/命令バージョンを含む)のためにファインチューニングできるベースモデルです。好きな手法を使ってファインチューニングすることができます。以下は PEFT ライブラリを使用したファインチューニングの例です。
from datasets import load_dataset
from trl import SFTTrainer
from peft import LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-v0.1")
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1", trust_remote_code=True, device_map='auto')
dataset = load_dataset("Abirate/english_quotes", split="train")
training_args = TrainingArguments(
output_dir="./results",
num_train_epochs=3,
per_device_train_batch_size=4,
logging_dir='./logs',
logging_steps=10,
learning_rate=2e-3
)
lora_config = LoraConfig(
r=8,
target_modules=["embed_tokens", "x_proj", "in_proj", "out_proj"],
task_type="CAUSAL_LM",
bias="none"
)
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
args=training_args,
peft_config=lora_config,
train_dataset=dataset,
dataset_text_field="quote",
)
trainer.train()
📚 ドキュメント
一般的なベンチマークの結果
ベンチマーク |
スコア |
HellaSwag |
87.1% |
Arc Challenge |
64.4% |
WinoGrande |
82.5% |
PIQA |
83.2% |
MMLU |
67.4% |
BBH |
45.4% |
TruthfulQA |
46.4% |
GSM8K (CoT) |
59.9% |
すべてのプロンプトに 'BOS' トークンを追加することが重要です。これは、すべての評価フレームワークでデフォルトで有効になっていない場合があります。
注意事項
Jambaは事前学習済みのベースモデルであり、命令/チャットインタラクションのためのアライメントは行われていません。
ベースモデルとして、Jambaはファインチューニング、トレーニング、およびカスタムソリューションの開発のための基礎層として使用することを目的としています。Jambaには安全モデレーションメカニズムがないため、責任ある安全な使用のためにガードレールを追加する必要があります。
AI21について
AI21は、企業向けの信頼性が高く、実用的で、スケーラブルなAIソリューションを構築しています。
JambaはAI21の新しいモデルファミリーの最初のモデルで、Jambaの命令バージョンは AI21プラットフォーム を通じてベータ版で利用できます。
📄 ライセンス
このモデルはApache 2.0ライセンスの下で提供されています。