🚀 PPO 智能體玩轉 LunarLander-v2
本項目是一個經過訓練的 PPO 智能體模型,它藉助 stable-baselines3 庫 來玩 LunarLander-v2 遊戲。該模型在強化學習領域表現出色,能為相關研究和應用提供有力支持。
🏷️ 標籤
- LunarLander-v2
- 深度強化學習
- 強化學習
- stable-baselines3
📊 模型指標
屬性 |
詳情 |
模型名稱 |
PPO |
任務類型 |
強化學習 |
數據集 |
LunarLander-v2 |
平均獎勵 |
283.49 ± 13.74 |
🚀 快速開始
💻 使用示例
基礎用法
from huggingface_sb3 import load_from_hub
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.evaluation import evaluate_policy
checkpoint = load_from_hub("araffin/ppo-LunarLander-v2", "ppo-LunarLander-v2.zip")
model = PPO.load(checkpoint)
env = make_vec_env("LunarLander-v2", n_envs=1)
print("Evaluating model")
mean_reward, std_reward = evaluate_policy(
model,
env,
n_eval_episodes=20,
deterministic=True,
)
print(f"Mean reward = {mean_reward:.2f} +/- {std_reward:.2f}")
obs = env.reset()
try:
while True:
action, _states = model.predict(obs, deterministic=True)
obs, rewards, dones, info = env.step(action)
env.render()
except KeyboardInterrupt:
pass
高級用法(訓練代碼)
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.callbacks import EvalCallback
env_id = "LunarLander-v2"
n_envs = 16
env = make_vec_env(env_id, n_envs=n_envs)
eval_envs = make_vec_env(env_id, n_envs=5)
eval_freq = int(1e5)
eval_freq = max(eval_freq // n_envs, 1)
eval_callback = EvalCallback(
eval_envs,
best_model_save_path="./logs/",
eval_freq=eval_freq,
n_eval_episodes=10,
)
model = PPO(
"MlpPolicy",
env,
n_steps=1024,
batch_size=64,
gae_lambda=0.98,
gamma=0.999,
n_epochs=4,
ent_coef=0.01,
verbose=1,
)
try:
model.learn(total_timesteps=int(5e6), callback=eval_callback)
except KeyboardInterrupt:
pass
model = PPO.load("logs/best_model.zip")