モデル概要
モデル特徴
モデル能力
使用事例
🚀 Llama-3.1-8B-tldr
このモデルは、Reddit投稿のスタイルでテキストを要約するようにファインチューニングされたモデルです。meta-llama/Llama-3.1-8B を trl-lib/tldr データセットでファインチューニングしたもので、trl-lib/tldrのテストセットで0.366のBERTScoreを達成しています。
🚀 クイックスタート
このモデルは、vLLM を使用して効率的にデプロイできます。以下の手順でモデルを使用することができます。
サーバーの起動
次のコマンドを実行して、vLLMサーバーを起動します。
vllm serve RedHatAI/Llama-3.1-8B-tldr
モデルのクエリ
サーバーが起動したら、OpenAI APIを使用してモデルにクエリを投げることができます。
from openai import OpenAI
openai_api_key = "EMPTY"
openai_api_base = "http://localhost:8000/v1"
client = OpenAI(
api_key=openai_api_key,
base_url=openai_api_base,
)
post="""
SUBREDDIT: r/AI
TITLE: Training sparse LLMs
POST: Now you can use the llm-compressor integration to axolotl to train sparse LLMs!
It's super easy to use. See the example in https://huggingface.co/RedHatAI/Sparse-Llama-3.1-8B-tldr-2of4.
And there's more. You can run 2:4 sparse models on vLLM and get significant speedupts on Hopper GPUs!
"""
prompt = f"Give a TL;DR of the following Reddit post.\n<|user|>{post}\nTL;DR:\n<|assistant|>\n"
completion = client.completions.create(
model="RedHatAI/Llama-3.1-8B-tldr",
prompt=prompt,
max_tokens=256,
)
print("Completion result:", completion)
✨ 主な機能
- モデルアーキテクチャ: LlamaForCausalLM
- 入力: テキスト
- 出力: テキスト
- リリース日: 2025年5月29日
- バージョン: 1.0
- 想定使用ケース: Reddit投稿のスタイルでテキストを要約するためにファインチューニングされています。
- 適用外: 適用される法律や規制(貿易コンプライアンス法を含む)に違反する方法での使用。許容使用ポリシーおよびLlama 3.1コミュニティライセンスによって禁止されている他の方法での使用。
- モデル開発者: Red Hat (Neural Magic)
📦 インストール
このモデルを使用するには、vLLMをインストールする必要があります。vLLMのインストール方法については、公式ドキュメント を参照してください。
💻 使用例
基本的な使用法
from openai import OpenAI
openai_api_key = "EMPTY"
openai_api_base = "http://localhost:8000/v1"
client = OpenAI(
api_key=openai_api_key,
base_url=openai_api_base,
)
post="""
SUBREDDIT: r/AI
TITLE: Training sparse LLMs
POST: Now you can use the llm-compressor integration to axolotl to train sparse LLMs!
It's super easy to use. See the example in https://huggingface.co/RedHatAI/Sparse-Llama-3.1-8B-tldr-2of4.
And there's more. You can run 2:4 sparse models on vLLM and get significant speedupts on Hopper GPUs!
"""
prompt = f"Give a TL;DR of the following Reddit post.\n<|user|>{post}\nTL;DR:\n<|assistant|>\n"
completion = client.completions.create(
model="RedHatAI/Llama-3.1-8B-tldr",
prompt=prompt,
max_tokens=256,
)
print("Completion result:", completion)
📚 ドキュメント
トレーニング
axolotl設定を表示
axolotlバージョン: 0.10.0.dev0
base_model: meta-llama/Llama-3.1-8B
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: trl-lib/tldr
type:
system_prompt: "Give a TL;DR of the following Reddit post."
field_system: system
field_instruction: prompt
field_output: completion
format: "<|user|>\n{instruction}\n<|assistant|>\n"
no_input_format: "<|user|>\n{instruction}\n<|assistant|>\n"
split: train
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
eval_sample_packing: true
torch.compile: true
gradient_accumulation_steps: 1
micro_batch_size: 4
num_epochs: 3
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 1e-5
max_grad_norm: 1
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
train_on_inputs: false
bf16: auto
fp16:
tf32: false
early_stopping_patience:
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_ratio: 0.05
evals_per_epoch: 4
val_set_size: 0.05
save_strategy: "best"
save_total_limit: 1
metric_for_best_model: "loss"
debug:
deepspeed:
weight_decay: 0.0
special_tokens:
pad_token: "<|end_of_text|>"
seed: 0
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_layer_norm: true
liger_fused_linear_cross_entropy: true
トレーニングハイパーパラメータ
トレーニング中に使用されたハイパーパラメータは以下の通りです。
- learning_rate: 1e-05
- train_batch_size: 4
- eval_batch_size: 4
- seed: 0
- distributed_type: multi-GPU
- num_devices: 8
- total_train_batch_size: 32
- total_eval_batch_size: 32
- optimizer: betas=(0.9,0.999) およびepsilon=1e-08のadamw_bnb_8bitを使用。追加のオプティマイザ引数はありません。
- lr_scheduler_type: cosine
- lr_scheduler_warmup_steps: 49
- num_epochs: 3.0
トレーニング結果
トレーニング損失 | エポック | ステップ | 検証損失 |
---|---|---|---|
2.2572 | 0.0031 | 1 | 2.2288 |
1.7865 | 0.2508 | 82 | 1.7680 |
1.7257 | 0.5015 | 164 | 1.7567 |
1.7343 | 0.7523 | 246 | 1.7489 |
1.7688 | 1.0031 | 328 | 1.7441 |
1.6822 | 1.2538 | 410 | 1.7493 |
1.6085 | 1.5046 | 492 | 1.7480 |
1.6627 | 1.7554 | 574 | 1.7444 |
1.729 | 2.0061 | 656 | 1.7426 |
1.6149 | 2.2569 | 738 | 1.7540 |
1.6002 | 2.5076 | 820 | 1.7537 |
1.6573 | 2.7584 | 902 | 1.7526 |
フレームワークバージョン
- Transformers 4.51.3
- Pytorch 2.7.0+cu126
- Datasets 3.5.1
- Tokenizers 0.21.1
評価
このモデルは、trl-lib/tldr のテストスプリットで、lm-evaluation-harness (tldrブランチ) のNeural Magicフォークを使用して評価されました。以下のコマンドを使用することで、これらの結果を再現することができます。
lm_eval --model vllm --model_args "pretrained=RedHatAI/Llama-3.1-8B-tldr,dtype=auto,add_bos_token=True" --batch-size auto --tasks tldr
メトリック | Llama-3.1-8B-Instruct | Llama-3.1-8B-tldr | Sparse-Llama-3.1-8B-tldr-2of4 (このモデル) |
---|---|---|---|
BERTScore | -0.230 | 0.366 | 0.366 |
ROUGE-1 | 0.059 | 0.362 | 0.357 |
ROUGE-2 | 0.018 | 0.144 | 0.141 |
ROUGE-Lsum | 0.051 | 0.306 | 0.304 |
推論パフォーマンス
このモデルの推論パフォーマンスは、trl-lib/tldr データセットのトレーニングセットの最初の1,000サンプルを使用して評価されました。ベンチマークは、vLLM バージョン 0.9.0.1
および GuideLLM バージョン 0.2.1
で実行されました。
下の図は、さまざまなリクエストレートにおける リクエストあたりの平均エンドツーエンドレイテンシー を示しています。このモデルと2つのバリアントの結果が表示されています。
- Dense-quantized: Llama-3.1-8B-tldr-FP8-dynamic
- Sparse-quantized: Sparse-Llama-3.1-8B-tldr-2of4-FP8-dynamic
再現手順
ベンチマークを再現するには、以下の手順を実行します。
- 最初の1,000個のトレーニングサンプルを含むJSONファイルを生成します。
from datasets import load_dataset
ds = load_dataset("trl-lib/tldr", split="train").take(1000)
ds.to_json("tldr_1000.json")
- ターゲットモデルを使用してvLLMサーバーを起動します。
vllm serve RedHatAI/Llama-3.1-8B-tldr
- GuideLLMでベンチマークを実行します。
GUIDELLM__OPENAI__MAX_OUTPUT_TOKENS=128 guidellm benchmark --target "http://localhost:8000" --rate-type sweep --data tldr_1000.json
平均出力長はサンプルあたり約30トークンです。稀な異常に冗長な完了によるパフォーマンスの歪みを減らすために、生成を128トークンで制限しました。
📄 ライセンス
このモデルは、llama3.1ライセンスの下で提供されています。
その他の情報
属性 | 詳情 |
---|---|
モデルタイプ | LlamaForCausalLM |
学習データ | trl-lib/tldr |



