模型简介
模型特点
模型能力
使用案例
🚀 Jamba-v0.1-9B
Jamba-v0.1-9B 是 Jamba-v0.1 的密集版本,它提取了第一个专家的权重,不再使用混合专家(MoE)架构。具体细节请参考 此脚本。该模型可以使用单张 3090/4090 显卡进行推理,使用方法与 Jamba-v0.1 完全相同。
🚀 快速开始
环境准备
Jamba 需要使用 transformers
4.39.0 或更高版本:
pip install transformers>=4.39.0
为了运行优化后的 Mamba 实现,你首先需要安装 mamba-ssm
和 causal-conv1d
:
pip install mamba-ssm causal-conv1d>=1.2.0
同时,你需要将模型部署在 CUDA 设备上。
你也可以在不使用优化的 Mamba 内核的情况下运行模型,但不建议这样做,因为这会显著降低推理速度。若要这样运行,在加载模型时需要指定 use_mamba_kernels=False
。
运行模型
请注意,目前运行新的 Jamba 架构需要设置 trust_remote_code=True
:
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1",
trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-v0.1")
input_ids = tokenizer("In the recent Super Bowl LVIII,", return_tensors='pt').to(model.device)["input_ids"]
outputs = model.generate(input_ids, max_new_tokens=216)
print(tokenizer.batch_decode(outputs))
# ["<|startoftext|>In the recent Super Bowl LVIII, the Kansas City Chiefs emerged victorious, defeating the San Francisco 49ers in a thrilling overtime showdown. The game was a nail-biter, with both teams showcasing their skills and determination.\n\nThe Chiefs, led by their star quarterback Patrick Mahomes, displayed their offensive prowess, while the 49ers, led by their strong defense, put up a tough fight. The game went into overtime, with the Chiefs ultimately securing the win with a touchdown.\n\nThe victory marked the Chiefs' second Super Bowl win in four years, solidifying their status as one of the top teams in the NFL. The game was a testament to the skill and talent of both teams, and a thrilling end to the NFL season.\n\nThe Super Bowl is not just about the game itself, but also about the halftime show and the commercials. This year's halftime show featured a star-studded lineup, including Usher, Alicia Keys, and Lil Jon. The show was a spectacle of music and dance, with the performers delivering an energetic and entertaining performance.\n"]
以半精度加载模型
发布的检查点以 BF16 格式保存。若要以 BF16/FP16 格式将其加载到内存中,需要指定 torch_dtype
:
from transformers import AutoModelForCausalLM
import torch
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1",
trust_remote_code=True,
torch_dtype=torch.bfloat16) # 你也可以使用 torch_dtype=torch.float16
使用半精度时,可以启用 Attention 块的 FlashAttention2 实现。要使用此功能,模型也需要部署在 CUDA 设备上。由于在这种精度下,模型太大无法装入单张 80GB GPU,因此还需要使用 accelerate 进行并行化:
from transformers import AutoModelForCausalLM
import torch
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1",
trust_remote_code=True,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
device_map="auto")
以 8 位精度加载模型
使用 8 位精度时,单张 80GB GPU 最多可以处理 140K 的序列长度。 你可以使用 bitsandbytes 轻松将模型量化为 8 位。为了不降低模型质量,建议在量化时排除 Mamba 块:
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(load_in_8bit=True,
llm_int8_skip_modules=["mamba"])
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1",
trust_remote_code=True,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
quantization_config=quantization_config)
微调示例
Jamba 是一个基础模型,可以针对自定义解决方案进行微调(包括聊天/指令版本)。你可以使用任何你选择的技术进行微调。以下是使用 PEFT 库进行微调的示例:
from datasets import load_dataset
from trl import SFTTrainer
from peft import LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-v0.1")
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1", trust_remote_code=True, device_map='auto')
dataset = load_dataset("Abirate/english_quotes", split="train")
training_args = TrainingArguments(
output_dir="./results",
num_train_epochs=3,
per_device_train_batch_size=4,
logging_dir='./logs',
logging_steps=10,
learning_rate=2e-3
)
lora_config = LoraConfig(
r=8,
target_modules=["embed_tokens", "x_proj", "in_proj", "out_proj"],
task_type="CAUSAL_LM",
bias="none"
)
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
args=training_args,
peft_config=lora_config,
train_dataset=dataset,
dataset_text_field="quote",
)
trainer.train()
✨ 主要特性
- Jamba 是最先进的混合 SSM-Transformer 大语言模型,与传统基于 Transformer 的模型相比,它提高了吞吐量,并且在大多数常见基准测试中,表现优于或等同于同规模的领先模型。
- 它是第一个生产规模的 Mamba 实现,为研究和应用开辟了新的机会。
📦 安装指南
- 安装
transformers
4.39.0 或更高版本:pip install transformers>=4.39.0
- 安装
mamba-ssm
和causal-conv1d
:pip install mamba-ssm causal-conv1d>=1.2.0
💻 使用示例
基础用法
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1",
trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-v0.1")
input_ids = tokenizer("In the recent Super Bowl LVIII,", return_tensors='pt').to(model.device)["input_ids"]
outputs = model.generate(input_ids, max_new_tokens=216)
print(tokenizer.batch_decode(outputs))
高级用法
半精度加载模型
from transformers import AutoModelForCausalLM
import torch
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1",
trust_remote_code=True,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
device_map="auto")
8 位精度加载模型
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
import torch
quantization_config = BitsAndBytesConfig(load_in_8bit=True,
llm_int8_skip_modules=["mamba"])
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1",
trust_remote_code=True,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
quantization_config=quantization_config)
模型微调
from datasets import load_dataset
from trl import SFTTrainer
from peft import LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-v0.1")
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1", trust_remote_code=True, device_map='auto')
dataset = load_dataset("Abirate/english_quotes", split="train")
training_args = TrainingArguments(
output_dir="./results",
num_train_epochs=3,
per_device_train_batch_size=4,
logging_dir='./logs',
logging_steps=10,
learning_rate=2e-3
)
lora_config = LoraConfig(
r=8,
target_modules=["embed_tokens", "x_proj", "in_proj", "out_proj"],
task_type="CAUSAL_LM",
bias="none"
)
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
args=training_args,
peft_config=lora_config,
train_dataset=dataset,
dataset_text_field="quote",
)
trainer.train()
📚 详细文档
模型详情
属性 | 详情 |
---|---|
开发者 | AI21 |
模型类型 | 联合注意力与 Mamba(Jamba) |
许可证 | Apache 2.0 |
上下文长度 | 256K |
知识截止日期 | 2024 年 3 月 5 日 |
常见基准测试结果
基准测试 | 得分 |
---|---|
HellaSwag | 87.1% |
Arc Challenge | 64.4% |
WinoGrande | 82.5% |
PIQA | 83.2% |
MMLU | 67.4% |
BBH | 45.4% |
TruthfulQA | 46.4% |
GSM8K (CoT) | 59.9% |
需要注意的是,所有提示都必须添加 'BOS' 标记,在某些评估框架中,该标记可能默认未启用。
注意事项
Jamba 是预训练的基础模型,未针对指令/聊天交互进行对齐。作为基础模型,Jamba 旨在作为微调、训练和开发自定义解决方案的基础层。它没有安全 moderation 机制,为了安全和负责任地使用,应该添加相应的防护措施。
关于 AI21
AI21 为企业构建可靠、实用且可扩展的 AI 解决方案。Jamba 是 AI21 新模型系列中的第一个,Jamba 的指令版本可通过 AI21 平台 进行测试。
📄 许可证
本项目采用 Apache 2.0 许可证。



