Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch distributed: add support for user-specified parameter synchronization #1612

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 26 additions & 13 deletions returnn/torch/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,15 @@
"""

from __future__ import annotations
from typing import Optional, Any, Dict
import logging
import os
import socket
import logging
from typing import Callable, Optional, Any, Dict

import torch
from torch.nn.parallel import DistributedDataParallel

from returnn.config import Config
from returnn.util.basic import CollectionReadCheckCovered
from returnn.util.basic import CollectionReadCheckCovered, get_fwd_compat_kwargs

_logger = logging.getLogger("returnn.torch.distributed")

Expand Down Expand Up @@ -42,8 +41,12 @@ def __init__(self, options: Dict[str, Any]):
% (socket.gethostname(), os.getpid(), self._rank, self._size, self._local_rank, self._local_size)
)

self._custom_step_after_param_update: Optional[Callable] = self._opts.get(
"custom_step_after_param_update", None
)
self._reduce_type = self._opts.get("reduce_type", "grad")
self._param_sync_step: Optional[int] = self._opts.get("param_sync_step", None)

if self._reduce_type == "param":
assert isinstance(self._param_sync_step, int) and self._param_sync_step > 0, (
f"reduce_type param: param_sync_step must be a positive int,"
Expand All @@ -52,6 +55,10 @@ def __init__(self, options: Dict[str, Any]):
_logger.info(f"reduce_type param: param_sync_step {self._param_sync_step}")
elif self._reduce_type == "grad":
_logger.info("reduce_type grad")
elif self._reduce_type == "custom_step_after_param_update":
if not isinstance(self._custom_step_after_param_update, Callable):
raise ValueError(f"custom step callback must be a callable, not {self._custom_step_after_param_update}")
_logger.info("reduce_type custom_step_after_param_update")
else:
raise ValueError(f"invalid reduce_type {self._reduce_type!r}")

Expand All @@ -70,6 +77,8 @@ def _check_no_unknown_opts(self):
self._opts.get("options")
if self._reduce_type == "param":
self._opts.get("sync_on_cpu")
if self._reduce_type == "custom":
self._opts.get("synchronizer")

self._opts.assert_all_read()

Expand Down Expand Up @@ -100,9 +109,10 @@ def maybe_make_distributed_module(self, module: torch.nn.Module) -> Optional[Dis
:param module: original module
:return: potentially wrapped module
"""
if self._reduce_type == "param":
if self._reduce_type in ["param", "custom_step_after_param_update"]:
return None
assert self._reduce_type == "grad"

cls = self._opts.get("class", DistributedDataParallel)
if cls is not DistributedDataParallel:
_logger.warning(f"Using custom class {cls} instead of DistributedDataParallel, might be unsupported.")
Expand All @@ -115,7 +125,11 @@ def maybe_make_distributed_module(self, module: torch.nn.Module) -> Optional[Dis

def step_after_param_update(self, *, module: torch.nn.Module, epoch_step_idx: int):
"""one train step"""
if self._reduce_type == "param" and ((epoch_step_idx % self._param_sync_step) == (self._param_sync_step - 1)):
if self._reduce_type == "custom_step_after_param_update":
self._custom_step_after_param_update(
module=module, train_step_idx=epoch_step_idx, **get_fwd_compat_kwargs()
)
elif self._reduce_type == "param" and ((epoch_step_idx % self._param_sync_step) == (self._param_sync_step - 1)):
_sync_params_avg(module=module, sync_on_cpu=self._opts.get("sync_on_cpu", False))


Expand All @@ -127,7 +141,7 @@ def get_ctx(config=None) -> Optional[DistributedContext]:
"""
:param Config|None config:
:returns: the global context if Torch distributed is enabled, or None otherwise.
If we did not setup the context yet, it will automatically create it.
If we did not set up the context yet, it will automatically create it.
"""
global _is_set_up, _ctx
if _is_set_up:
Expand Down Expand Up @@ -155,7 +169,7 @@ def _sync_params_avg(*, module: torch.nn.Module, sync_on_cpu: bool = False):

if sync_on_cpu:
for param in module.parameters():
# Separately move each param to CPU (instead of the whole module), to safe CPU memory.
# Separately move each param to CPU (instead of the whole module), to save CPU memory.
param_cpu = param.to(torch.device("cpu"))
# On CPU, we are likely using Gloo, and Gloo does not support AVG
dist.all_reduce(param_cpu.data, op=dist.ReduceOp.SUM)
Expand All @@ -166,12 +180,11 @@ def _sync_params_avg(*, module: torch.nn.Module, sync_on_cpu: bool = False):
if dist.get_backend() == "gloo":
# Gloo does not support AVG
reduce_op = dist.ReduceOp.SUM
elif hasattr(dist.ReduceOp, "AVG"):
reduce_op = dist.ReduceOp.AVG
else:
if hasattr(dist.ReduceOp, "AVG"):
reduce_op = dist.ReduceOp.AVG
else:
# Older PyTorch versions do not have ReduceOp.AVG.
reduce_op = dist.ReduceOp.SUM
# Older PyTorch versions do not have ReduceOp.AVG.
reduce_op = dist.ReduceOp.SUM

for param in module.parameters():
dist.all_reduce(param.data, op=reduce_op)
Expand Down
12 changes: 12 additions & 0 deletions returnn/util/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4586,3 +4586,15 @@ def override_env_var(var_name: str, value: str):
os.environ[var_name] = cur_val
else:
os.environ.pop(var_name)


_fwd_compat_rng = np.random.default_rng()


def get_fwd_compat_kwargs() -> Dict[str, Any]:
"""
Returns a dictionary suitable for passing as kwargs for any RETURNN userland
function where forwards compatibility wrt. additional arguments must be
ensured.
"""
return {f"fwd_compatible_random_kwarg_{_fwd_compat_rng.integers(0, 100)}": None}
Loading