模型简介
模型特点
模型能力
使用案例
🚀 Llama-3.1-8B-tldr
Llama-3.1-8B-tldr 是基于 meta-llama/Llama-3.1-8B 在 trl-lib/tldr 数据集上微调得到的模型,可高效总结 Reddit 帖子,在测试集上取得了 0.366 的 BERTScore。
🚀 快速开始
本模型可使用 vLLM 进行高效部署,示例如下。
运行以下命令启动 vLLM 服务器:
vllm serve RedHatAI/Llama-3.1-8B-tldr
服务器启动后,你可以使用 OpenAI API 查询模型:
from openai import OpenAI
openai_api_key = "EMPTY"
openai_api_base = "http://localhost:8000/v1"
client = OpenAI(
api_key=openai_api_key,
base_url=openai_api_base,
)
post="""
SUBREDDIT: r/AI
TITLE: Training sparse LLMs
POST: Now you can use the llm-compressor integration to axolotl to train sparse LLMs!
It's super easy to use. See the example in https://huggingface.co/RedHatAI/Sparse-Llama-3.1-8B-tldr-2of4.
And there's more. You can run 2:4 sparse models on vLLM and get significant speedupts on Hopper GPUs!
"""
prompt = f"Give a TL;DR of the following Reddit post.\n<|user|>{post}\nTL;DR:\n<|assistant|>\n"
completion = client.completions.create(
model="RedHatAI/Llama-3.1-8B-tldr",
prompt=prompt,
max_tokens=256,
)
print("Completion result:", completion)
✨ 主要特性
- 高效总结:能够对 Reddit 帖子进行高效总结,在测试集上取得了 0.366 的 BERTScore。
- 易于部署:可以使用 vLLM 进行高效部署。
📦 安装指南
本模型可使用 vLLM 进行部署,运行以下命令启动 vLLM 服务器:
vllm serve RedHatAI/Llama-3.1-8B-tldr
💻 使用示例
基础用法
from openai import OpenAI
openai_api_key = "EMPTY"
openai_api_base = "http://localhost:8000/v1"
client = OpenAI(
api_key=openai_api_key,
base_url=openai_api_base,
)
post="""
SUBREDDIT: r/AI
TITLE: Training sparse LLMs
POST: Now you can use the llm-compressor integration to axolotl to train sparse LLMs!
It's super easy to use. See the example in https://huggingface.co/RedHatAI/Sparse-Llama-3.1-8B-tldr-2of4.
And there's more. You can run 2:4 sparse models on vLLM and get significant speedupts on Hopper GPUs!
"""
prompt = f"Give a TL;DR of the following Reddit post.\n<|user|>{post}\nTL;DR:\n<|assistant|>\n"
completion = client.completions.create(
model="RedHatAI/Llama-3.1-8B-tldr",
prompt=prompt,
max_tokens=256,
)
print("Completion result:", completion)
📚 详细文档
模型概述
- 模型架构:LlamaForCausalLM
- 输入:文本
- 输出:文本
- 发布日期:2025 年 5 月 29 日
- 版本:1.0
- 预期用例:该模型经过微调,可按照 Reddit 帖子的风格总结文本。
- 适用范围外:禁止以任何违反适用法律法规(包括贸易合规法律)的方式使用。禁止以《可接受使用政策》和《Llama 3.1 社区许可证》禁止的任何其他方式使用。
- 模型开发者:Red Hat (Neural Magic)
本模型是 meta-llama/Llama-3.1-8B 在 trl-lib/tldr 数据集上的微调版本。该模型在 trl-lib/tldr 测试集上的 BERTScore 为 0.366。
训练
查看 axolotl 配置
axolotl 版本:0.10.0.dev0
base_model: meta-llama/Llama-3.1-8B
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: trl-lib/tldr
type:
system_prompt: "Give a TL;DR of the following Reddit post."
field_system: system
field_instruction: prompt
field_output: completion
format: "<|user|>\n{instruction}\n<|assistant|>\n"
no_input_format: "<|user|>\n{instruction}\n<|assistant|>\n"
split: train
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
eval_sample_packing: true
torch.compile: true
gradient_accumulation_steps: 1
micro_batch_size: 4
num_epochs: 3
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 1e-5
max_grad_norm: 1
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
train_on_inputs: false
bf16: auto
fp16:
tf32: false
early_stopping_patience:
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.05
evals_per_epoch: 4
val_set_size: 0.05
save_strategy: "best"
save_total_limit: 1
metric_for_best_model: "loss"
debug:
deepspeed:
weight_decay: 0.0
special_tokens:
pad_token: "<|end_of_text|>"
seed: 0
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true
训练超参数
训练期间使用了以下超参数:
- 学习率:1e-05
- 训练批次大小:4
- 评估批次大小:4
- 随机种子:0
- 分布式类型:多 GPU
- 设备数量:8
- 总训练批次大小:32
- 总评估批次大小:32
- 优化器:使用 adamw_bnb_8bit,β=(0.9, 0.999),ε=1e-08,无额外优化器参数
- 学习率调度器类型:余弦
- 学习率调度器热身步数:49
- 训练轮数:3.0
训练结果
训练损失 | 轮数 | 步数 | 验证损失 |
---|---|---|---|
2.2572 | 0.0031 | 1 | 2.2288 |
1.7865 | 0.2508 | 82 | 1.7680 |
1.7257 | 0.5015 | 164 | 1.7567 |
1.7343 | 0.7523 | 246 | 1.7489 |
1.7688 | 1.0031 | 328 | 1.7441 |
1.6822 | 1.2538 | 410 | 1.7493 |
1.6085 | 1.5046 | 492 | 1.7480 |
1.6627 | 1.7554 | 574 | 1.7444 |
1.729 | 2.0061 | 656 | 1.7426 |
1.6149 | 2.2569 | 738 | 1.7540 |
1.6002 | 2.5076 | 820 | 1.7537 |
1.6573 | 2.7584 | 902 | 1.7526 |
框架版本
- Transformers 4.51.3
- Pytorch 2.7.0+cu126
- Datasets 3.5.1
- Tokenizers 0.21.1
评估
该模型使用 lm-evaluation-harness(tldr 分支)的 Neural Magic 分支,在 trl-lib/tldr 数据集的测试分割上进行了评估。可以使用以下命令重现这些结果:
lm_eval --model vllm --model_args "pretrained=RedHatAI/Llama-3.1-8B-tldr,dtype=auto,add_bos_token=True" --batch-size auto --tasks tldr
指标 | Llama-3.1-8B-Instruct | Llama-3.1-8B-tldr | Sparse-Llama-3.1-8B-tldr-2of4(本模型) |
---|---|---|---|
BERTScore | -0.230 | 0.366 | 0.366 |
ROUGE-1 | 0.059 | 0.362 | 0.357 |
ROUGE-2 | 0.018 | 0.144 | 0.141 |
ROUGE-Lsum | 0.051 | 0.306 | 0.304 |
推理性能
我们使用 trl-lib/tldr 数据集训练集的前 1000 个样本评估了该模型的推理性能。基准测试使用 vLLM 版本 0.9.0.1
和 GuideLLM 版本 0.2.1
进行。
下图展示了不同请求速率下的每个请求的平均端到端延迟。结果显示了本模型以及两个变体的情况:
重现说明
要重现基准测试:
- 生成包含前 1000 个训练样本的 JSON 文件:
from datasets import load_dataset
ds = load_dataset("trl-lib/tldr", split="train").take(1000)
ds.to_json("tldr_1000.json")
- 使用目标模型启动 vLLM 服务器:
vllm serve RedHatAI/Llama-3.1-8B-tldr
- 使用 GuideLLM 运行基准测试:
GUIDELLM__OPENAI__MAX_OUTPUT_TOKENS=128 guidellm benchmark --target "http://localhost:8000" --rate-type sweep --data tldr_1000.json
平均输出长度约为每个样本 30 个标记。我们将生成限制为 128 个标记,以减少罕见的、异常冗长的完成对性能的影响。
🔧 技术细节
本模型基于 LlamaForCausalLM 架构,在 trl-lib/tldr 数据集上进行微调。训练过程中使用了 axolotl 框架,具体配置和超参数可参考上述训练部分的详细文档。
📄 许可证
本模型使用 llama3.1 许可证。



