đ Zamba Model Card
Zamba-7B-v1-phase1 is a hybrid model that combines Mamba, a state-space model, with transformers. It employs a mamba backbone and inserts a shared transformer layer every 6 blocks. Zamba was trained using next-token prediction and utilizes the Mistral v0.1 tokenizer. This architecture was determined after a series of small-scale ablation experiments. Zamba-7B-v1-phase-1 was pre-trained on 1T tokens of text and code data from open web-datasets. Unlike Zamba-v1, this model represents the checkpoint after pure pre-training solely on web-datasets. We primarily envision it as a comparison tool to explore the effects of our annealing process.
â ī¸ Important Note
The current Huggingface implementation of Zamba runs slower than our internal implementation. We are collaborating with the Huggingface team to resolve this issue.
Our technical report detailing the training of Zamba is available here.
đ Quick Start
Prerequisites
To download Zamba, clone Zyphra's fork of transformers:
git clone https://github.com/Zyphra/transformers_zamba
cd transformers_zamba
- Install the repository:
pip install -e .
To run optimized Mamba implementations on a CUDA device, you need to install mamba-ssm
and causal-conv1d
:
pip install mamba-ssm causal-conv1d>=1.2.0
You can run the model without using the optimized Mamba kernels, but it is not recommended as it will lead to significantly higher latency.
To run on CPU, please specify use_mamba_kernels=False
when loading the model using AutoModelForCausalLM.from_pretrained
.
Inference
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
tokenizer = AutoTokenizer.from_pretrained("Zyphra/Zamba-7B-v1-phase1")
model = AutoModelForCausalLM.from_pretrained("Zyphra/Zamba-7B-v1-phase1", device_map="auto", torch_dtype=torch.bfloat16)
input_text = "What factors contributed to the fall of the Roman Empire?"
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
outputs = model.generate(**input_ids, max_new_tokens=100)
print(tokenizer.decode(outputs[0]))
To load a different checkpoint, for example, for iteration 2500:
model = AutoModelForCausalLM.from_pretrained("Zyphra/Zamba-7B-v1-phase1", device_map="auto", torch_dtype=torch.bfloat16, revision="iter2500")
The default iteration is the fully trained phase 1 model, corresponding to iteration 462070. This is the number of iterations performed during the model's training starting from random initialization. See arXiv:2405.16712 for more details on training.
⨠Features
Zamba utilizes a unique hybrid SSM architecture. This architecture consists of a backbone of Mamba layers interspersed with a shared attention layer. This attention has shared weights to minimize the parameter cost of the model. We find that concatenating the original model embeddings to the input to this attention block improves performance, likely due to better maintenance of information across depth.
đ Performance
We find that Zamba performs significantly better than existing open models (with open datasets and training details) at this scale. However, it performs slightly worse than the leading open-weight models at the 7B scale. Most of this difference derives from MMLU and reasoning evaluations. Zamba, however, is trained on significantly fewer tokens than these models and is the most sample efficient model in terms of performance per training tokens.
Due to its SSM architecture, Zamba is extremely efficient in inference, substantially outperforming comparable 7B and 8B models in inference latency as well as memory cost of generation due to its substantially diminished KV cache.
đ Citation
If you find Zamba useful in your work, please cite it as:
@article{glorioso2024zamba,
title={Zamba: A Compact 7B SSM Hybrid Model},
author={Glorioso, Paolo and Anthony, Quentin and Tokpanov, Yury and Whittington, James and Pilault, Jonathan and Ibrahim, Adam and Millidge, Beren},
journal={arXiv preprint arXiv:2405.16712},
year={2024}
}
â ī¸ Notice
Zamba is a pretrained base model and therefore does not have any moderation mechanism. In addition, one should not expect good chat performance, as this model was not fine-tuned for chat.
đ License
Zamba is released under the Apache 2.0 license.