Skip to content

Commit

Permalink
polish(pu): adapt learner to unizero_multitask_ddp_v2
Browse files Browse the repository at this point in the history
  • Loading branch information
puyuan1996 committed Nov 12, 2024
1 parent e6a18ba commit a42c85b
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 4 deletions.
6 changes: 3 additions & 3 deletions ding/worker/learner/base_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,9 @@ def __init__(
self._logger, _ = build_logger(
'./{}/log/{}'.format(self._exp_name, self._instance_name), self._instance_name, need_tb=False
)
self._tb_logger = None

self._tb_logger = None
# self._tb_logger = None
# ========== TODO: unizero_multitask ddp_v2 ========
self._tb_logger = tb_logger


self._log_buffer = {
Expand Down
16 changes: 15 additions & 1 deletion ding/worker/learner/learner_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,8 +273,22 @@ def aggregate(data):
Returns:
- new_data (:obj:`dict`): data after reduce
"""
# if isinstance(data, dict):
# new_data = {k: aggregate(v) for k, v in data.items()}

def should_reduce(key):
# 检查 key 是否以 "noreduce_" 前缀开头
return not key.startswith("noreduce_")

if isinstance(data, dict):
new_data = {k: aggregate(v) for k, v in data.items()}
new_data = {}
for k, v in data.items():
if should_reduce(k):
new_data[k] = aggregate(v) # 对需要 reduce 的数据执行 allreduce
else:
new_data[k] = v # 不需要 reduce 的数据直接保留


elif isinstance(data, list) or isinstance(data, tuple):
new_data = [aggregate(t) for t in data]
elif isinstance(data, torch.Tensor):
Expand Down

0 comments on commit a42c85b

Please sign in to comment.