đ Jamba Model
Jamba is a state - of - the - art, hybrid SSM - Transformer LLM. It offers throughput gains over traditional Transformer - based models and performs well on common benchmarks. This README provides details about the base version of Jamba, including its usage, performance, and more.
đ Quick Start
This is the base version of the Jamba model. We've since released a better, instruct - tuned version, [Jamba - 1.5 - Mini](https://huggingface.co/ai21labs/AI21 - Jamba - 1.5 - Mini). For even greater performance, check out the scaled - up [Jamba - 1.5 - Large](https://huggingface.co/ai21labs/AI21 - Jamba - 1.5 - Large).
For full details of this model please read the white paper and the release blog post.
⨠Features
- High - performance LLM: Delivers throughput gains over traditional Transformer - based models.
- Large context length: Supports a 256K context length and can fit up to 140K tokens on a single 80GB GPU.
- Mixture - of - experts architecture: With 12B active parameters and a total of 52B parameters across all experts.
đĻ Installation
Prerequisites
In order to use Jamba, it is recommended you use transformers
version 4.40.0 or higher (version 4.39.0 or higher is required):
pip install transformers>=4.40.0
In order to run optimized Mamba implementations, you first need to install mamba - ssm
and causal - conv1d
:
pip install mamba - ssm causal - conv1d>=1.2.0
You also have to have the model on a CUDA device.
You can run the model not using the optimized Mamba kernels, but it is not recommended as it will result in significantly lower latencies. In order to do that, you'll need to specify use_mamba_kernels=False
when loading the model.
đģ Usage Examples
Basic Usage
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1")
tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-v0.1")
input_ids = tokenizer("In the recent Super Bowl LVIII,", return_tensors='pt').to(model.device)["input_ids"]
outputs = model.generate(input_ids, max_new_tokens=216)
print(tokenizer.batch_decode(outputs))
Please note that if you're using transformers<4.40.0
, trust_remote_code=True
is required for running the new Jamba architecture.
Advanced Usage - Loading the model in half precision
The published checkpoint is saved in BF16. In order to load it into RAM in BF16/FP16, you need to specify torch_dtype
:
from transformers import AutoModelForCausalLM
import torch
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1",
torch_dtype=torch.bfloat16)
When using half precision, you can enable the [FlashAttention2](https://github.com/Dao - AILab/flash - attention) implementation of the Attention blocks. In order to use it, you also need the model on a CUDA device. Since in this precision the model is too big to fit on a single 80GB GPU, you'll also need to parallelize it using accelerate:
from transformers import AutoModelForCausalLM
import torch
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1",
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
device_map="auto")
Advanced Usage - Load the model in 8 - bit
Using 8 - bit precision, it is possible to fit up to 140K sequence lengths on a single 80GB GPU. You can easily quantize the model to 8 - bit using bitsandbytes. In order to not degrade model quality, we recommend to exclude the Mamba blocks from the quantization:
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(load_in_8bit=True,
llm_int8_skip_modules=["mamba"])
model = AutoModelForCausalLM.from_pretrained("ai21labs/Jamba-v0.1",
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
quantization_config=quantization_config)
Advanced Usage - Fine - tuning example
Jamba is a base model that can be fine - tuned for custom solutions (including for chat/instruct versions). You can fine - tune it using any technique of your choice. Here is an example of fine - tuning with the PEFT library (requires ~120GB GPU RAM, in example 2xA100 80GB):
import torch
from datasets import load_dataset
from trl import SFTTrainer, SFTConfig
from peft import LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-v0.1")
model = AutoModelForCausalLM.from_pretrained(
"ai21labs/Jamba-v0.1", device_map='auto', torch_dtype=torch.bfloat16)
lora_config = LoraConfig(
r=8,
target_modules=[
"embed_tokens",
"x_proj", "in_proj", "out_proj",
"gate_proj", "up_proj", "down_proj",
"q_proj", "k_proj", "v_proj"
],
task_type="CAUSAL_LM",
bias="none"
)
dataset = load_dataset("Abirate/english_quotes", split="train")
training_args = SFTConfig(
output_dir="./results",
num_train_epochs=2,
per_device_train_batch_size=4,
logging_dir='./logs',
logging_steps=10,
learning_rate=1e-5,
dataset_text_field="quote",
)
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
args=training_args,
peft_config=lora_config,
train_dataset=dataset,
)
trainer.train()
đ Documentation
Model Details
Property |
Details |
Developed by |
AI21 |
Model Type |
Joint Attention and Mamba (Jamba) |
License |
Apache 2.0 |
Context length |
256K |
Knowledge cutoff date |
March 5, 2024 |
Results on common benchmarks
Benchmark |
Score |
HellaSwag |
87.1% |
Arc Challenge |
64.4% |
WinoGrande |
82.5% |
PIQA |
83.2% |
MMLU |
67.4% |
BBH |
45.4% |
TruthfulQA |
46.4% |
GSM8K (CoT) |
59.9% |
It's crucial that the 'BOS' token is added to all prompts, which might not be enabled by default in all eval frameworks.
Notice
Jamba is a pretrained base model and did not undergo any alignment for instruct/chat interactions.
As a base model, Jamba is intended for use as a foundation layer for fine tuning, training, and developing custom solutions. Jamba does not have safety moderation mechanisms and guardrails should be added for responsible and safe use.
About AI21
AI21 builds reliable, practical, and scalable AI solutions for the enterprise.
Jamba is the first in AI21âs new family of models, and the Instruct version of Jamba is coming soon to the AI21 platform.
đ License
This project is licensed under the Apache 2.0 license.