模型概述
模型特點
模型能力
使用案例
🚀 Jamba模型
Jamba是一款先進的混合SSM-Transformer大語言模型。相較於傳統基於Transformer的模型,它能顯著提升吞吐量,並且在多數常見基準測試中,表現優於或媲美同規模的領先模型。本版本為Jamba的基礎版本,目前已發佈了效果更佳的指令調優版本Jamba-1.5-Mini。若追求更高性能,可查看擴展版Jamba-1.5-Large。
🚀 快速開始
前置依賴
若要使用Jamba,建議使用transformers
4.40.0或更高版本(最低要求4.39.0版本):
pip install transformers>=4.40.0
若要運行優化後的Mamba實現,需先安裝mamba-ssm
和causal-conv1d
:
pip install mamba-ssm causal-conv1d>=1.2.0
同時,模型需運行在CUDA設備上。
也可以不使用優化後的Mamba內核來運行模型,但不建議這樣做,因為這會顯著降低運行效率。若要如此,在加載模型時需指定use_mamba_kernels=False
。
運行模型
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1")
tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-v0.1")
input_ids = tokenizer("In the recent Super Bowl LVIII,", return_tensors='pt').to(model.device)["input_ids"]
outputs = model.generate(input_ids, max_new_tokens=216)
print(tokenizer.batch_decode(outputs))
# ["<|startoftext|>In the recent Super Bowl LVIII, the Kansas City Chiefs emerged victorious, defeating the San Francisco 49ers in a thrilling overtime showdown. The game was a nail-biter, with both teams showcasing their skills and determination.\n\nThe Chiefs, led by their star quarterback Patrick Mahomes, displayed their offensive prowess, while the 49ers, led by their strong defense, put up a tough fight. The game went into overtime, with the Chiefs ultimately securing the win with a touchdown.\n\nThe victory marked the Chiefs' second Super Bowl win in four years, solidifying their status as one of the top teams in the NFL. The game was a testament to the skill and talent of both teams, and a thrilling end to the NFL season.\n\nThe Super Bowl is not just about the game itself, but also about the halftime show and the commercials. This year's halftime show featured a star-studded lineup, including Usher, Alicia Keys, and Lil Jon. The show was a spectacle of music and dance, with the performers delivering an energetic and entertaining performance.\n"]
請注意,若使用transformers<4.40.0
版本,運行新的Jamba架構時需要指定trust_remote_code=True
。
以半精度加載模型
發佈的檢查點以BF16格式保存。若要以BF16/FP16格式將其加載到內存中,需指定torch_dtype
:
from transformers import AutoModelForCausalLM
import torch
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1",
torch_dtype=torch.bfloat16) # you can also use torch_dtype=torch.float16
使用半精度時,可啟用Attention模塊的FlashAttention2實現。若要使用該實現,模型也需運行在CUDA設備上。由於在此精度下,模型太大無法加載到單張80GB GPU上,還需使用accelerate進行並行化處理:
from transformers import AutoModelForCausalLM
import torch
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1",
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
device_map="auto")
以8位精度加載模型
使用8位精度時,單張80GB GPU最多可處理140K的序列長度。 可使用bitsandbytes輕鬆將模型量化為8位。為避免降低模型質量,建議在量化時排除Mamba模塊:
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(load_in_8bit=True,
llm_int8_skip_modules=["mamba"])
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1",
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
quantization_config=quantization_config)
微調示例
Jamba是一個基礎模型,可針對自定義解決方案進行微調(包括聊天/指令版本)。可以使用任意選擇的技術進行微調。以下是使用PEFT庫進行微調的示例(約需120GB GPU內存,如2張A100 80GB GPU):
import torch
from datasets import load_dataset
from trl import SFTTrainer, SFTConfig
from peft import LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-v0.1")
model = AutoModelForCausalLM.from_pretrained(
"ai21labs/Jamba-v0.1", device_map='auto', torch_dtype=torch.bfloat16)
lora_config = LoraConfig(
r=8,
target_modules=[
"embed_tokens",
"x_proj", "in_proj", "out_proj", # mamba
"gate_proj", "up_proj", "down_proj", # mlp
"q_proj", "k_proj", "v_proj" # attention
],
task_type="CAUSAL_LM",
bias="none"
)
dataset = load_dataset("Abirate/english_quotes", split="train")
training_args = SFTConfig(
output_dir="./results",
num_train_epochs=2,
per_device_train_batch_size=4,
logging_dir='./logs',
logging_steps=10,
learning_rate=1e-5,
dataset_text_field="quote",
)
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
args=training_args,
peft_config=lora_config,
train_dataset=dataset,
)
trainer.train()
✨ 主要特性
- 作為Joint Attention and Mamba (Jamba)類型的模型,結合了多種技術優勢。
- 擁有256K的上下文長度,單張80GB GPU最多可容納140K個token。
- 知識截止日期為2024年3月5日。
📦 安裝指南
pip install transformers>=4.40.0
pip install mamba-ssm causal-conv1d>=1.2.0
💻 使用示例
基礎用法
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1")
tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-v0.1")
input_ids = tokenizer("In the recent Super Bowl LVIII,", return_tensors='pt').to(model.device)["input_ids"]
outputs = model.generate(input_ids, max_new_tokens=216)
print(tokenizer.batch_decode(outputs))
# ["<|startoftext|>In the recent Super Bowl LVIII, the Kansas City Chiefs emerged victorious, defeating the San Francisco 49ers in a thrilling overtime showdown. The game was a nail-biter, with both teams showcasing their skills and determination.\n\nThe Chiefs, led by their star quarterback Patrick Mahomes, displayed their offensive prowess, while the 49ers, led by their strong defense, put up a tough fight. The game went into overtime, with the Chiefs ultimately securing the win with a touchdown.\n\nThe victory marked the Chiefs' second Super Bowl win in four years, solidifying their status as one of the top teams in the NFL. The game was a testament to the skill and talent of both teams, and a thrilling end to the NFL season.\n\nThe Super Bowl is not just about the game itself, but also about the halftime show and the commercials. This year's halftime show featured a star-studded lineup, including Usher, Alicia Keys, and Lil Jon. The show was a spectacle of music and dance, with the performers delivering an energetic and entertaining performance.\n"]
高級用法
半精度加載模型
from transformers import AutoModelForCausalLM
import torch
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1",
torch_dtype=torch.bfloat16) # you can also use torch_dtype=torch.float16
結合FlashAttention2和並行化處理
from transformers import AutoModelForCausalLM
import torch
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1",
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
device_map="auto")
8位精度加載模型
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(load_in_8bit=True,
llm_int8_skip_modules=["mamba"])
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1",
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
quantization_config=quantization_config)
模型微調示例
import torch
from datasets import load_dataset
from trl import SFTTrainer, SFTConfig
from peft import LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-v0.1")
model = AutoModelForCausalLM.from_pretrained(
"ai21labs/Jamba-v0.1", device_map='auto', torch_dtype=torch.bfloat16)
lora_config = LoraConfig(
r=8,
target_modules=[
"embed_tokens",
"x_proj", "in_proj", "out_proj", # mamba
"gate_proj", "up_proj", "down_proj", # mlp
"q_proj", "k_proj", "v_proj" # attention
],
task_type="CAUSAL_LM",
bias="none"
)
dataset = load_dataset("Abirate/english_quotes", split="train")
training_args = SFTConfig(
output_dir="./results",
num_train_epochs=2,
per_device_train_batch_size=4,
logging_dir='./logs',
logging_steps=10,
learning_rate=1e-5,
dataset_text_field="quote",
)
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
args=training_args,
peft_config=lora_config,
train_dataset=dataset,
)
trainer.train()
📚 詳細文檔
🔧 技術細節
Jamba是首個生產規模的Mamba實現,為研究和應用帶來了新的機遇。儘管初步實驗已顯示出令人鼓舞的效果,但預計未來通過進一步的優化和探索,性能還將得到進一步提升。
📄 許可證
本模型採用Apache 2.0許可證。
模型常見基準測試結果
基準測試 | 得分 |
---|---|
HellaSwag | 87.1% |
Arc Challenge | 64.4% |
WinoGrande | 82.5% |
PIQA | 83.2% |
MMLU | 67.4% |
BBH | 45.4% |
TruthfulQA | 46.4% |
GSM8K (CoT) | 59.9% |
⚠️ 重要提示
所有提示都必須添加'BOS'標記,不過並非所有評估框架都默認啟用此功能。
💡 使用建議
Jamba是預訓練的基礎模型,未針對指令/聊天交互進行對齊。作為基礎模型,它旨在作為微調、訓練和開發自定義解決方案的基礎層。Jamba沒有安全 moderation機制,為確保安全和負責任地使用,應添加相應的防護措施。
關於AI21
AI21為企業構建可靠、實用且可擴展的AI解決方案。Jamba是AI21新模型系列中的首個模型,Jamba的指令版本即將在AI21平臺上線。



