Skip to content

Commit

Permalink
polish(pu): adapt learn log to unizero_singletask_ddp
Browse files Browse the repository at this point in the history
  • Loading branch information
puyuan1996 committed Nov 12, 2024
1 parent 7a66b76 commit 3ea48ca
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 10 deletions.
12 changes: 6 additions & 6 deletions ding/worker/learner/base_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,9 @@ def __init__(
self._logger, _ = build_logger(
'./{}/log/{}'.format(self._exp_name, self._instance_name), self._instance_name, need_tb=False
)
# self._tb_logger = None
self._tb_logger = None
# ========== TODO: unizero_multitask ddp_v2 ========
self._tb_logger = tb_logger
# self._tb_logger = tb_logger


self._log_buffer = {
Expand Down Expand Up @@ -436,10 +436,10 @@ def policy(self, _policy: 'Policy') -> None: # noqa
Policy variable monitor is set alongside with policy, because variables are determined by specific policy.
"""
self._policy = _policy
# if self._rank == 0:
# self._monitor = get_simple_monitor_type(self._policy.monitor_vars())(TickTime(), expire=10)

self._monitor = get_simple_monitor_type(self._policy.monitor_vars())(TickTime(), expire=10)
if self._rank == 0:
self._monitor = get_simple_monitor_type(self._policy.monitor_vars())(TickTime(), expire=10)
# ========== TODO: unizero_multitask ddp_v2 ========
# self._monitor = get_simple_monitor_type(self._policy.monitor_vars())(TickTime(), expire=10)

if self._cfg.log_policy:
self.info(self._policy.info())
Expand Down
9 changes: 5 additions & 4 deletions ding/worker/learner/learner_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,11 +201,12 @@ def __call__(self, engine: 'BaseLearner') -> None: # noqa
Arguments:
- engine (:obj:`BaseLearner`): the BaseLearner
"""
# ========== TODO: unizero_multitask ddp_v2 ========
# Only show log for rank 0 learner
# if engine.rank != 0:
# for k in engine.log_buffer:
# engine.log_buffer[k].clear()
# return
if engine.rank != 0:
for k in engine.log_buffer:
engine.log_buffer[k].clear()
return

# For 'scalar' type variables: log_buffer -> tick_monitor -> monitor_time.step
for k, v in engine.log_buffer['scalar'].items():
Expand Down

0 comments on commit 3ea48ca

Please sign in to comment.