🚀 PPO 智能体玩转 LunarLander-v2
本项目是一个经过训练的 PPO 智能体模型,它借助 stable-baselines3 库 来玩 LunarLander-v2 游戏。该模型在强化学习领域表现出色,能为相关研究和应用提供有力支持。
🏷️ 标签
- LunarLander-v2
- 深度强化学习
- 强化学习
- stable-baselines3
📊 模型指标
属性 |
详情 |
模型名称 |
PPO |
任务类型 |
强化学习 |
数据集 |
LunarLander-v2 |
平均奖励 |
283.49 ± 13.74 |
🚀 快速开始
💻 使用示例
基础用法
from huggingface_sb3 import load_from_hub
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.evaluation import evaluate_policy
checkpoint = load_from_hub("araffin/ppo-LunarLander-v2", "ppo-LunarLander-v2.zip")
model = PPO.load(checkpoint)
env = make_vec_env("LunarLander-v2", n_envs=1)
print("Evaluating model")
mean_reward, std_reward = evaluate_policy(
model,
env,
n_eval_episodes=20,
deterministic=True,
)
print(f"Mean reward = {mean_reward:.2f} +/- {std_reward:.2f}")
obs = env.reset()
try:
while True:
action, _states = model.predict(obs, deterministic=True)
obs, rewards, dones, info = env.step(action)
env.render()
except KeyboardInterrupt:
pass
高级用法(训练代码)
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.callbacks import EvalCallback
env_id = "LunarLander-v2"
n_envs = 16
env = make_vec_env(env_id, n_envs=n_envs)
eval_envs = make_vec_env(env_id, n_envs=5)
eval_freq = int(1e5)
eval_freq = max(eval_freq // n_envs, 1)
eval_callback = EvalCallback(
eval_envs,
best_model_save_path="./logs/",
eval_freq=eval_freq,
n_eval_episodes=10,
)
model = PPO(
"MlpPolicy",
env,
n_steps=1024,
batch_size=64,
gae_lambda=0.98,
gamma=0.999,
n_epochs=4,
ent_coef=0.01,
verbose=1,
)
try:
model.learn(total_timesteps=int(5e6), callback=eval_callback)
except KeyboardInterrupt:
pass
model = PPO.load("logs/best_model.zip")