🚀 Mambaoutai 1.6B
Mambaoutai是此博客文章中描述的所有實驗和訓練運行的成果,該文章分享了該模型系列的所有細節。Mambaoutai是一系列小型Mamba檢查點,供社區探索使用,在法語、英語和代碼數據上進行訓練。我們使用WSD調度器運行了兩個不同的衰減階段,併發布了有無指令數據預訓練的模型檢查點。
🚀 快速開始
Mambaoutai可用於文本生成、推理等任務。你可以按照以下步驟使用該模型。
✨ 主要特性
- 多語言支持:在法語、英語和代碼數據上進行訓練。
- 不同訓練階段檢查點:發佈了有無指令數據預訓練的模型檢查點。
- 輕量級模型:僅有1.6B參數,可在CPU上以合理速度運行。
📦 安裝指南
你需要從main
分支安裝transformers
,直到transformers=4.39.0
版本發佈。
pip install git+https://github.com/huggingface/transformers@main
我們還建議你使用以下命令安裝causal-conv1d
和mamba-ssm
:
pip install causal-conv1d>=1.2.0
pip install mamba-ssm>=1.2.0
如果這兩個庫未安裝,將使用“eager”實現(不推薦),否則將使用更優化的CUDA
內核。
💻 使用示例
基礎用法
使用以下代碼片段從模型生成文本:
from transformers import MambaConfig, MambaForCausalLM, AutoTokenizer
import torch
if model_has_instruct_data:
prompt = ”<start_user>Tell me something about Paris.<end_message><start_assistant>”
else:
prompt = ”This is a text about Paris. Paris is”
tokenizer = AutoTokenizer.from_pretrained("lightonai/mambaoutai")
model = MambaForCausalLM.from_pretrained("lightonai/mambaoutai")
input_ids = tokenizer(prompt, return_tensors="pt")["input_ids"]
out = model.generate(input_ids, max_new_tokens=10)
print(tokenizer.batch_decode(out))
高級用法
你可以在倉庫分支中找到一些訓練檢查點。在訓練過程中的某個時間點對應的分支上。你可以通過在from_pretrained
方法中添加revision
參數,使用這些訓練檢查點進行推理。例如,要加載預訓練30000步後的模型檢查點,可以使用以下代碼:
from transformers import MambaConfig, MambaForCausalLM, AutoTokenizer
import torch
tokenizer = AutoTokenizer.from_pretrained("lightonai/mambaoutai", revision="pre-30000")
model = MambaForCausalLM.from_pretrained("lightonai/mambaoutai", revision="pre-30000")
input_ids = tokenizer("What is a mamba?", return_tensors="pt")["input_ids"]
out = model.generate(input_ids, max_new_tokens=10)
print(tokenizer.batch_decode(out))
設備上推理
由於Mambaoutai僅有1.6B參數,它可以在CPU上以合理速度運行。以下是在llama.cpp上運行它的示例:
git clone https://github.com/ggerganov/llama.cpp
cd llama.cpp
make
conda create -n mamba-cpp python=3.10
conda activate mamba-cpp
pip install -r requirements/requirements-convert-hf-to-gguf.txt
mkdir Mambaoutai
python convert-hf-to-gguf.py Mambaoutai
./main -m Mambaoutai/ggml-model-f16.gguf -p "Building a website can be done in 10 simple steps:\nStep 1:" -n 400 -e -ngl 1
🔧 技術細節
訓練硬件
無指令數據的模型檢查點在OVH Cloud提供的NVIDIA DGX H100上進行了全面訓練,而有指令數據的衰減階段在Orange Cloud Avenue的HPE Cray(配備8xH100)上進行。消融實驗在MeluXina的16個節點(4xA100 - 40GB)上進行。
模型超參數
模型超參數的更多細節如下表所示:
參數 |
詳情 |
d_model |
2688 |
n_layer |
28 |
vocab_size |
65024 |
context_len |
4096 |
rms_norm |
true |
residual_in_fp32 |
true |
fused_add_norm |
true |
conv_kernel |
4 |
d_inner |
5376 |
state_size |
16 |
dtype |
bfloat16 |
tie_word_embeddings |
false |
non embeddings params |
1.27B |
📄 許可證
本項目採用Apache-2.0許可證。
數據集
- togethercomputer/RedPajama-Data-V2
- stingning/ultrachat
語言
評估指標