From 25f1d2fe5e2b0e49669851365255f6b26f2db669 Mon Sep 17 00:00:00 2001 From: xueyingyi <3245896298@qq.com> Date: Wed, 4 Dec 2024 21:43:04 +0800 Subject: [PATCH] feature(xyy):add HPT model and test_hpt --- .../lunarlander/entry/lunarlander_dqn_example.py | 12 ++++++++---- .../lunarlander/entry/lunarlander_hpt_example.py | 15 ++++++++++----- 2 files changed, 18 insertions(+), 9 deletions(-) diff --git a/dizoo/box2d/lunarlander/entry/lunarlander_dqn_example.py b/dizoo/box2d/lunarlander/entry/lunarlander_dqn_example.py index 048554c310..a90344e6c0 100644 --- a/dizoo/box2d/lunarlander/entry/lunarlander_dqn_example.py +++ b/dizoo/box2d/lunarlander/entry/lunarlander_dqn_example.py @@ -19,16 +19,19 @@ def main(): logging.getLogger().setLevel(logging.INFO) - cfg = compile_config(main_config, create_cfg=create_config, auto=True, save_cfg=task.router.node_id == 0) + cfg = compile_config(main_config, create_cfg=create_config, + auto=True, save_cfg=task.router.node_id == 0) ding_init(cfg) with task.start(async_mode=False, ctx=OnlineRLContext()): collector_env = SubprocessEnvManagerV2( - env_fn=[lambda: DingEnvWrapper(gym.make("LunarLander-v2")) for _ in range(cfg.env.collector_env_num)], + env_fn=[lambda: DingEnvWrapper( + gym.make("LunarLander-v2")) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager ) evaluator_env = SubprocessEnvManagerV2( - env_fn=[lambda: DingEnvWrapper(gym.make("LunarLander-v2")) for _ in range(cfg.env.evaluator_env_num)], + env_fn=[lambda: DingEnvWrapper( + gym.make("LunarLander-v2")) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager ) @@ -38,7 +41,8 @@ def main(): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = DQN(**cfg.policy.model).to(device) - buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size) + buffer_ = DequeBuffer( + size=cfg.policy.other.replay_buffer.replay_buffer_size) # Pass the model into Policy policy = DQNPolicy(cfg.policy, model=model) diff --git a/dizoo/box2d/lunarlander/entry/lunarlander_hpt_example.py b/dizoo/box2d/lunarlander/entry/lunarlander_hpt_example.py index f5da04eaf9..1fe52cffdd 100644 --- a/dizoo/box2d/lunarlander/entry/lunarlander_hpt_example.py +++ b/dizoo/box2d/lunarlander/entry/lunarlander_hpt_example.py @@ -21,16 +21,19 @@ def main(): logging.getLogger().setLevel(logging.INFO) - cfg = compile_config(main_config, create_cfg=create_config, auto=True, save_cfg=task.router.node_id == 0) + cfg = compile_config(main_config, create_cfg=create_config, + auto=True, save_cfg=task.router.node_id == 0) ding_init(cfg) with task.start(async_mode=False, ctx=OnlineRLContext()): collector_env = SubprocessEnvManagerV2( - env_fn=[lambda: DingEnvWrapper(gym.make("LunarLander-v2")) for _ in range(cfg.env.collector_env_num)], + env_fn=[lambda: DingEnvWrapper( + gym.make("LunarLander-v2")) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager ) evaluator_env = SubprocessEnvManagerV2( - env_fn=[lambda: DingEnvWrapper(gym.make("LunarLander-v2")) for _ in range(cfg.env.evaluator_env_num)], + env_fn=[lambda: DingEnvWrapper( + gym.make("LunarLander-v2")) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager ) @@ -39,8 +42,10 @@ def main(): # Migrating models to the GPU device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # HPT introduces a Policy Stem module, which processes the input features using Cross-Attention. - model = HPT(cfg.policy.model.obs_shape, cfg.policy.model.action_shape).to(device) - buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size) + model = HPT(cfg.policy.model.obs_shape, + cfg.policy.model.action_shape).to(device) + buffer_ = DequeBuffer( + size=cfg.policy.other.replay_buffer.replay_buffer_size) # Pass the model into Policy policy = DQNPolicy(cfg.policy, model=model)