模型简介
模型特点
模型能力
使用案例
🚀 LayerSkip Llama2 7B
LayerSkip Llama2 7B 模型基于 Layer Skip: Enabling Early Exit Inference and Self-Speculative Decoding 中提出的 LayerSkip 方法进行持续预训练,能够执行自推测解码,即利用前面的层进行解码,并使用其余层进行验证。
🚀 快速开始
🔍 模型使用方式
我们提供了 3 种运行模型的方式:
🤗 HuggingFace
HuggingFace 目前尚不支持自推测解码。不过,我们可以通过使用主模型的部分层创建一个草稿模型,来复用其推测解码功能:
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
>>> import torch
>>> from copy import deepcopy
>>> checkpoint = "facebook/layerskip-llama2-7B"
>>> early_exit = 4
>>> device = "cuda" if torch.cuda.is_available() else "cpu"
>>> prompt = "typing import List\ndef bucket_sort(A: List):"
>>> model = AutoModelForCausalLM.from_pretrained(checkpoint)
>>> tokenizer = AutoTokenizer.from_pretrained(checkpoint)
>>> generation_config = model.generation_config
>>> weights_memo = {id(w): w for w in model.parameters()}
>>> assistant_model = deepcopy(model, memo=weights_memo) # Clone main model with shared weights
>>> assistant_model.model.layers = assistant_model.model.layers[:early_exit] # Apply early exit
>>> del assistant_model.model.layers[early_exit:]
>>> model.to(device)
>>> assistant_model.to(device)
>>> inputs = tokenizer(prompt, return_tensors="pt").to(device)
>>> outputs = model.generate(**inputs, assistant_model=assistant_model, generation_config=generation_config)
>>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
请注意,这并非最优实现,因为它需要更多内存来保存重复层的 KV 缓存和激活值。复用早期层的优化实现可在我们的 自定义实现 或 gpt-fast 实现 中找到。
📊 基准测试
如果你想衡量自推测解码和自回归解码之间的速度提升,我们编写了以下脚本: ```python from transformers import LlamaForCausalLM, LlamaConfig, LlamaTokenizer, GenerationConfig import torch from copy import deepcopy from time import time from tqdm import tqdmprompt = "typing import List\ndef bucket_sort(A: List):"
checkpoint = "facebook/layerskip-llama2-7B" early_exit = 4 device = "cuda" if torch.cuda.is_available() else "cpu"
max_new_tokens = 512 do_sample = True top_p = 0.9 temperature = 0.6
warmup = 2 repeat = 10
config = LlamaConfig.from_pretrained(checkpoint) model = LlamaForCausalLM.from_pretrained(checkpoint, config=config, torch_dtype=torch.float16)
Draft model
Clone main model with shared weights
weights_memo = {id(w): w for w in model.parameters()} assistant_model = deepcopy(model, memo=weights_memo)
Create early exit version
assistant_model.model.layers = assistant_model.model.layers[:early_exit] del assistant_model.model.layers[early_exit:]
model.to(device) assistant_model.to(device)
tokenizer = LlamaTokenizer.from_pretrained(checkpoint, use_fast=False) inputs = tokenizer(prompt, return_tensors="pt").to(device)
generation_config = { "max_new_tokens": max_new_tokens, "do_sample": do_sample, "top_p": top_p, "temperature": temperature, "pad_token_id": tokenizer.eos_token_id, }
Warmup
print("Warmup") for i in tqdm(range(warmup)): _ = model.generate(**inputs, **generation_config) _ = model.generate(**inputs, **generation_config, assistant_model=assistant_model)
print("Autoregressive Decoding") total_time = 0 total_tokens = 0 for i in tqdm(range(repeat)): start = time() outputs = model.generate(**inputs, **generation_config) total_time += time() - start total_tokens += outputs.numel() print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]) print("\n\t=========================") print(f"\tAverage Generation Time: {total_time / repeat:.2f} s") print(f"\tAverage Tokens per Second: {total_tokens / total_time:.2f} tokens per sec\n\n")
print("Self-Speculative Decoding") total_time = 0 total_tokens = 0 for i in tqdm(range(repeat)): start = time() outputs = model.generate(**inputs, **generation_config, assistant_model=assistant_model) total_time += time() - start total_tokens += outputs.numel() print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]) print("\n\t=========================") print(f"\tAverage Generation Time: {total_time / repeat:.2f} s") print(f"\tAverage Tokens per Second: {total_tokens / total_time:.2f} tokens per sec\n\n")
在配备 `transformers==4.34.1`、`torch==2.2.1` 和 `triton==2.2.0` 的单块 NVIDIA A100 GPU 上运行此脚本,我们得到以下结果:
Autoregressive Decoding ========================= Average Generation Time: 12.60 s Average Tokens per Second: 34.87 tokens per sec
Self-Speculative Decoding ========================= Average Generation Time: 7.38 s Average Tokens per Second: 56.10 tokens per sec
</details>
### 📦 LayerSkip 代码库<a name="custom"></a>
我们在 [github.com/facebookresearch/LayerSkip](https://github.com/facebookresearch/LayerSkip) 上的自推测解码实现有一个优化版本,它不会消耗额外内存,并且在草稿和验证阶段都会复用早期层的权重和 KV 缓存。
要运行该模型,请执行以下操作:
```console
> git clone git@github.com:facebookresearch/LayerSkip.git
> cd LayerSkip
> conda create --name layer_skip python=3.10
> conda activate layer_skip
> pip install -r requirements.txt
> torchrun generate.py --model facebook/layerskip-llama2-7B --generation_strategy self_speculative --exit_layer 6 --num_speculations 4
你可以在 GitHub 仓库中找到更多选项和脚本的详细信息。
💨 gpt-fast
如果你想将我们的解决方案与其他优化(如 torch.compile()
和量化)相结合,我们还在 PyTorch 的 gpt-fast 的一个独立分支 中实现了自推测解码。我们的 gpt-fast 实现经过优化,不会消耗额外内存,并且在草稿和验证阶段都会复用早期层的权重和 KV 缓存。
要运行该模型,请执行以下操作:
> git clone git@github.com:pytorch-labs/gpt-fast.git -b LayerSkip
> cd gpt-fast
> conda create --name gpt_fast python=3.10
> conda activate gpt_fast
> # Install PyTorch (check [here](https://pytorch.org/get-started/locally/) for other hardwares and operating systems)
> pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
> pip install sentencepiece huggingface_hub tiktoken
> mkdir checkpoints
> MODEL_REPO=facebook/layerskip-llama2-7B
> ./scripts/prepare.sh $MODEL_REPO
> python generate.py --compile --checkpoint_path checkpoints/$MODEL_REPO/model.pth --top_k 100 --temperature 0.6 --self_speculative --early_exit 5 --speculate_k 3
📊 基准测试
- 自回归解码: ```console > python generate.py --compile --checkpoint_path checkpoints/$MODEL_REPO/model.pth --top_k 100 --temperature 0.6 ========== Average tokens/sec: 110.50 Memory used: 13.88 GB ``` - 自推测解码: ```console > python generate.py --compile --checkpoint_path checkpoints/$MODEL_REPO/model.pth --top_k 100 --temperature 0.6 --self_speculative --early_exit 5 --speculate_k 3 ========== {'tokens_per_sec': [120.16508373150057, 141.77910376715855, 132.42363092761354, 138.73840444421148, 121.55019835742718], 'accept_counts': [[32, 15, 19, 20], [50, 23, 21, 10], [31, 22, 16, 19], [41, 19, 19, 16], [35, 20, 15, 20], [47, 32, 9, 16]]} Acceptance probs: [0.41622574955908287, 0.2310405643738977, 0.1746031746031746, 0.1781305114638448] Mean Accepted: 1.1146384479717812 Average tokens/sec: 130.93 Memory used: 13.91 GB ```🛠️ 训练
我们的训练实现仍在进行中。你可以查看这个 拉取请求 以获取详细信息和讨论内容。
📈 评估
我们在模型卡片中提供了该模型在各种自然语言和编码任务上的评估结果。你可以在屏幕右上角的侧边栏中查看这些结果。 模型卡片中报告的数值是使用 Eluether 评估工具 和 BigCode 评估工具 进行评估的,而我们论文中提供的数值是使用 Meta 的内部代码库进行评估的。
❗ 问题反馈
请通过以下方式之一报告任何软件“漏洞”或模型相关的其他问题:
- 报告模型问题:https://github.com/facebookresearch/LayerSkip/issues
- 报告模型生成的风险内容:developers.facebook.com/llama_output_feedback
- 报告漏洞和安全问题:facebook.com/whitehat/info
📄 许可证
请参阅 LICENSE 文件。
⚠️ 重要提示
你需要与 Meta 共享联系信息才能访问此模型。
FAIR 非商业研究许可证
最新更新日期:[2024 年 10 月 16 日]
“可接受使用政策”指适用于研究材料的 FAIR 可接受使用政策,该政策已纳入本协议。
“协议”指本协议中规定的研究材料的使用、复制、分发和修改的条款和条件。
“文档”指 Meta 分发的研究材料附带的规格说明、手册和文档。
“被许可方”或“你”指你本人,或你的雇主,或任何其他个人或实体(如果你代表该个人或实体签订本协议),且该个人或实体已达到适用法律、规则或法规要求的提供法律同意的年龄,并且如果你代表其签订本协议,该个人或实体具有约束你的雇主或该其他个人或实体的法律权力。
“Meta”或“我们”指 Meta Platforms Ireland Limited(如果你位于欧洲经济区或瑞士,或者如果你是一个实体,你的主要营业地位于欧洲经济区或瑞士)和 Meta Platforms, Inc.(如果你位于欧洲经济区或瑞士以外的地区)。
“非商业研究用途”指与研究、开发、教育、处理或分析相关的非商业研究用例,并且在每种情况下,其主要目的不是为你或他人带来商业利益或金钱补偿。
“研究材料”指文档以及模型、软件和算法的集合,包括机器学习模型代码、训练好的模型权重、推理启用代码、训练启用代码、微调启用代码、演示材料以及 Meta 分发并根据本协议提供的上述各项的其他元素。
通过点击下面的“我接受”,或者通过使用或分发研究材料的任何部分或元素,你同意受本协议的约束。
-
许可权利和再分发
- 权利授予:你被授予在 Meta 体现在研究材料中的知识产权或其他权利下的非排他性、全球性、不可转让且免版税的有限许可,以使用、复制、分发、拷贝、创作衍生作品并对研究材料进行修改。
- 再分发和使用:
- 你不得将研究材料或研究材料的任何输出或结果用于任何商业用途,或用于非商业研究用途以外的任何用途。
- 研究材料及其任何衍生作品的分发须遵守本协议的条款。如果你将研究材料或其任何衍生作品分发给第三方,你只能根据本协议的条款进行分发。你还应向该第三方提供本协议的副本。
- 如果你发表使用研究材料进行的研究结果,你必须在出版物中承认使用了研究材料。
- 你对研究材料的使用必须遵守适用的法律和法规(包括贸易管制法律),并遵守 FAIR 可接受使用政策,该政策特此通过引用纳入本协议。
-
用户支持:你对研究材料的非商业研究使用由你自行决定;Meta 不会处理任何与该使用相关的信息,也不会提供任何服务。Meta 没有义务为研究材料提供任何支持服务。提供的任何支持均“按现状”、“带有所有缺陷”提供,且不提供任何形式的保证。
-
保证免责声明:除非适用法律要求,研究材料及其任何输出和结果均“按现状”提供,不提供任何形式的保证,Meta 免除所有形式的明示或暗示保证,包括但不限于所有权、不侵权、适销性或特定用途适用性的任何保证。你独自负责确定使用或再分发研究材料的适当性,并承担与你使用研究材料及其任何输出和结果相关的任何风险。
-
责任限制:在任何情况下,Meta 或其关联公司均不对因本协议引起的任何责任理论(无论是合同、侵权、疏忽、产品责任还是其他)下的任何利润损失或任何直接或间接、特殊、后果性、偶发性、示范性或惩罚性损害承担责任,即使 Meta 或其关联公司已被告知存在此类损害的可能性。
-
知识产权
- 除 Meta 对研究材料及其由 Meta 或代表 Meta 制作的衍生作品的所有权外,就你制作的研究材料的任何衍生作品和修改而言,在你和 Meta 之间,你是此类衍生作品和修改的所有者。
- 如果你对 Meta 或任何实体提起诉讼或其他法律程序(包括在诉讼中的交叉索赔或反诉),声称研究材料、输出或结果或上述任何部分构成侵犯你拥有或可许可的知识产权或其他权利,则本协议授予你的任何许可将自该诉讼或索赔提起之日起终止。你将赔偿并使 Meta 免受任何第三方因你使用或分发研究材料而产生的或与之相关的任何索赔的损害。
-
期限和终止:本协议的期限将自你接受本协议或访问研究材料时开始,并将持续有效,直至根据本协议的条款和条件终止。如果你违反本协议的任何条款和条件,Meta 可终止本协议。本协议终止后,你应删除并停止使用研究材料。第 5、6 和 9 条在本协议终止后仍然有效。
-
适用法律和管辖权:本协议将受加利福尼亚州法律管辖并依其解释,不考虑法律选择原则,并且《联合国国际货物销售合同公约》不适用于本协议。加利福尼亚州的法院对因本协议引起的任何争议具有专属管辖权。
-
修改和修订:Meta 可不时通过在 https://huggingface.co/facebook/layerskip-llama2-7B/blob/main/LICENSE 上发布修订版本来修改本协议;前提是这些修订在精神上与本协议的当前版本相似,但在细节上可能有所不同,以解决新的问题或担忧。所有此类更改将立即生效。在本协议进行任何修改后,你继续使用研究材料即表示你同意该修改。除非本协议另有规定,否则对本协议任何条款的修改或补充均不具有约束力,除非该修改或补充以书面形式作出,并由你和 Meta 的授权代表签字。
FAIR 可接受使用政策
Meta 的基础人工智能研究 (FAIR) 团队致力于通过开放研究推动人工智能的发展,以造福所有人,从而进一步理解新的和现有的研究领域。
作为这一使命的一部分,Meta 提供某些研究材料供非商业研究使用。Meta 致力于促进对这些研究材料的安全和负责任的使用。
禁止使用情况
你同意不会使用或允许他人使用研究材料来:
- 违反法律或他人权利,包括:
- 从事、促进、生成、促成、鼓励、策划、煽动或推动非法或违法活动或内容,例如:
- 暴力或恐怖主义
- 对儿童的剥削或伤害,包括征集、创建、获取或传播儿童剥削内容,或未报告儿童性虐待材料
- 人口贩卖、剥削和性暴力
- 向未成年人非法分发信息或材料,包括淫秽材料,或未对此类信息或材料采用法律要求的年龄限制
- 性引诱
- 任何其他犯罪活动
- 从事、促进、煽动或便利对个人或群体的骚扰、虐待、威胁或欺凌
- 从事、促进、煽动或便利在就业、就业福利、信贷、住房、其他经济利益或其他基本商品和服务的提供方面的歧视或其他非法或有害行为
- 从事未经授权或无执照的任何专业实践,包括但不限于金融、法律、医疗/健康或相关专业实践
- 在未获得适用法律要求的权利和同意的情况下收集、处理、披露、生成或推断个人的健康、人口统计或其他敏感个人或私人信息
- 从事或便利任何侵犯、盗用或以其他方式侵犯任何第三方权利的行动或生成任何内容,包括使用 FAIR 研究材料的任何技术的输出或结果
- 创建、生成或便利创建恶意代码、恶意软件、计算机病毒,或做任何可能禁用、使负担过重、干扰或损害网站或计算机系统的正常运行、完整性、操作或外观的事情
- 从事、促进、生成、促成、鼓励、策划、煽动或推动非法或违法活动或内容,例如:
- 从事、促进、煽动、便利或协助策划或开展对个人造成死亡或身体伤害风险的活动,包括使用与以下相关的研究工件:
- 军事、战争、核工业或应用、间谍活动,以及受美国国务院维护的《国际武器贸易条例》(ITAR) 约束的材料或活动
- 枪支和非法武器(包括武器开发)
- 非法药物和受管制/受控物质
- 关键基础设施、运输技术或重型机械的操作
- 自我伤害或伤害他人,包括自杀、自残和饮食失调
- 任何旨在煽动或促进对个人的暴力、虐待或任何身体伤害的内容
- 故意欺骗或误导他人,包括使用与以下相关的 FAIR 研究材料:
- 生成、促进或推动欺诈或创建或推广虚假信息
- 生成、促进或推动诽谤性内容,包括创建诽谤性声明、图像或其他内容
- 生成、促进或进一步分发垃圾邮件
- 在未经同意、授权或合法权利的情况下冒充他人
- 声称 FAIR 研究材料的输出或使用 FAIR 研究材料的技术的输出是人类生成的
- 生成或促进虚假的在线互动,包括虚假评论和其他虚假在线互动方式
- 未能向最终用户适当披露研究材料的任何已知危险。
请通过 此处 提交报告,报告任何违反本政策的行为或可能导致违反本政策的其他问题。
额外信息收集
属性 | 详情 |
---|---|
名字 | 文本输入 |
姓氏 | 文本输入 |
出生日期 | 日期选择器 |
国家 | 国家选择 |
所属机构 | 文本输入 |
地理位置 | IP 定位 |
通过点击下面的“提交”,你接受许可证的条款,并确认你提供的信息将根据 Meta 隐私政策 进行收集、存储、处理和共享。
模型指标
任务类型 | 数据集 | 指标 | 值 | 验证状态 |
---|---|---|---|---|
问答 | google/boolq (BoolQ) | acc | 0.776 | 未验证 |
问答 | ybisk/piqa (PIQA) | acc | 0.775 | 未验证 |
问答 | allenai/social_i_qa (SIQA) | acc | 0.454 | 未验证 |
文本生成 | Rowan/hellaswag (HellaSwag) | acc | 0.567 | 未验证 |
问答 | allenai/winogrande (WinoGrande) | acc | 0.701 | 未验证 |
问答 | allenai/ai2_arc (ARC (Easy)) | acc | 0.765 | 未验证 |
问答 | allenai/ai2_arc (ARC (Challenge)) | acc | 0.437 | 未验证 |
问答 | allenai/openbookqa (OpenBookQA) | acc | 0.328 | 未验证 |
问答 | ehovy/race (RACE) | acc | 0.389 | 未验证 |
问答 | cais/mmlu (MMLU) | acc | 0.376 | 未验证 |
文本生成 | google-research-datasets/nq_open (Natural Questions) | exact_match | 0.156 | 未验证 |
问答 | sentence-transformers/trivia-qa (TriviaQA) | acc | 0.529 | 未验证 |
文本生成 | openai/gsm8k (GSM8K) | exact_match | 0.121 | 未验证 |
问答 | allenai/math_qa (MathQA) | acc | 0.276 | 未验证 |
问答 | rajpurkar/squad_v2 (SQuAD2.0) | exact | 0.164 | 未验证 |
文本分类 | toxigen/toxigen-data (ToxiGen) | acc | 0.428 | 未验证 |
文本生成 | openai_humaneval (HumanEval) | pass@1 | 0.134 | 未验证 |
文本生成 | mbpp (MBPP) | pass@1 | 0.19 | 未验证 |



