Skip to content

Commit

Permalink
polish(pu): polish resume_training in entry
Browse files Browse the repository at this point in the history
  • Loading branch information
puyuan1996 committed Nov 4, 2024
1 parent f165c59 commit 9fede9b
Show file tree
Hide file tree
Showing 11 changed files with 15 additions and 21 deletions.
4 changes: 2 additions & 2 deletions ding/entry/serial_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def serial_pipeline(
auto=True,
create_cfg=create_cfg,
save_cfg=True,
renew_dir=not cfg.policy.learn.resume_training
renew_dir=not cfg.policy.learn.get('resume_training', False)
)
# Create main components: env, policy
if env_setting is None:
Expand Down Expand Up @@ -94,7 +94,7 @@ def serial_pipeline(
# ==========
# Learner's before_run hook.
learner.call_hook('before_run')
if cfg.policy.learn.resume_training:
if cfg.policy.learn.get('resume_training', False):
collector.envstep = learner.collector_envstep

# Accumulate plenty of data at the beginning of training.
Expand Down
13 changes: 6 additions & 7 deletions ding/entry/serial_entry_mbrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def mbrl_entry_setup(
auto=True,
create_cfg=create_cfg,
save_cfg=True,
renew_dir=not cfg.policy.learn.resume_training
renew_dir=not cfg.policy.learn.get('resume_training', False)
)

if env_setting is None:
Expand Down Expand Up @@ -79,8 +79,7 @@ def mbrl_entry_setup(
)

return (
cfg, policy, world_model, env_buffer, learner, collector, collector_env, evaluator, commander, tb_logger,
resume_training
cfg, policy, world_model, env_buffer, learner, collector, collector_env, evaluator, commander, tb_logger
)


Expand Down Expand Up @@ -125,13 +124,13 @@ def serial_pipeline_dyna(
Returns:
- policy (:obj:`Policy`): Converged policy.
"""
cfg, policy, world_model, env_buffer, learner, collector, collector_env, evaluator, commander, tb_logger, resume_training = \
cfg, policy, world_model, env_buffer, learner, collector, collector_env, evaluator, commander, tb_logger = \
mbrl_entry_setup(input_cfg, seed, env_setting, model)

img_buffer = create_img_buffer(cfg, input_cfg, world_model, tb_logger)

learner.call_hook('before_run')
if cfg.policy.learn.resume_training:
if cfg.policy.learn.get('resume_training', False):
collector.envstep = learner.collector_envstep

if cfg.policy.get('random_collect_size', 0) > 0:
Expand Down Expand Up @@ -200,11 +199,11 @@ def serial_pipeline_dream(
Returns:
- policy (:obj:`Policy`): Converged policy.
"""
cfg, policy, world_model, env_buffer, learner, collector, collector_env, evaluator, commander, tb_logger, resume_training = \
cfg, policy, world_model, env_buffer, learner, collector, collector_env, evaluator, commander, tb_logger = \
mbrl_entry_setup(input_cfg, seed, env_setting, model)

learner.call_hook('before_run')
if cfg.policy.learn.resume_training:
if cfg.policy.learn.get('resume_training', False):
collector.envstep = learner.collector_envstep

if cfg.policy.get('random_collect_size', 0) > 0:
Expand Down
4 changes: 2 additions & 2 deletions ding/entry/serial_entry_ngu.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def serial_pipeline_ngu(
auto=True,
create_cfg=create_cfg,
save_cfg=True,
renew_dir=not cfg.policy.learn.resume_training
renew_dir=not cfg.policy.learn.get('resume_training', False)
)
# Create main components: env, policy
if env_setting is None:
Expand Down Expand Up @@ -97,7 +97,7 @@ def serial_pipeline_ngu(
# ==========
# Learner's before_run hook.
learner.call_hook('before_run')
if cfg.policy.learn.resume_training:
if cfg.policy.learn.get('resume_training', False):
collector.envstep = learner.collector_envstep

# Accumulate plenty of data at the beginning of training.
Expand Down
5 changes: 3 additions & 2 deletions ding/entry/serial_entry_onpolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,9 @@ def serial_pipeline_onpolicy(
auto=True,
create_cfg=create_cfg,
save_cfg=True,
renew_dir=not cfg.policy.learn.resume_training
renew_dir=not cfg.policy.learn.get('resume_training', False)
)

# Create main components: env, policy
if env_setting is None:
env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
Expand Down Expand Up @@ -88,7 +89,7 @@ def serial_pipeline_onpolicy(
# ==========
# Learner's before_run hook.
learner.call_hook('before_run')
if cfg.policy.learn.resume_training:
if cfg.policy.learn.get('resume_training', False):
collector.envstep = learner.collector_envstep

while True:
Expand Down
4 changes: 2 additions & 2 deletions ding/entry/serial_entry_onpolicy_ppg.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def serial_pipeline_onpolicy_ppg(
auto=True,
create_cfg=create_cfg,
save_cfg=True,
renew_dir=not cfg.policy.learn.resume_training
renew_dir=not cfg.policy.learn.get('resume_training', False)
)
# Create main components: env, policy
if env_setting is None:
Expand Down Expand Up @@ -88,7 +88,7 @@ def serial_pipeline_onpolicy_ppg(
# ==========
# Learner's before_run hook.
learner.call_hook('before_run')
if cfg.policy.learn.resume_training:
if cfg.policy.learn.get('resume_training', False):
collector.envstep = learner.collector_envstep

while True:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
learning_rate=0.001,
# (float) loss weight of the entropy regularization, the weight of policy network is set to 1
entropy_weight=0.01,
resume_training=False,
),
collect=dict(
# (int) collect n_sample data, train model n_iteration times
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
batch_size=64,
learning_rate=0.001,
entropy_weight=0.001,
resume_training=False,
),
collect=dict(n_episode=80, unroll_len=1, discount_factor=0.9),
eval=dict(evaluator=dict(eval_freq=100, ), ),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
entropy_weight=0.01,
clip_ratio=0.2,
learner=dict(hook=dict(save_ckpt_after_iter=100)),
resume_training=False,
),
collect=dict(
n_sample=256,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
batch_size=64,
learning_rate=0.001,
entropy_weight=0.001,
resume_training=False,
),
collect=dict(n_episode=80, unroll_len=1, discount_factor=0.9, collector=dict(get_train_sample=True)),
eval=dict(evaluator=dict(eval_freq=100, ), ),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
adv_norm=True,
value_norm=True,
ignore_done=True,
resume_training=False,
),
collect=dict(
n_sample=5000,
Expand Down
1 change: 0 additions & 1 deletion dizoo/petting_zoo/config/ptz_simple_spread_mappo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@
grad_clip_type='clip_norm',
grad_clip_value=10,
ignore_done=False,
resume_training=False,
),
collect=dict(
n_sample=3200,
Expand Down

0 comments on commit 9fede9b

Please sign in to comment.