模型概述
模型特點
模型能力
使用案例
🚀 LayerSkip Llama2 7B
LayerSkip Llama2 7B 模型基於 Layer Skip: Enabling Early Exit Inference and Self-Speculative Decoding 中提出的 LayerSkip 方法進行持續預訓練,能夠執行自推測解碼,即利用前面的層進行解碼,並使用其餘層進行驗證。
🚀 快速開始
🔍 模型使用方式
我們提供了 3 種運行模型的方式:
🤗 HuggingFace
HuggingFace 目前尚不支持自推測解碼。不過,我們可以通過使用主模型的部分層創建一個草稿模型,來複用其推測解碼功能:
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
>>> import torch
>>> from copy import deepcopy
>>> checkpoint = "facebook/layerskip-llama2-7B"
>>> early_exit = 4
>>> device = "cuda" if torch.cuda.is_available() else "cpu"
>>> prompt = "typing import List\ndef bucket_sort(A: List):"
>>> model = AutoModelForCausalLM.from_pretrained(checkpoint)
>>> tokenizer = AutoTokenizer.from_pretrained(checkpoint)
>>> generation_config = model.generation_config
>>> weights_memo = {id(w): w for w in model.parameters()}
>>> assistant_model = deepcopy(model, memo=weights_memo) # Clone main model with shared weights
>>> assistant_model.model.layers = assistant_model.model.layers[:early_exit] # Apply early exit
>>> del assistant_model.model.layers[early_exit:]
>>> model.to(device)
>>> assistant_model.to(device)
>>> inputs = tokenizer(prompt, return_tensors="pt").to(device)
>>> outputs = model.generate(**inputs, assistant_model=assistant_model, generation_config=generation_config)
>>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
請注意,這並非最優實現,因為它需要更多內存來保存重複層的 KV 緩存和激活值。複用早期層的優化實現可在我們的 自定義實現 或 gpt-fast 實現 中找到。
📊 基準測試
如果你想衡量自推測解碼和自迴歸解碼之間的速度提升,我們編寫了以下腳本: ```python from transformers import LlamaForCausalLM, LlamaConfig, LlamaTokenizer, GenerationConfig import torch from copy import deepcopy from time import time from tqdm import tqdmprompt = "typing import List\ndef bucket_sort(A: List):"
checkpoint = "facebook/layerskip-llama2-7B" early_exit = 4 device = "cuda" if torch.cuda.is_available() else "cpu"
max_new_tokens = 512 do_sample = True top_p = 0.9 temperature = 0.6
warmup = 2 repeat = 10
config = LlamaConfig.from_pretrained(checkpoint) model = LlamaForCausalLM.from_pretrained(checkpoint, config=config, torch_dtype=torch.float16)
Draft model
Clone main model with shared weights
weights_memo = {id(w): w for w in model.parameters()} assistant_model = deepcopy(model, memo=weights_memo)
Create early exit version
assistant_model.model.layers = assistant_model.model.layers[:early_exit] del assistant_model.model.layers[early_exit:]
model.to(device) assistant_model.to(device)
tokenizer = LlamaTokenizer.from_pretrained(checkpoint, use_fast=False) inputs = tokenizer(prompt, return_tensors="pt").to(device)
generation_config = { "max_new_tokens": max_new_tokens, "do_sample": do_sample, "top_p": top_p, "temperature": temperature, "pad_token_id": tokenizer.eos_token_id, }
Warmup
print("Warmup") for i in tqdm(range(warmup)): _ = model.generate(**inputs, **generation_config) _ = model.generate(**inputs, **generation_config, assistant_model=assistant_model)
print("Autoregressive Decoding") total_time = 0 total_tokens = 0 for i in tqdm(range(repeat)): start = time() outputs = model.generate(**inputs, **generation_config) total_time += time() - start total_tokens += outputs.numel() print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]) print("\n\t=========================") print(f"\tAverage Generation Time: {total_time / repeat:.2f} s") print(f"\tAverage Tokens per Second: {total_tokens / total_time:.2f} tokens per sec\n\n")
print("Self-Speculative Decoding") total_time = 0 total_tokens = 0 for i in tqdm(range(repeat)): start = time() outputs = model.generate(**inputs, **generation_config, assistant_model=assistant_model) total_time += time() - start total_tokens += outputs.numel() print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]) print("\n\t=========================") print(f"\tAverage Generation Time: {total_time / repeat:.2f} s") print(f"\tAverage Tokens per Second: {total_tokens / total_time:.2f} tokens per sec\n\n")
在配備 `transformers==4.34.1`、`torch==2.2.1` 和 `triton==2.2.0` 的單塊 NVIDIA A100 GPU 上運行此腳本,我們得到以下結果:
Autoregressive Decoding ========================= Average Generation Time: 12.60 s Average Tokens per Second: 34.87 tokens per sec
Self-Speculative Decoding ========================= Average Generation Time: 7.38 s Average Tokens per Second: 56.10 tokens per sec
</details>
### 📦 LayerSkip 代碼庫<a name="custom"></a>
我們在 [github.com/facebookresearch/LayerSkip](https://github.com/facebookresearch/LayerSkip) 上的自推測解碼實現有一個優化版本,它不會消耗額外內存,並且在草稿和驗證階段都會複用早期層的權重和 KV 緩存。
要運行該模型,請執行以下操作:
```console
> git clone git@github.com:facebookresearch/LayerSkip.git
> cd LayerSkip
> conda create --name layer_skip python=3.10
> conda activate layer_skip
> pip install -r requirements.txt
> torchrun generate.py --model facebook/layerskip-llama2-7B --generation_strategy self_speculative --exit_layer 6 --num_speculations 4
你可以在 GitHub 倉庫中找到更多選項和腳本的詳細信息。
💨 gpt-fast
如果你想將我們的解決方案與其他優化(如 torch.compile()
和量化)相結合,我們還在 PyTorch 的 gpt-fast 的一個獨立分支 中實現了自推測解碼。我們的 gpt-fast 實現經過優化,不會消耗額外內存,並且在草稿和驗證階段都會複用早期層的權重和 KV 緩存。
要運行該模型,請執行以下操作:
> git clone git@github.com:pytorch-labs/gpt-fast.git -b LayerSkip
> cd gpt-fast
> conda create --name gpt_fast python=3.10
> conda activate gpt_fast
> # Install PyTorch (check [here](https://pytorch.org/get-started/locally/) for other hardwares and operating systems)
> pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
> pip install sentencepiece huggingface_hub tiktoken
> mkdir checkpoints
> MODEL_REPO=facebook/layerskip-llama2-7B
> ./scripts/prepare.sh $MODEL_REPO
> python generate.py --compile --checkpoint_path checkpoints/$MODEL_REPO/model.pth --top_k 100 --temperature 0.6 --self_speculative --early_exit 5 --speculate_k 3
📊 基準測試
- 自迴歸解碼: ```console > python generate.py --compile --checkpoint_path checkpoints/$MODEL_REPO/model.pth --top_k 100 --temperature 0.6 ========== Average tokens/sec: 110.50 Memory used: 13.88 GB ``` - 自推測解碼: ```console > python generate.py --compile --checkpoint_path checkpoints/$MODEL_REPO/model.pth --top_k 100 --temperature 0.6 --self_speculative --early_exit 5 --speculate_k 3 ========== {'tokens_per_sec': [120.16508373150057, 141.77910376715855, 132.42363092761354, 138.73840444421148, 121.55019835742718], 'accept_counts': [[32, 15, 19, 20], [50, 23, 21, 10], [31, 22, 16, 19], [41, 19, 19, 16], [35, 20, 15, 20], [47, 32, 9, 16]]} Acceptance probs: [0.41622574955908287, 0.2310405643738977, 0.1746031746031746, 0.1781305114638448] Mean Accepted: 1.1146384479717812 Average tokens/sec: 130.93 Memory used: 13.91 GB ```🛠️ 訓練
我們的訓練實現仍在進行中。你可以查看這個 拉取請求 以獲取詳細信息和討論內容。
📈 評估
我們在模型卡片中提供了該模型在各種自然語言和編碼任務上的評估結果。你可以在屏幕右上角的側邊欄中查看這些結果。 模型卡片中報告的數值是使用 Eluether 評估工具 和 BigCode 評估工具 進行評估的,而我們論文中提供的數值是使用 Meta 的內部代碼庫進行評估的。
❗ 問題反饋
請通過以下方式之一報告任何軟件“漏洞”或模型相關的其他問題:
- 報告模型問題:https://github.com/facebookresearch/LayerSkip/issues
- 報告模型生成的風險內容:developers.facebook.com/llama_output_feedback
- 報告漏洞和安全問題:facebook.com/whitehat/info
📄 許可證
請參閱 LICENSE 文件。
⚠️ 重要提示
你需要與 Meta 共享聯繫信息才能訪問此模型。
FAIR 非商業研究許可證
最新更新日期:[2024 年 10 月 16 日]
“可接受使用政策”指適用於研究材料的 FAIR 可接受使用政策,該政策已納入本協議。
“協議”指本協議中規定的研究材料的使用、複製、分發和修改的條款和條件。
“文檔”指 Meta 分發的研究材料附帶的規格說明、手冊和文檔。
“被許可方”或“你”指你本人,或你的僱主,或任何其他個人或實體(如果你代表該個人或實體簽訂本協議),且該個人或實體已達到適用法律、規則或法規要求的提供法律同意的年齡,並且如果你代表其簽訂本協議,該個人或實體具有約束你的僱主或該其他個人或實體的法律權力。
“Meta”或“我們”指 Meta Platforms Ireland Limited(如果你位於歐洲經濟區或瑞士,或者如果你是一個實體,你的主要營業地位於歐洲經濟區或瑞士)和 Meta Platforms, Inc.(如果你位於歐洲經濟區或瑞士以外的地區)。
“非商業研究用途”指與研究、開發、教育、處理或分析相關的非商業研究用例,並且在每種情況下,其主要目的不是為你或他人帶來商業利益或金錢補償。
“研究材料”指文檔以及模型、軟件和算法的集合,包括機器學習模型代碼、訓練好的模型權重、推理啟用代碼、訓練啟用代碼、微調啟用代碼、演示材料以及 Meta 分發並根據本協議提供的上述各項的其他元素。
通過點擊下面的“我接受”,或者通過使用或分發研究材料的任何部分或元素,你同意受本協議的約束。
-
許可權利和再分發
- 權利授予:你被授予在 Meta 體現在研究材料中的知識產權或其他權利下的非排他性、全球性、不可轉讓且免版稅的有限許可,以使用、複製、分發、拷貝、創作衍生作品並對研究材料進行修改。
- 再分發和使用:
- 你不得將研究材料或研究材料的任何輸出或結果用於任何商業用途,或用於非商業研究用途以外的任何用途。
- 研究材料及其任何衍生作品的分發須遵守本協議的條款。如果你將研究材料或其任何衍生作品分發給第三方,你只能根據本協議的條款進行分發。你還應向該第三方提供本協議的副本。
- 如果你發表使用研究材料進行的研究結果,你必須在出版物中承認使用了研究材料。
- 你對研究材料的使用必須遵守適用的法律和法規(包括貿易管制法律),並遵守 FAIR 可接受使用政策,該政策特此通過引用納入本協議。
-
用戶支持:你對研究材料的非商業研究使用由你自行決定;Meta 不會處理任何與該使用相關的信息,也不會提供任何服務。Meta 沒有義務為研究材料提供任何支持服務。提供的任何支持均“按現狀”、“帶有所有缺陷”提供,且不提供任何形式的保證。
-
保證免責聲明:除非適用法律要求,研究材料及其任何輸出和結果均“按現狀”提供,不提供任何形式的保證,Meta 免除所有形式的明示或暗示保證,包括但不限於所有權、不侵權、適銷性或特定用途適用性的任何保證。你獨自負責確定使用或再分發研究材料的適當性,並承擔與你使用研究材料及其任何輸出和結果相關的任何風險。
-
責任限制:在任何情況下,Meta 或其關聯公司均不對因本協議引起的任何責任理論(無論是合同、侵權、疏忽、產品責任還是其他)下的任何利潤損失或任何直接或間接、特殊、後果性、偶發性、示範性或懲罰性損害承擔責任,即使 Meta 或其關聯公司已被告知存在此類損害的可能性。
-
知識產權
- 除 Meta 對研究材料及其由 Meta 或代表 Meta 製作的衍生作品的所有權外,就你製作的研究材料的任何衍生作品和修改而言,在你和 Meta 之間,你是此類衍生作品和修改的所有者。
- 如果你對 Meta 或任何實體提起訴訟或其他法律程序(包括在訴訟中的交叉索賠或反訴),聲稱研究材料、輸出或結果或上述任何部分構成侵犯你擁有或可許可的知識產權或其他權利,則本協議授予你的任何許可將自該訴訟或索賠提起之日起終止。你將賠償並使 Meta 免受任何第三方因你使用或分發研究材料而產生的或與之相關的任何索賠的損害。
-
期限和終止:本協議的期限將自你接受本協議或訪問研究材料時開始,並將持續有效,直至根據本協議的條款和條件終止。如果你違反本協議的任何條款和條件,Meta 可終止本協議。本協議終止後,你應刪除並停止使用研究材料。第 5、6 和 9 條在本協議終止後仍然有效。
-
適用法律和管轄權:本協議將受加利福尼亞州法律管轄並依其解釋,不考慮法律選擇原則,並且《聯合國國際貨物銷售合同公約》不適用於本協議。加利福尼亞州的法院對因本協議引起的任何爭議具有專屬管轄權。
-
修改和修訂:Meta 可不時通過在 https://huggingface.co/facebook/layerskip-llama2-7B/blob/main/LICENSE 上發佈修訂版本來修改本協議;前提是這些修訂在精神上與本協議的當前版本相似,但在細節上可能有所不同,以解決新的問題或擔憂。所有此類更改將立即生效。在本協議進行任何修改後,你繼續使用研究材料即表示你同意該修改。除非本協議另有規定,否則對本協議任何條款的修改或補充均不具有約束力,除非該修改或補充以書面形式作出,並由你和 Meta 的授權代表簽字。
FAIR 可接受使用政策
Meta 的基礎人工智能研究 (FAIR) 團隊致力於通過開放研究推動人工智能的發展,以造福所有人,從而進一步理解新的和現有的研究領域。
作為這一使命的一部分,Meta 提供某些研究材料供非商業研究使用。Meta 致力於促進對這些研究材料的安全和負責任的使用。
禁止使用情況
你同意不會使用或允許他人使用研究材料來:
- 違反法律或他人權利,包括:
- 從事、促進、生成、促成、鼓勵、策劃、煽動或推動非法或違法活動或內容,例如:
- 暴力或恐怖主義
- 對兒童的剝削或傷害,包括徵集、創建、獲取或傳播兒童剝削內容,或未報告兒童性虐待材料
- 人口販賣、剝削和性暴力
- 向未成年人非法分發信息或材料,包括淫穢材料,或未對此類信息或材料採用法律要求的年齡限制
- 性引誘
- 任何其他犯罪活動
- 從事、促進、煽動或便利對個人或群體的騷擾、虐待、威脅或欺凌
- 從事、促進、煽動或便利在就業、就業福利、信貸、住房、其他經濟利益或其他基本商品和服務的提供方面的歧視或其他非法或有害行為
- 從事未經授權或無執照的任何專業實踐,包括但不限於金融、法律、醫療/健康或相關專業實踐
- 在未獲得適用法律要求的權利和同意的情況下收集、處理、披露、生成或推斷個人的健康、人口統計或其他敏感個人或私人信息
- 從事或便利任何侵犯、盜用或以其他方式侵犯任何第三方權利的行動或生成任何內容,包括使用 FAIR 研究材料的任何技術的輸出或結果
- 創建、生成或便利創建惡意代碼、惡意軟件、計算機病毒,或做任何可能禁用、使負擔過重、干擾或損害網站或計算機系統的正常運行、完整性、操作或外觀的事情
- 從事、促進、生成、促成、鼓勵、策劃、煽動或推動非法或違法活動或內容,例如:
- 從事、促進、煽動、便利或協助策劃或開展對個人造成死亡或身體傷害風險的活動,包括使用與以下相關的研究工件:
- 軍事、戰爭、核工業或應用、間諜活動,以及受美國國務院維護的《國際武器貿易條例》(ITAR) 約束的材料或活動
- 槍支和非法武器(包括武器開發)
- 非法藥物和受管制/受控物質
- 關鍵基礎設施、運輸技術或重型機械的操作
- 自我傷害或傷害他人,包括自殺、自殘和飲食失調
- 任何旨在煽動或促進對個人的暴力、虐待或任何身體傷害的內容
- 故意欺騙或誤導他人,包括使用與以下相關的 FAIR 研究材料:
- 生成、促進或推動欺詐或創建或推廣虛假信息
- 生成、促進或推動誹謗性內容,包括創建誹謗性聲明、圖像或其他內容
- 生成、促進或進一步分發垃圾郵件
- 在未經同意、授權或合法權利的情況下冒充他人
- 聲稱 FAIR 研究材料的輸出或使用 FAIR 研究材料的技術的輸出是人類生成的
- 生成或促進虛假的在線互動,包括虛假評論和其他虛假在線互動方式
- 未能向最終用戶適當披露研究材料的任何已知危險。
請通過 此處 提交報告,報告任何違反本政策的行為或可能導致違反本政策的其他問題。
額外信息收集
屬性 | 詳情 |
---|---|
名字 | 文本輸入 |
姓氏 | 文本輸入 |
出生日期 | 日期選擇器 |
國家 | 國家選擇 |
所屬機構 | 文本輸入 |
地理位置 | IP 定位 |
通過點擊下面的“提交”,你接受許可證的條款,並確認你提供的信息將根據 Meta 隱私政策 進行收集、存儲、處理和共享。
模型指標
任務類型 | 數據集 | 指標 | 值 | 驗證狀態 |
---|---|---|---|---|
問答 | google/boolq (BoolQ) | acc | 0.776 | 未驗證 |
問答 | ybisk/piqa (PIQA) | acc | 0.775 | 未驗證 |
問答 | allenai/social_i_qa (SIQA) | acc | 0.454 | 未驗證 |
文本生成 | Rowan/hellaswag (HellaSwag) | acc | 0.567 | 未驗證 |
問答 | allenai/winogrande (WinoGrande) | acc | 0.701 | 未驗證 |
問答 | allenai/ai2_arc (ARC (Easy)) | acc | 0.765 | 未驗證 |
問答 | allenai/ai2_arc (ARC (Challenge)) | acc | 0.437 | 未驗證 |
問答 | allenai/openbookqa (OpenBookQA) | acc | 0.328 | 未驗證 |
問答 | ehovy/race (RACE) | acc | 0.389 | 未驗證 |
問答 | cais/mmlu (MMLU) | acc | 0.376 | 未驗證 |
文本生成 | google-research-datasets/nq_open (Natural Questions) | exact_match | 0.156 | 未驗證 |
問答 | sentence-transformers/trivia-qa (TriviaQA) | acc | 0.529 | 未驗證 |
文本生成 | openai/gsm8k (GSM8K) | exact_match | 0.121 | 未驗證 |
問答 | allenai/math_qa (MathQA) | acc | 0.276 | 未驗證 |
問答 | rajpurkar/squad_v2 (SQuAD2.0) | exact | 0.164 | 未驗證 |
文本分類 | toxigen/toxigen-data (ToxiGen) | acc | 0.428 | 未驗證 |
文本生成 | openai_humaneval (HumanEval) | pass@1 | 0.134 | 未驗證 |
文本生成 | mbpp (MBPP) | pass@1 | 0.19 | 未驗證 |



