模型概述
模型特點
模型能力
使用案例
🚀 Jamba-v0.1-9B
Jamba-v0.1-9B 是 Jamba-v0.1 的密集版本,它提取了第一個專家的權重,不再使用混合專家(MoE)架構。具體細節請參考 此腳本。該模型可以使用單張 3090/4090 顯卡進行推理,使用方法與 Jamba-v0.1 完全相同。
🚀 快速開始
環境準備
Jamba 需要使用 transformers
4.39.0 或更高版本:
pip install transformers>=4.39.0
為了運行優化後的 Mamba 實現,你首先需要安裝 mamba-ssm
和 causal-conv1d
:
pip install mamba-ssm causal-conv1d>=1.2.0
同時,你需要將模型部署在 CUDA 設備上。
你也可以在不使用優化的 Mamba 內核的情況下運行模型,但不建議這樣做,因為這會顯著降低推理速度。若要這樣運行,在加載模型時需要指定 use_mamba_kernels=False
。
運行模型
請注意,目前運行新的 Jamba 架構需要設置 trust_remote_code=True
:
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1",
trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-v0.1")
input_ids = tokenizer("In the recent Super Bowl LVIII,", return_tensors='pt').to(model.device)["input_ids"]
outputs = model.generate(input_ids, max_new_tokens=216)
print(tokenizer.batch_decode(outputs))
# ["<|startoftext|>In the recent Super Bowl LVIII, the Kansas City Chiefs emerged victorious, defeating the San Francisco 49ers in a thrilling overtime showdown. The game was a nail-biter, with both teams showcasing their skills and determination.\n\nThe Chiefs, led by their star quarterback Patrick Mahomes, displayed their offensive prowess, while the 49ers, led by their strong defense, put up a tough fight. The game went into overtime, with the Chiefs ultimately securing the win with a touchdown.\n\nThe victory marked the Chiefs' second Super Bowl win in four years, solidifying their status as one of the top teams in the NFL. The game was a testament to the skill and talent of both teams, and a thrilling end to the NFL season.\n\nThe Super Bowl is not just about the game itself, but also about the halftime show and the commercials. This year's halftime show featured a star-studded lineup, including Usher, Alicia Keys, and Lil Jon. The show was a spectacle of music and dance, with the performers delivering an energetic and entertaining performance.\n"]
以半精度加載模型
發佈的檢查點以 BF16 格式保存。若要以 BF16/FP16 格式將其加載到內存中,需要指定 torch_dtype
:
from transformers import AutoModelForCausalLM
import torch
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1",
trust_remote_code=True,
torch_dtype=torch.bfloat16) # 你也可以使用 torch_dtype=torch.float16
使用半精度時,可以啟用 Attention 塊的 FlashAttention2 實現。要使用此功能,模型也需要部署在 CUDA 設備上。由於在這種精度下,模型太大無法裝入單張 80GB GPU,因此還需要使用 accelerate 進行並行化:
from transformers import AutoModelForCausalLM
import torch
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1",
trust_remote_code=True,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
device_map="auto")
以 8 位精度加載模型
使用 8 位精度時,單張 80GB GPU 最多可以處理 140K 的序列長度。 你可以使用 bitsandbytes 輕鬆將模型量化為 8 位。為了不降低模型質量,建議在量化時排除 Mamba 塊:
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(load_in_8bit=True,
llm_int8_skip_modules=["mamba"])
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1",
trust_remote_code=True,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
quantization_config=quantization_config)
微調示例
Jamba 是一個基礎模型,可以針對自定義解決方案進行微調(包括聊天/指令版本)。你可以使用任何你選擇的技術進行微調。以下是使用 PEFT 庫進行微調的示例:
from datasets import load_dataset
from trl import SFTTrainer
from peft import LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-v0.1")
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1", trust_remote_code=True, device_map='auto')
dataset = load_dataset("Abirate/english_quotes", split="train")
training_args = TrainingArguments(
output_dir="./results",
num_train_epochs=3,
per_device_train_batch_size=4,
logging_dir='./logs',
logging_steps=10,
learning_rate=2e-3
)
lora_config = LoraConfig(
r=8,
target_modules=["embed_tokens", "x_proj", "in_proj", "out_proj"],
task_type="CAUSAL_LM",
bias="none"
)
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
args=training_args,
peft_config=lora_config,
train_dataset=dataset,
dataset_text_field="quote",
)
trainer.train()
✨ 主要特性
- Jamba 是最先進的混合 SSM-Transformer 大語言模型,與傳統基於 Transformer 的模型相比,它提高了吞吐量,並且在大多數常見基準測試中,表現優於或等同於同規模的領先模型。
- 它是第一個生產規模的 Mamba 實現,為研究和應用開闢了新的機會。
📦 安裝指南
- 安裝
transformers
4.39.0 或更高版本:pip install transformers>=4.39.0
- 安裝
mamba-ssm
和causal-conv1d
:pip install mamba-ssm causal-conv1d>=1.2.0
💻 使用示例
基礎用法
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1",
trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-v0.1")
input_ids = tokenizer("In the recent Super Bowl LVIII,", return_tensors='pt').to(model.device)["input_ids"]
outputs = model.generate(input_ids, max_new_tokens=216)
print(tokenizer.batch_decode(outputs))
高級用法
半精度加載模型
from transformers import AutoModelForCausalLM
import torch
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1",
trust_remote_code=True,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
device_map="auto")
8 位精度加載模型
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
import torch
quantization_config = BitsAndBytesConfig(load_in_8bit=True,
llm_int8_skip_modules=["mamba"])
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1",
trust_remote_code=True,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
quantization_config=quantization_config)
模型微調
from datasets import load_dataset
from trl import SFTTrainer
from peft import LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-v0.1")
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1", trust_remote_code=True, device_map='auto')
dataset = load_dataset("Abirate/english_quotes", split="train")
training_args = TrainingArguments(
output_dir="./results",
num_train_epochs=3,
per_device_train_batch_size=4,
logging_dir='./logs',
logging_steps=10,
learning_rate=2e-3
)
lora_config = LoraConfig(
r=8,
target_modules=["embed_tokens", "x_proj", "in_proj", "out_proj"],
task_type="CAUSAL_LM",
bias="none"
)
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
args=training_args,
peft_config=lora_config,
train_dataset=dataset,
dataset_text_field="quote",
)
trainer.train()
📚 詳細文檔
模型詳情
屬性 | 詳情 |
---|---|
開發者 | AI21 |
模型類型 | 聯合注意力與 Mamba(Jamba) |
許可證 | Apache 2.0 |
上下文長度 | 256K |
知識截止日期 | 2024 年 3 月 5 日 |
常見基準測試結果
基準測試 | 得分 |
---|---|
HellaSwag | 87.1% |
Arc Challenge | 64.4% |
WinoGrande | 82.5% |
PIQA | 83.2% |
MMLU | 67.4% |
BBH | 45.4% |
TruthfulQA | 46.4% |
GSM8K (CoT) | 59.9% |
需要注意的是,所有提示都必須添加 'BOS' 標記,在某些評估框架中,該標記可能默認未啟用。
注意事項
Jamba 是預訓練的基礎模型,未針對指令/聊天交互進行對齊。作為基礎模型,Jamba 旨在作為微調、訓練和開發自定義解決方案的基礎層。它沒有安全 moderation 機制,為了安全和負責任地使用,應該添加相應的防護措施。
關於 AI21
AI21 為企業構建可靠、實用且可擴展的 AI 解決方案。Jamba 是 AI21 新模型系列中的第一個,Jamba 的指令版本可通過 AI21 平臺 進行測試。
📄 許可證
本項目採用 Apache 2.0 許可證。



