模型简介
模型特点
模型能力
使用案例
🚀 Jamba模型
Jamba是一款先进的混合SSM-Transformer大语言模型。相较于传统基于Transformer的模型,它能显著提升吞吐量,并且在多数常见基准测试中,表现优于或媲美同规模的领先模型。本版本为Jamba的基础版本,目前已发布了效果更佳的指令调优版本Jamba-1.5-Mini。若追求更高性能,可查看扩展版Jamba-1.5-Large。
🚀 快速开始
前置依赖
若要使用Jamba,建议使用transformers
4.40.0或更高版本(最低要求4.39.0版本):
pip install transformers>=4.40.0
若要运行优化后的Mamba实现,需先安装mamba-ssm
和causal-conv1d
:
pip install mamba-ssm causal-conv1d>=1.2.0
同时,模型需运行在CUDA设备上。
也可以不使用优化后的Mamba内核来运行模型,但不建议这样做,因为这会显著降低运行效率。若要如此,在加载模型时需指定use_mamba_kernels=False
。
运行模型
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1")
tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-v0.1")
input_ids = tokenizer("In the recent Super Bowl LVIII,", return_tensors='pt').to(model.device)["input_ids"]
outputs = model.generate(input_ids, max_new_tokens=216)
print(tokenizer.batch_decode(outputs))
# ["<|startoftext|>In the recent Super Bowl LVIII, the Kansas City Chiefs emerged victorious, defeating the San Francisco 49ers in a thrilling overtime showdown. The game was a nail-biter, with both teams showcasing their skills and determination.\n\nThe Chiefs, led by their star quarterback Patrick Mahomes, displayed their offensive prowess, while the 49ers, led by their strong defense, put up a tough fight. The game went into overtime, with the Chiefs ultimately securing the win with a touchdown.\n\nThe victory marked the Chiefs' second Super Bowl win in four years, solidifying their status as one of the top teams in the NFL. The game was a testament to the skill and talent of both teams, and a thrilling end to the NFL season.\n\nThe Super Bowl is not just about the game itself, but also about the halftime show and the commercials. This year's halftime show featured a star-studded lineup, including Usher, Alicia Keys, and Lil Jon. The show was a spectacle of music and dance, with the performers delivering an energetic and entertaining performance.\n"]
请注意,若使用transformers<4.40.0
版本,运行新的Jamba架构时需要指定trust_remote_code=True
。
以半精度加载模型
发布的检查点以BF16格式保存。若要以BF16/FP16格式将其加载到内存中,需指定torch_dtype
:
from transformers import AutoModelForCausalLM
import torch
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1",
torch_dtype=torch.bfloat16) # you can also use torch_dtype=torch.float16
使用半精度时,可启用Attention模块的FlashAttention2实现。若要使用该实现,模型也需运行在CUDA设备上。由于在此精度下,模型太大无法加载到单张80GB GPU上,还需使用accelerate进行并行化处理:
from transformers import AutoModelForCausalLM
import torch
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1",
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
device_map="auto")
以8位精度加载模型
使用8位精度时,单张80GB GPU最多可处理140K的序列长度。 可使用bitsandbytes轻松将模型量化为8位。为避免降低模型质量,建议在量化时排除Mamba模块:
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(load_in_8bit=True,
llm_int8_skip_modules=["mamba"])
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1",
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
quantization_config=quantization_config)
微调示例
Jamba是一个基础模型,可针对自定义解决方案进行微调(包括聊天/指令版本)。可以使用任意选择的技术进行微调。以下是使用PEFT库进行微调的示例(约需120GB GPU内存,如2张A100 80GB GPU):
import torch
from datasets import load_dataset
from trl import SFTTrainer, SFTConfig
from peft import LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-v0.1")
model = AutoModelForCausalLM.from_pretrained(
"ai21labs/Jamba-v0.1", device_map='auto', torch_dtype=torch.bfloat16)
lora_config = LoraConfig(
r=8,
target_modules=[
"embed_tokens",
"x_proj", "in_proj", "out_proj", # mamba
"gate_proj", "up_proj", "down_proj", # mlp
"q_proj", "k_proj", "v_proj" # attention
],
task_type="CAUSAL_LM",
bias="none"
)
dataset = load_dataset("Abirate/english_quotes", split="train")
training_args = SFTConfig(
output_dir="./results",
num_train_epochs=2,
per_device_train_batch_size=4,
logging_dir='./logs',
logging_steps=10,
learning_rate=1e-5,
dataset_text_field="quote",
)
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
args=training_args,
peft_config=lora_config,
train_dataset=dataset,
)
trainer.train()
✨ 主要特性
- 作为Joint Attention and Mamba (Jamba)类型的模型,结合了多种技术优势。
- 拥有256K的上下文长度,单张80GB GPU最多可容纳140K个token。
- 知识截止日期为2024年3月5日。
📦 安装指南
pip install transformers>=4.40.0
pip install mamba-ssm causal-conv1d>=1.2.0
💻 使用示例
基础用法
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1")
tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-v0.1")
input_ids = tokenizer("In the recent Super Bowl LVIII,", return_tensors='pt').to(model.device)["input_ids"]
outputs = model.generate(input_ids, max_new_tokens=216)
print(tokenizer.batch_decode(outputs))
# ["<|startoftext|>In the recent Super Bowl LVIII, the Kansas City Chiefs emerged victorious, defeating the San Francisco 49ers in a thrilling overtime showdown. The game was a nail-biter, with both teams showcasing their skills and determination.\n\nThe Chiefs, led by their star quarterback Patrick Mahomes, displayed their offensive prowess, while the 49ers, led by their strong defense, put up a tough fight. The game went into overtime, with the Chiefs ultimately securing the win with a touchdown.\n\nThe victory marked the Chiefs' second Super Bowl win in four years, solidifying their status as one of the top teams in the NFL. The game was a testament to the skill and talent of both teams, and a thrilling end to the NFL season.\n\nThe Super Bowl is not just about the game itself, but also about the halftime show and the commercials. This year's halftime show featured a star-studded lineup, including Usher, Alicia Keys, and Lil Jon. The show was a spectacle of music and dance, with the performers delivering an energetic and entertaining performance.\n"]
高级用法
半精度加载模型
from transformers import AutoModelForCausalLM
import torch
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1",
torch_dtype=torch.bfloat16) # you can also use torch_dtype=torch.float16
结合FlashAttention2和并行化处理
from transformers import AutoModelForCausalLM
import torch
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1",
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
device_map="auto")
8位精度加载模型
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(load_in_8bit=True,
llm_int8_skip_modules=["mamba"])
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1",
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
quantization_config=quantization_config)
模型微调示例
import torch
from datasets import load_dataset
from trl import SFTTrainer, SFTConfig
from peft import LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-v0.1")
model = AutoModelForCausalLM.from_pretrained(
"ai21labs/Jamba-v0.1", device_map='auto', torch_dtype=torch.bfloat16)
lora_config = LoraConfig(
r=8,
target_modules=[
"embed_tokens",
"x_proj", "in_proj", "out_proj", # mamba
"gate_proj", "up_proj", "down_proj", # mlp
"q_proj", "k_proj", "v_proj" # attention
],
task_type="CAUSAL_LM",
bias="none"
)
dataset = load_dataset("Abirate/english_quotes", split="train")
training_args = SFTConfig(
output_dir="./results",
num_train_epochs=2,
per_device_train_batch_size=4,
logging_dir='./logs',
logging_steps=10,
learning_rate=1e-5,
dataset_text_field="quote",
)
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
args=training_args,
peft_config=lora_config,
train_dataset=dataset,
)
trainer.train()
📚 详细文档
🔧 技术细节
Jamba是首个生产规模的Mamba实现,为研究和应用带来了新的机遇。尽管初步实验已显示出令人鼓舞的效果,但预计未来通过进一步的优化和探索,性能还将得到进一步提升。
📄 许可证
本模型采用Apache 2.0许可证。
模型常见基准测试结果
基准测试 | 得分 |
---|---|
HellaSwag | 87.1% |
Arc Challenge | 64.4% |
WinoGrande | 82.5% |
PIQA | 83.2% |
MMLU | 67.4% |
BBH | 45.4% |
TruthfulQA | 46.4% |
GSM8K (CoT) | 59.9% |
⚠️ 重要提示
所有提示都必须添加'BOS'标记,不过并非所有评估框架都默认启用此功能。
💡 使用建议
Jamba是预训练的基础模型,未针对指令/聊天交互进行对齐。作为基础模型,它旨在作为微调、训练和开发自定义解决方案的基础层。Jamba没有安全 moderation机制,为确保安全和负责任地使用,应添加相应的防护措施。
关于AI21
AI21为企业构建可靠、实用且可扩展的AI解决方案。Jamba是AI21新模型系列中的首个模型,Jamba的指令版本即将在AI21平台上线。



