From 9fede9bde58b8b362719c92fd76198d5d7b3deda Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=92=B2=E6=BA=90?= <2402552459@qq.com> Date: Mon, 4 Nov 2024 20:02:37 +0800 Subject: [PATCH] polish(pu): polish resume_training in entry --- ding/entry/serial_entry.py | 4 ++-- ding/entry/serial_entry_mbrl.py | 13 ++++++------- ding/entry/serial_entry_ngu.py | 4 ++-- ding/entry/serial_entry_onpolicy.py | 5 +++-- ding/entry/serial_entry_onpolicy_ppg.py | 4 ++-- .../cartpole/config/cartpole_a2c_config.py | 1 - .../cartpole/config/cartpole_pg_config.py | 1 - .../cartpole/config/cartpole_ppo_stdim_config.py | 1 - .../cartpole/config/cartpole_ppopg_config.py | 1 - .../pendulum/config/pendulum_ppo_config.py | 1 - .../config/ptz_simple_spread_mappo_config.py | 1 - 11 files changed, 15 insertions(+), 21 deletions(-) diff --git a/ding/entry/serial_entry.py b/ding/entry/serial_entry.py index f1826f354c..f7c039e494 100644 --- a/ding/entry/serial_entry.py +++ b/ding/entry/serial_entry.py @@ -54,7 +54,7 @@ def serial_pipeline( auto=True, create_cfg=create_cfg, save_cfg=True, - renew_dir=not cfg.policy.learn.resume_training + renew_dir=not cfg.policy.learn.get('resume_training', False) ) # Create main components: env, policy if env_setting is None: @@ -94,7 +94,7 @@ def serial_pipeline( # ========== # Learner's before_run hook. learner.call_hook('before_run') - if cfg.policy.learn.resume_training: + if cfg.policy.learn.get('resume_training', False): collector.envstep = learner.collector_envstep # Accumulate plenty of data at the beginning of training. diff --git a/ding/entry/serial_entry_mbrl.py b/ding/entry/serial_entry_mbrl.py index 42e82a3c77..426e4ca667 100644 --- a/ding/entry/serial_entry_mbrl.py +++ b/ding/entry/serial_entry_mbrl.py @@ -37,7 +37,7 @@ def mbrl_entry_setup( auto=True, create_cfg=create_cfg, save_cfg=True, - renew_dir=not cfg.policy.learn.resume_training + renew_dir=not cfg.policy.learn.get('resume_training', False) ) if env_setting is None: @@ -79,8 +79,7 @@ def mbrl_entry_setup( ) return ( - cfg, policy, world_model, env_buffer, learner, collector, collector_env, evaluator, commander, tb_logger, - resume_training + cfg, policy, world_model, env_buffer, learner, collector, collector_env, evaluator, commander, tb_logger ) @@ -125,13 +124,13 @@ def serial_pipeline_dyna( Returns: - policy (:obj:`Policy`): Converged policy. """ - cfg, policy, world_model, env_buffer, learner, collector, collector_env, evaluator, commander, tb_logger, resume_training = \ + cfg, policy, world_model, env_buffer, learner, collector, collector_env, evaluator, commander, tb_logger = \ mbrl_entry_setup(input_cfg, seed, env_setting, model) img_buffer = create_img_buffer(cfg, input_cfg, world_model, tb_logger) learner.call_hook('before_run') - if cfg.policy.learn.resume_training: + if cfg.policy.learn.get('resume_training', False): collector.envstep = learner.collector_envstep if cfg.policy.get('random_collect_size', 0) > 0: @@ -200,11 +199,11 @@ def serial_pipeline_dream( Returns: - policy (:obj:`Policy`): Converged policy. """ - cfg, policy, world_model, env_buffer, learner, collector, collector_env, evaluator, commander, tb_logger, resume_training = \ + cfg, policy, world_model, env_buffer, learner, collector, collector_env, evaluator, commander, tb_logger = \ mbrl_entry_setup(input_cfg, seed, env_setting, model) learner.call_hook('before_run') - if cfg.policy.learn.resume_training: + if cfg.policy.learn.get('resume_training', False): collector.envstep = learner.collector_envstep if cfg.policy.get('random_collect_size', 0) > 0: diff --git a/ding/entry/serial_entry_ngu.py b/ding/entry/serial_entry_ngu.py index 5d78d9bd3f..1fcce53dc7 100644 --- a/ding/entry/serial_entry_ngu.py +++ b/ding/entry/serial_entry_ngu.py @@ -54,7 +54,7 @@ def serial_pipeline_ngu( auto=True, create_cfg=create_cfg, save_cfg=True, - renew_dir=not cfg.policy.learn.resume_training + renew_dir=not cfg.policy.learn.get('resume_training', False) ) # Create main components: env, policy if env_setting is None: @@ -97,7 +97,7 @@ def serial_pipeline_ngu( # ========== # Learner's before_run hook. learner.call_hook('before_run') - if cfg.policy.learn.resume_training: + if cfg.policy.learn.get('resume_training', False): collector.envstep = learner.collector_envstep # Accumulate plenty of data at the beginning of training. diff --git a/ding/entry/serial_entry_onpolicy.py b/ding/entry/serial_entry_onpolicy.py index da4beef68f..713fbcac58 100644 --- a/ding/entry/serial_entry_onpolicy.py +++ b/ding/entry/serial_entry_onpolicy.py @@ -52,8 +52,9 @@ def serial_pipeline_onpolicy( auto=True, create_cfg=create_cfg, save_cfg=True, - renew_dir=not cfg.policy.learn.resume_training + renew_dir=not cfg.policy.learn.get('resume_training', False) ) + # Create main components: env, policy if env_setting is None: env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) @@ -88,7 +89,7 @@ def serial_pipeline_onpolicy( # ========== # Learner's before_run hook. learner.call_hook('before_run') - if cfg.policy.learn.resume_training: + if cfg.policy.learn.get('resume_training', False): collector.envstep = learner.collector_envstep while True: diff --git a/ding/entry/serial_entry_onpolicy_ppg.py b/ding/entry/serial_entry_onpolicy_ppg.py index 1ce40131f5..90f31891ab 100644 --- a/ding/entry/serial_entry_onpolicy_ppg.py +++ b/ding/entry/serial_entry_onpolicy_ppg.py @@ -52,7 +52,7 @@ def serial_pipeline_onpolicy_ppg( auto=True, create_cfg=create_cfg, save_cfg=True, - renew_dir=not cfg.policy.learn.resume_training + renew_dir=not cfg.policy.learn.get('resume_training', False) ) # Create main components: env, policy if env_setting is None: @@ -88,7 +88,7 @@ def serial_pipeline_onpolicy_ppg( # ========== # Learner's before_run hook. learner.call_hook('before_run') - if cfg.policy.learn.resume_training: + if cfg.policy.learn.get('resume_training', False): collector.envstep = learner.collector_envstep while True: diff --git a/dizoo/classic_control/cartpole/config/cartpole_a2c_config.py b/dizoo/classic_control/cartpole/config/cartpole_a2c_config.py index f68d8503d8..ec6f93cd6e 100644 --- a/dizoo/classic_control/cartpole/config/cartpole_a2c_config.py +++ b/dizoo/classic_control/cartpole/config/cartpole_a2c_config.py @@ -21,7 +21,6 @@ learning_rate=0.001, # (float) loss weight of the entropy regularization, the weight of policy network is set to 1 entropy_weight=0.01, - resume_training=False, ), collect=dict( # (int) collect n_sample data, train model n_iteration times diff --git a/dizoo/classic_control/cartpole/config/cartpole_pg_config.py b/dizoo/classic_control/cartpole/config/cartpole_pg_config.py index 808b70501d..af3ee5ba04 100644 --- a/dizoo/classic_control/cartpole/config/cartpole_pg_config.py +++ b/dizoo/classic_control/cartpole/config/cartpole_pg_config.py @@ -18,7 +18,6 @@ batch_size=64, learning_rate=0.001, entropy_weight=0.001, - resume_training=False, ), collect=dict(n_episode=80, unroll_len=1, discount_factor=0.9), eval=dict(evaluator=dict(eval_freq=100, ), ), diff --git a/dizoo/classic_control/cartpole/config/cartpole_ppo_stdim_config.py b/dizoo/classic_control/cartpole/config/cartpole_ppo_stdim_config.py index 6735f980cd..3f6060797c 100644 --- a/dizoo/classic_control/cartpole/config/cartpole_ppo_stdim_config.py +++ b/dizoo/classic_control/cartpole/config/cartpole_ppo_stdim_config.py @@ -37,7 +37,6 @@ entropy_weight=0.01, clip_ratio=0.2, learner=dict(hook=dict(save_ckpt_after_iter=100)), - resume_training=False, ), collect=dict( n_sample=256, diff --git a/dizoo/classic_control/cartpole/config/cartpole_ppopg_config.py b/dizoo/classic_control/cartpole/config/cartpole_ppopg_config.py index 4ee8462575..623a3b5048 100644 --- a/dizoo/classic_control/cartpole/config/cartpole_ppopg_config.py +++ b/dizoo/classic_control/cartpole/config/cartpole_ppopg_config.py @@ -19,7 +19,6 @@ batch_size=64, learning_rate=0.001, entropy_weight=0.001, - resume_training=False, ), collect=dict(n_episode=80, unroll_len=1, discount_factor=0.9, collector=dict(get_train_sample=True)), eval=dict(evaluator=dict(eval_freq=100, ), ), diff --git a/dizoo/classic_control/pendulum/config/pendulum_ppo_config.py b/dizoo/classic_control/pendulum/config/pendulum_ppo_config.py index 6c58e78e64..151455aec1 100644 --- a/dizoo/classic_control/pendulum/config/pendulum_ppo_config.py +++ b/dizoo/classic_control/pendulum/config/pendulum_ppo_config.py @@ -35,7 +35,6 @@ adv_norm=True, value_norm=True, ignore_done=True, - resume_training=False, ), collect=dict( n_sample=5000, diff --git a/dizoo/petting_zoo/config/ptz_simple_spread_mappo_config.py b/dizoo/petting_zoo/config/ptz_simple_spread_mappo_config.py index 910a8bb997..5eb1095a5a 100644 --- a/dizoo/petting_zoo/config/ptz_simple_spread_mappo_config.py +++ b/dizoo/petting_zoo/config/ptz_simple_spread_mappo_config.py @@ -53,7 +53,6 @@ grad_clip_type='clip_norm', grad_clip_value=10, ignore_done=False, - resume_training=False, ), collect=dict( n_sample=3200,