diff --git a/returnn/torch/distributed.py b/returnn/torch/distributed.py index 3bb2cf825..58ac0dbe3 100644 --- a/returnn/torch/distributed.py +++ b/returnn/torch/distributed.py @@ -3,16 +3,15 @@ """ from __future__ import annotations -from typing import Optional, Any, Dict +import logging import os import socket -import logging +from typing import Callable, Optional, Any, Dict import torch from torch.nn.parallel import DistributedDataParallel -from returnn.config import Config -from returnn.util.basic import CollectionReadCheckCovered +from returnn.util.basic import CollectionReadCheckCovered, get_fwd_compat_kwargs _logger = logging.getLogger("returnn.torch.distributed") @@ -42,8 +41,12 @@ def __init__(self, options: Dict[str, Any]): % (socket.gethostname(), os.getpid(), self._rank, self._size, self._local_rank, self._local_size) ) + self._custom_step_after_param_update: Optional[Callable] = self._opts.get( + "custom_step_after_param_update", None + ) self._reduce_type = self._opts.get("reduce_type", "grad") self._param_sync_step: Optional[int] = self._opts.get("param_sync_step", None) + if self._reduce_type == "param": assert isinstance(self._param_sync_step, int) and self._param_sync_step > 0, ( f"reduce_type param: param_sync_step must be a positive int," @@ -52,6 +55,10 @@ def __init__(self, options: Dict[str, Any]): _logger.info(f"reduce_type param: param_sync_step {self._param_sync_step}") elif self._reduce_type == "grad": _logger.info("reduce_type grad") + elif self._reduce_type == "custom_step_after_param_update": + if not isinstance(self._custom_step_after_param_update, Callable): + raise ValueError(f"custom step callback must be a callable, not {self._custom_step_after_param_update}") + _logger.info("reduce_type custom_step_after_param_update") else: raise ValueError(f"invalid reduce_type {self._reduce_type!r}") @@ -70,6 +77,8 @@ def _check_no_unknown_opts(self): self._opts.get("options") if self._reduce_type == "param": self._opts.get("sync_on_cpu") + if self._reduce_type == "custom": + self._opts.get("synchronizer") self._opts.assert_all_read() @@ -100,9 +109,10 @@ def maybe_make_distributed_module(self, module: torch.nn.Module) -> Optional[Dis :param module: original module :return: potentially wrapped module """ - if self._reduce_type == "param": + if self._reduce_type in ["param", "custom_step_after_param_update"]: return None assert self._reduce_type == "grad" + cls = self._opts.get("class", DistributedDataParallel) if cls is not DistributedDataParallel: _logger.warning(f"Using custom class {cls} instead of DistributedDataParallel, might be unsupported.") @@ -115,7 +125,11 @@ def maybe_make_distributed_module(self, module: torch.nn.Module) -> Optional[Dis def step_after_param_update(self, *, module: torch.nn.Module, epoch_step_idx: int): """one train step""" - if self._reduce_type == "param" and ((epoch_step_idx % self._param_sync_step) == (self._param_sync_step - 1)): + if self._reduce_type == "custom_step_after_param_update": + self._custom_step_after_param_update( + module=module, train_step_idx=epoch_step_idx, **get_fwd_compat_kwargs() + ) + elif self._reduce_type == "param" and ((epoch_step_idx % self._param_sync_step) == (self._param_sync_step - 1)): _sync_params_avg(module=module, sync_on_cpu=self._opts.get("sync_on_cpu", False)) @@ -127,7 +141,7 @@ def get_ctx(config=None) -> Optional[DistributedContext]: """ :param Config|None config: :returns: the global context if Torch distributed is enabled, or None otherwise. - If we did not setup the context yet, it will automatically create it. + If we did not set up the context yet, it will automatically create it. """ global _is_set_up, _ctx if _is_set_up: @@ -155,7 +169,7 @@ def _sync_params_avg(*, module: torch.nn.Module, sync_on_cpu: bool = False): if sync_on_cpu: for param in module.parameters(): - # Separately move each param to CPU (instead of the whole module), to safe CPU memory. + # Separately move each param to CPU (instead of the whole module), to save CPU memory. param_cpu = param.to(torch.device("cpu")) # On CPU, we are likely using Gloo, and Gloo does not support AVG dist.all_reduce(param_cpu.data, op=dist.ReduceOp.SUM) @@ -166,12 +180,11 @@ def _sync_params_avg(*, module: torch.nn.Module, sync_on_cpu: bool = False): if dist.get_backend() == "gloo": # Gloo does not support AVG reduce_op = dist.ReduceOp.SUM + elif hasattr(dist.ReduceOp, "AVG"): + reduce_op = dist.ReduceOp.AVG else: - if hasattr(dist.ReduceOp, "AVG"): - reduce_op = dist.ReduceOp.AVG - else: - # Older PyTorch versions do not have ReduceOp.AVG. - reduce_op = dist.ReduceOp.SUM + # Older PyTorch versions do not have ReduceOp.AVG. + reduce_op = dist.ReduceOp.SUM for param in module.parameters(): dist.all_reduce(param.data, op=reduce_op) diff --git a/returnn/util/basic.py b/returnn/util/basic.py index 26fb24484..a696d7242 100644 --- a/returnn/util/basic.py +++ b/returnn/util/basic.py @@ -4586,3 +4586,15 @@ def override_env_var(var_name: str, value: str): os.environ[var_name] = cur_val else: os.environ.pop(var_name) + + +_fwd_compat_rng = np.random.default_rng() + + +def get_fwd_compat_kwargs() -> Dict[str, Any]: + """ + Returns a dictionary suitable for passing as kwargs for any RETURNN userland + function where forwards compatibility wrt. additional arguments must be + ensured. + """ + return {f"fwd_compatible_random_kwarg_{_fwd_compat_rng.integers(0, 100)}": None}