🚀 PPOエージェントでSeaquestNoFrameskip-v4をプレイ
このプロジェクトは、訓練済みのPPOエージェントモデルを使用し、stable-baselines3ライブラリを使ってSeaquestNoFrameskip-v4ゲームをプレイするものです。このモデルは、深度強化学習がAtariゲームにおいてどのように機能するかを示しています。
✨ 主な機能
- 深度強化学習のPPOアルゴリズムをベースにしており、SeaquestNoFrameskip-v4ゲームで良好な結果を得ています。
- stable-baselines3ライブラリを使用して訓練されており、このライブラリはモデルの訓練や評価に便利なツールやインターフェースを提供します。
- 訓練過程はWandBを通じて記録・監視されており、詳細な訓練レポートを確認することができます。
📦 インストール
gym==0.19
を使用する必要があります。これはAtariゲームのROMを含んでいるためです。
- このゲームで使用されるのは可能なアクションのみであるため、アクション空間は6です。
💻 使用例
基本的な使用法
エージェントと環境の相互作用を観察するには、次のようにします。
import os
import gym
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import VecNormalize
from stable_baselines3.common.env_util import make_atari_env
from stable_baselines3.common.vec_env import VecFrameStack
from huggingface_sb3 import load_from_hub, push_to_hub
checkpoint = load_from_hub("ThomasSimonini/ppo-SeaquestNoFrameskip-v4", "ppo-SeaquestNoFrameskip-v4.zip")
custom_objects = {
"learning_rate": 0.0,
"lr_schedule": lambda _: 0.0,
"clip_range": lambda _: 0.0,
}
model= PPO.load(checkpoint, custom_objects=custom_objects)
env = make_atari_env('SeaquestNoFrameskip-v4', n_envs=1)
env = VecFrameStack(env, n_stack=4)
obs = env.reset()
while True:
action, _states = model.predict(obs)
obs, rewards, dones, info = env.step(action)
env.render()
高度な使用法
モデルを訓練するコード例は次の通りです。
import wandb
import gym
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_atari_env
from stable_baselines3.common.vec_env import VecFrameStack, VecVideoRecorder
from stable_baselines3.common.callbacks import CheckpointCallback
from wandb.integration.sb3 import WandbCallback
from huggingface_sb3 import load_from_hub, push_to_hub
config = {
"env_name": "SeaquestNoFrameskip-v4",
"num_envs": 8,
"total_timesteps": int(10e6),
"seed": 2862830927,
}
run = wandb.init(
project="HFxSB3",
config = config,
sync_tensorboard = True,
monitor_gym = True,
save_code = True,
)
env = make_atari_env(config["env_name"], n_envs=config["num_envs"], seed=config["seed"])
print("ENV ACTION SPACE: ", env.action_space.n)
env = VecFrameStack(env, n_stack=4)
env = VecVideoRecorder(env, "videos", record_video_trigger=lambda x: x % 100000 == 0, video_length=2000)
model = PPO(policy = "CnnPolicy",
env = env,
batch_size = 256,
clip_range = 0.1,
ent_coef = 0.01,
gae_lambda = 0.9,
gamma = 0.99,
learning_rate = 2.5e-4,
max_grad_norm = 0.5,
n_epochs = 4,
n_steps = 128,
vf_coef = 0.5,
tensorboard_log = f"runs",
verbose=1,
)
model.learn(
total_timesteps = config["total_timesteps"],
callback = [
WandbCallback(
gradient_save_freq = 1000,
model_save_path = f"models/{run.id}",
),
CheckpointCallback(save_freq=10000, save_path='./seaquest',
name_prefix=config["env_name"]),
]
)
model.save("ppo-SeaquestNoFrameskip-v4.zip")
push_to_hub(repo_id="ThomasSimonini/ppo-SeaquestNoFrameskip-v4",
filename="ppo-SeaquestNoFrameskip-v4.zip",
commit_message="Added Seaquest trained agent")
📚 ドキュメント
評価結果
平均報酬:1820.00 +/- 20.0
訓練レポート:https://wandb.ai/simoninithomas/HFxSB3/reports/Atari-HFxSB3-Benchmark--VmlldzoxNjI3NTIy
情報テーブル
属性 |
詳細 |
モデルタイプ |
PPOエージェント |
訓練データ |
SeaquestNoFrameskip-v4ゲーム環境 |
注意事項
⚠️ 重要な注意
gym==0.19
を使用する必要があります。これはAtariゲームのROMを含んでいるためです。
💡 使用上の提案
ColabでPython 3.7を使用している場合、このエージェントはPython 3.8で訓練されているため、Pickleエラーを避けるためにcustom_objects
を設定する必要があります。