模型概述
模型特點
模型能力
使用案例
🚀 Llama-3.1-8B-tldr
Llama-3.1-8B-tldr 是基於 meta-llama/Llama-3.1-8B 在 trl-lib/tldr 數據集上微調得到的模型,可高效總結 Reddit 帖子,在測試集上取得了 0.366 的 BERTScore。
🚀 快速開始
本模型可使用 vLLM 進行高效部署,示例如下。
運行以下命令啟動 vLLM 服務器:
vllm serve RedHatAI/Llama-3.1-8B-tldr
服務器啟動後,你可以使用 OpenAI API 查詢模型:
from openai import OpenAI
openai_api_key = "EMPTY"
openai_api_base = "http://localhost:8000/v1"
client = OpenAI(
api_key=openai_api_key,
base_url=openai_api_base,
)
post="""
SUBREDDIT: r/AI
TITLE: Training sparse LLMs
POST: Now you can use the llm-compressor integration to axolotl to train sparse LLMs!
It's super easy to use. See the example in https://huggingface.co/RedHatAI/Sparse-Llama-3.1-8B-tldr-2of4.
And there's more. You can run 2:4 sparse models on vLLM and get significant speedupts on Hopper GPUs!
"""
prompt = f"Give a TL;DR of the following Reddit post.\n<|user|>{post}\nTL;DR:\n<|assistant|>\n"
completion = client.completions.create(
model="RedHatAI/Llama-3.1-8B-tldr",
prompt=prompt,
max_tokens=256,
)
print("Completion result:", completion)
✨ 主要特性
- 高效總結:能夠對 Reddit 帖子進行高效總結,在測試集上取得了 0.366 的 BERTScore。
- 易於部署:可以使用 vLLM 進行高效部署。
📦 安裝指南
本模型可使用 vLLM 進行部署,運行以下命令啟動 vLLM 服務器:
vllm serve RedHatAI/Llama-3.1-8B-tldr
💻 使用示例
基礎用法
from openai import OpenAI
openai_api_key = "EMPTY"
openai_api_base = "http://localhost:8000/v1"
client = OpenAI(
api_key=openai_api_key,
base_url=openai_api_base,
)
post="""
SUBREDDIT: r/AI
TITLE: Training sparse LLMs
POST: Now you can use the llm-compressor integration to axolotl to train sparse LLMs!
It's super easy to use. See the example in https://huggingface.co/RedHatAI/Sparse-Llama-3.1-8B-tldr-2of4.
And there's more. You can run 2:4 sparse models on vLLM and get significant speedupts on Hopper GPUs!
"""
prompt = f"Give a TL;DR of the following Reddit post.\n<|user|>{post}\nTL;DR:\n<|assistant|>\n"
completion = client.completions.create(
model="RedHatAI/Llama-3.1-8B-tldr",
prompt=prompt,
max_tokens=256,
)
print("Completion result:", completion)
📚 詳細文檔
模型概述
- 模型架構:LlamaForCausalLM
- 輸入:文本
- 輸出:文本
- 發佈日期:2025 年 5 月 29 日
- 版本:1.0
- 預期用例:該模型經過微調,可按照 Reddit 帖子的風格總結文本。
- 適用範圍外:禁止以任何違反適用法律法規(包括貿易合規法律)的方式使用。禁止以《可接受使用政策》和《Llama 3.1 社區許可證》禁止的任何其他方式使用。
- 模型開發者:Red Hat (Neural Magic)
本模型是 meta-llama/Llama-3.1-8B 在 trl-lib/tldr 數據集上的微調版本。該模型在 trl-lib/tldr 測試集上的 BERTScore 為 0.366。
訓練
查看 axolotl 配置
axolotl 版本:0.10.0.dev0
base_model: meta-llama/Llama-3.1-8B
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: trl-lib/tldr
type:
system_prompt: "Give a TL;DR of the following Reddit post."
field_system: system
field_instruction: prompt
field_output: completion
format: "<|user|>\n{instruction}\n<|assistant|>\n"
no_input_format: "<|user|>\n{instruction}\n<|assistant|>\n"
split: train
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
eval_sample_packing: true
torch.compile: true
gradient_accumulation_steps: 1
micro_batch_size: 4
num_epochs: 3
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 1e-5
max_grad_norm: 1
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
train_on_inputs: false
bf16: auto
fp16:
tf32: false
early_stopping_patience:
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.05
evals_per_epoch: 4
val_set_size: 0.05
save_strategy: "best"
save_total_limit: 1
metric_for_best_model: "loss"
debug:
deepspeed:
weight_decay: 0.0
special_tokens:
pad_token: "<|end_of_text|>"
seed: 0
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true
訓練超參數
訓練期間使用了以下超參數:
- 學習率:1e-05
- 訓練批次大小:4
- 評估批次大小:4
- 隨機種子:0
- 分佈式類型:多 GPU
- 設備數量:8
- 總訓練批次大小:32
- 總評估批次大小:32
- 優化器:使用 adamw_bnb_8bit,β=(0.9, 0.999),ε=1e-08,無額外優化器參數
- 學習率調度器類型:餘弦
- 學習率調度器熱身步數:49
- 訓練輪數:3.0
訓練結果
訓練損失 | 輪數 | 步數 | 驗證損失 |
---|---|---|---|
2.2572 | 0.0031 | 1 | 2.2288 |
1.7865 | 0.2508 | 82 | 1.7680 |
1.7257 | 0.5015 | 164 | 1.7567 |
1.7343 | 0.7523 | 246 | 1.7489 |
1.7688 | 1.0031 | 328 | 1.7441 |
1.6822 | 1.2538 | 410 | 1.7493 |
1.6085 | 1.5046 | 492 | 1.7480 |
1.6627 | 1.7554 | 574 | 1.7444 |
1.729 | 2.0061 | 656 | 1.7426 |
1.6149 | 2.2569 | 738 | 1.7540 |
1.6002 | 2.5076 | 820 | 1.7537 |
1.6573 | 2.7584 | 902 | 1.7526 |
框架版本
- Transformers 4.51.3
- Pytorch 2.7.0+cu126
- Datasets 3.5.1
- Tokenizers 0.21.1
評估
該模型使用 lm-evaluation-harness(tldr 分支)的 Neural Magic 分支,在 trl-lib/tldr 數據集的測試分割上進行了評估。可以使用以下命令重現這些結果:
lm_eval --model vllm --model_args "pretrained=RedHatAI/Llama-3.1-8B-tldr,dtype=auto,add_bos_token=True" --batch-size auto --tasks tldr
指標 | Llama-3.1-8B-Instruct | Llama-3.1-8B-tldr | Sparse-Llama-3.1-8B-tldr-2of4(本模型) |
---|---|---|---|
BERTScore | -0.230 | 0.366 | 0.366 |
ROUGE-1 | 0.059 | 0.362 | 0.357 |
ROUGE-2 | 0.018 | 0.144 | 0.141 |
ROUGE-Lsum | 0.051 | 0.306 | 0.304 |
推理性能
我們使用 trl-lib/tldr 數據集訓練集的前 1000 個樣本評估了該模型的推理性能。基準測試使用 vLLM 版本 0.9.0.1
和 GuideLLM 版本 0.2.1
進行。
下圖展示了不同請求速率下的每個請求的平均端到端延遲。結果顯示了本模型以及兩個變體的情況:
重現說明
要重現基準測試:
- 生成包含前 1000 個訓練樣本的 JSON 文件:
from datasets import load_dataset
ds = load_dataset("trl-lib/tldr", split="train").take(1000)
ds.to_json("tldr_1000.json")
- 使用目標模型啟動 vLLM 服務器:
vllm serve RedHatAI/Llama-3.1-8B-tldr
- 使用 GuideLLM 運行基準測試:
GUIDELLM__OPENAI__MAX_OUTPUT_TOKENS=128 guidellm benchmark --target "http://localhost:8000" --rate-type sweep --data tldr_1000.json
平均輸出長度約為每個樣本 30 個標記。我們將生成限制為 128 個標記,以減少罕見的、異常冗長的完成對性能的影響。
🔧 技術細節
本模型基於 LlamaForCausalLM 架構,在 trl-lib/tldr 數據集上進行微調。訓練過程中使用了 axolotl 框架,具體配置和超參數可參考上述訓練部分的詳細文檔。
📄 許可證
本模型使用 llama3.1 許可證。



