🚀 Jambaモデル
このJambaモデルのベースバージョンです。その後、より良いインストラクションチューニング済みのバージョン Jamba-1.5-Mini をリリースしました。さらに高いパフォーマンスが必要な場合は、拡張版の Jamba-1.5-Large をご確認ください。
🚀 クイックスタート
Jambaは最先端のハイブリッドSSM-Transformer LLMです。従来のTransformerベースのモデルに比べてスループットが向上し、同サイズクラスの主要なモデルを多くの一般的なベンチマークで上回るか同等の性能を発揮します。
このモデルカードはJambaのベースバージョンに関するものです。これは事前学習されたエキスパート混合(MoE)生成テキストモデルで、アクティブなパラメータが120億個、すべてのエキスパートを合わせた総パラメータは520億個です。256Kのコンテキスト長をサポートし、単一の80GB GPUで最大140Kトークンを収容できます。
このモデルの詳細については、ホワイトペーパー と リリースブログ記事 をご覧ください。
✨ 主な機能
- Jambaは最先端のハイブリッドSSM-Transformer LLMで、従来のTransformerベースのモデルに比べてスループットが向上します。
- 同サイズクラスの主要なモデルを多くの一般的なベンチマークで上回るか同等の性能を発揮します。
- 最初の本番規模のMamba実装であり、興味深い研究とアプリケーションの機会を開拓します。
📦 インストール
前提条件
Jambaを使用するには、transformers
バージョン4.40.0以上(バージョン4.39.0以上が必要)を使用することをお勧めします。
pip install transformers>=4.40.0
最適化されたMamba実装を実行するには、まず mamba-ssm
と causal-conv1d
をインストールする必要があります。
pip install mamba-ssm causal-conv1d>=1.2.0
また、モデルをCUDAデバイス上で実行する必要があります。
最適化されたMambaカーネルを使用せずにモデルを実行することもできますが、これは大幅にレイテンシが増加するため お勧めしません。その場合は、モデルをロードする際に use_mamba_kernels=False
を指定する必要があります。
💻 使用例
基本的な使用法
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1")
tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-v0.1")
input_ids = tokenizer("In the recent Super Bowl LVIII,", return_tensors='pt').to(model.device)["input_ids"]
outputs = model.generate(input_ids, max_new_tokens=216)
print(tokenizer.batch_decode(outputs))
transformers<4.40.0
を使用している場合は、新しいJambaアーキテクチャを実行するために trust_remote_code=True
が必要です。
高度な使用法
半精度でのモデルのロード
公開されているチェックポイントはBF16で保存されています。BF16/FP16でRAMにロードするには、torch_dtype
を指定する必要があります。
from transformers import AutoModelForCausalLM
import torch
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1",
torch_dtype=torch.bfloat16)
半精度を使用する場合、Attentionブロックの FlashAttention2 実装を有効にすることができます。これを使用するには、モデルをCUDAデバイス上に配置する必要もあります。この精度ではモデルが大きすぎて単一の80GB GPUに収まらないため、accelerate を使用して並列化する必要があります。
from transformers import AutoModelForCausalLM
import torch
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1",
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
device_map="auto")
8ビットでのモデルのロード
8ビット精度を使用すると、単一の80GB GPUで最大140Kのシーケンス長を収容することができます。 bitsandbytes を使用してモデルを簡単に8ビットに量子化することができます。モデルの品質を低下させないために、Mambaブロックを量子化から除外することをお勧めします。
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(load_in_8bit=True,
llm_int8_skip_modules=["mamba"])
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1",
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
quantization_config=quantization_config)
ファインチューニングの例
Jambaは、カスタムソリューション(チャット/インストラクションバージョンを含む)のためにファインチューニングできるベースモデルです。好きな手法を使用してファインチューニングすることができます。以下は、PEFT ライブラリを使用したファインチューニングの例です(約120GBのGPU RAMが必要です。例:2xA100 80GB)。
import torch
from datasets import load_dataset
from trl import SFTTrainer, SFTConfig
from peft import LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-v0.1")
model = AutoModelForCausalLM.from_pretrained(
"ai21labs/Jamba-v0.1", device_map='auto', torch_dtype=torch.bfloat16)
lora_config = LoraConfig(
r=8,
target_modules=[
"embed_tokens",
"x_proj", "in_proj", "out_proj",
"gate_proj", "up_proj", "down_proj",
"q_proj", "k_proj", "v_proj"
],
task_type="CAUSAL_LM",
bias="none"
)
dataset = load_dataset("Abirate/english_quotes", split="train")
training_args = SFTConfig(
output_dir="./results",
num_train_epochs=2,
per_device_train_batch_size=4,
logging_dir='./logs',
logging_steps=10,
learning_rate=1e-5,
dataset_text_field="quote",
)
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
args=training_args,
peft_config=lora_config,
train_dataset=dataset,
)
trainer.train()
📚 ドキュメント
一般的なベンチマークの結果
ベンチマーク |
スコア |
HellaSwag |
87.1% |
Arc Challenge |
64.4% |
WinoGrande |
82.5% |
PIQA |
83.2% |
MMLU |
67.4% |
BBH |
45.4% |
TruthfulQA |
46.4% |
GSM8K (CoT) |
59.9% |
すべてのプロンプトに 'BOS' トークンを追加することが重要です。これは、すべての評価フレームワークでデフォルトで有効になっていない場合があります。
注意事項
Jambaは事前学習されたベースモデルであり、インストラクション/チャットインタラクションのためのアライメントは行われていません。
ベースモデルとして、Jambaはファインチューニング、トレーニング、およびカスタムソリューションの開発のための基礎層として使用することを目的としています。Jambaにはセーフティモデレーションメカニズムがないため、責任ある安全な使用のためにガードレールを追加する必要があります。
AI21について
AI21は、企業向けの信頼性が高く、実用的で、スケーラブルなAIソリューションを構築しています。
JambaはAI21の新しいモデルファミリーの最初のモデルであり、Jambaのインストラクションバージョンは間もなく AI21プラットフォーム に登場します。
モデル情報
属性 |
详情 |
モデルタイプ |
Joint Attention and Mamba (Jamba) |
開発元 |
AI21 |
ライセンス |
Apache 2.0 |
コンテキスト長 |
256K |
知識の切断日 |
2024年3月5日 |
⚠️ 重要提示
すべてのプロンプトに 'BOS' トークンを追加することが重要です。これは、すべての評価フレームワークでデフォルトで有効になっていない場合があります。
💡 使用建议
Jambaは事前学習されたベースモデルであり、インストラクション/チャットインタラクションのためのアライメントは行われていません。責任ある安全な使用のためにガードレールを追加することをお勧めします。