diff --git a/oslo/torch/nn/parallel/data_parallel/__init__.py b/oslo/torch/nn/parallel/data_parallel/__init__.py index 2d993d1d..d75274c4 100644 --- a/oslo/torch/nn/parallel/data_parallel/__init__.py +++ b/oslo/torch/nn/parallel/data_parallel/__init__.py @@ -1 +1,6 @@ +from oslo.torch.nn.parallel.data_parallel.data_parallel import ( + DistributedDataParallel, +) from oslo.torch.nn.parallel.data_parallel.zero import * + +__ALL__ = ["DistributedDataParallel", "ZeroRedundancyOptimizer"] diff --git a/oslo/torch/nn/parallel/data_parallel/_reducer.py b/oslo/torch/nn/parallel/data_parallel/_reducer.py new file mode 100644 index 00000000..455e00f1 --- /dev/null +++ b/oslo/torch/nn/parallel/data_parallel/_reducer.py @@ -0,0 +1,116 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import functools +from typing import Callable, Dict, List, Optional, Tuple + +import torch +import torch.distributed as dist +from torch import Tensor +from torch.distributed import ProcessGroup + + +class Bucket: + def __init__( + self, size: int, dtype: torch.dtype, device: torch.device, group: ProcessGroup + ): + self.buffer = torch.zeros(size, dtype=dtype, device=device) + self.group = group + self.offset = 0 + self.callbacks: List[Callable] = [] + + def flush(self) -> None: + """Flush content of the bucket.""" + if self.offset == 0: + assert len(self.callbacks) == 0 + return + # reduce-scatter bucket + dist.all_reduce(self.buffer[: self.offset], group=self.group) + + # execute post-reduction callbacks + for callback_fn in self.callbacks: + callback_fn() + # reuse input bucket but allocate a fresh output shard + self.offset = 0 + self.callbacks.clear() + self.buffer = torch.zeros_like(self.buffer) + + def alloc(self) -> None: + + if self.buffer.storage().size() == 0: + self.buffer.storage().resize_(self.buffer.numel()) + + def free(self) -> None: + + assert self.offset == 0 and self.callbacks == [], "Incorrect call of teardown" + self.buffer.storage().resize_(0) + + def append(self, tensor: Tensor, callback_fn: Callable): + tensor_size = tensor.numel() + offset = self.offset + self.buffer[offset : offset + tensor_size].copy_(tensor.flatten()) + self.offset += tensor_size + + # callback will be given the reduced result + if callback_fn is not None: + result_view = self.buffer[offset : offset + tensor_size].view(tensor.shape) + self.callbacks.append(functools.partial(callback_fn, result_view)) + + @property + def avail_size(self) -> int: + return self.buffer.size(0) - self.offset + + +class Reducer: + def __init__(self, bucket_size_mb: int = 25): + self.bucket_size_mb = bucket_size_mb + self.buckets: Dict[Tuple[torch.dtype, torch.device, ProcessGroup], Bucket] = {} + + @torch.no_grad() + def all_reduce_async( + self, + tensor: Tensor, + group: ProcessGroup, + callback_fn: Optional[Callable] = None, + ) -> None: + bucket_size = self._get_bucket_size(tensor.element_size()) + + if tensor.numel() >= bucket_size: + dist.all_reduce(tensor, group=group) + if callback_fn is not None: + callback_fn(tensor) + return + + bucket = self._get_bucket(tensor, group) + if tensor.numel() > bucket.avail_size: + # not enough space remaining in bucket, flush it now + bucket.flush() + bucket.append(tensor, callback_fn) + + @torch.no_grad() + def flush(self) -> None: + for bucket in self.buckets.values(): + bucket.flush() + + @torch.no_grad() + def free(self) -> None: + for bucket in self.buckets.values(): + bucket.free() + + @functools.lru_cache() + def _get_bucket_size(self, element_size: int) -> int: + if self.bucket_size_mb <= 0: # Values <= 0 disable bucketing. + return 0 + MB = 1024 * 1024 + bucket_size = self.bucket_size_mb * MB / element_size + return int(bucket_size) + + def _get_bucket(self, tensor: Tensor, group: ProcessGroup) -> Bucket: + key = (tensor.dtype, tensor.device, group) + if key not in self.buckets: + bucket_size = self._get_bucket_size(tensor.element_size()) + self.buckets[key] = Bucket(bucket_size, tensor.dtype, tensor.device, group) + self.buckets[key].alloc() + return self.buckets[key] diff --git a/oslo/torch/nn/parallel/data_parallel/_utils.py b/oslo/torch/nn/parallel/data_parallel/_utils.py new file mode 100644 index 00000000..2155097f --- /dev/null +++ b/oslo/torch/nn/parallel/data_parallel/_utils.py @@ -0,0 +1,24 @@ +from typing import Iterable + +import torch + + +def is_ddp_ignored(p): + return getattr(p, "_ddp_to_ignore", False) + + +def set_params_to_ignore(params_to_ignore: Iterable[torch.Tensor]) -> None: + """Sets parameters to be ignored by DDP. + This method must be called before initializing DistributedDataParallel. + Example: + >>> params_to_ignore = [] + >>> for p in module.parameters(): + >>> if should_ignore(p): + >>> params_to_ignore.append(p) + >>> set_params_to_ignore(params_to_ignore) + >>> module = DistributedDataParallel(module) + Args: + params_to_ignore (Iterable[torch.Tensor]): A list of parameters to be ignored. + """ + for p in params_to_ignore: + p._ddp_to_ignore = True diff --git a/oslo/torch/nn/parallel/data_parallel/data_parallel.py b/oslo/torch/nn/parallel/data_parallel/data_parallel.py new file mode 100644 index 00000000..a4b08fc1 --- /dev/null +++ b/oslo/torch/nn/parallel/data_parallel/data_parallel.py @@ -0,0 +1,179 @@ +import copy +from functools import partial + +import torch +import torch.nn as nn +from torch.autograd import Variable +import torch.distributed as dist + +try: + from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX +except ImportError: + _EXTRA_STATE_KEY_SUFFIX = "_extra_state" + +from oslo.torch.distributed.parallel_context import ParallelContext +from oslo.torch.distributed.parallel_mode import ParallelMode +from oslo.torch.nn.parallel.utils import ( + get_parallel_context, + add_wrapper, + OsloParallelWrapper, +) +from oslo.torch.nn.parallel.data_parallel._reducer import Reducer +from oslo.torch.nn.parallel.data_parallel._utils import is_ddp_ignored + + +def free_storage(data: torch.Tensor) -> None: + """Free underlying storage of a Tensor.""" + if data.storage().size() > 0: + # Since we're modifying the Tensor's Storage directly, make sure the Tensor + # is the sole occupant of the Storage. + assert data.storage_offset() == 0 + data.storage().resize_(0) + + +class BackwardFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, module, inputs): + ctx.module = module + return inputs + + @staticmethod + def backward(ctx, *grad_outputs): + ctx.module._pre_backward() + # Enqueue a callback to flush the reducer. + # This callback is triggered after all gradients' calculation is completed. + Variable._execution_engine.queue_callback(ctx.module._post_backward) + return (None,) + grad_outputs + + +def DistributedDataParallel( + module: nn.Module, + parallel_context: ParallelContext, + bucket_cap_mb: int = 25, + rebuild_bucket: bool = True, +): + ddp = _DistributedDataParallel( + module=module, + parallel_context=parallel_context, + bucket_cap_mb=bucket_cap_mb, + rebuild_bucket=rebuild_bucket, + ) + + add_wrapper( + module, + mode=ParallelMode.DATA, + wrapper=ddp, + parallel_context=parallel_context, + ) + return module + + +class _DistributedDataParallel(OsloParallelWrapper): + """Distributed data parallel wrapper for Oslo. + Example: + >>> from oslo.torch.nn.parallel import DistributedDataParallel as DDP + >>> model = torch.nn.Linear(20, 1) + >>> model = DDP(model, parallel_context) + >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1) + >>> olso.ready(model, parallel_context) + >>> model.zero_grad() + >>> logits = model(x) + >>> loss = criterion(logits, labels) + >>> loss.backward() + >>> optimizer.step() + Args: + module (nn.Module): PyTorch module object + parallel_context (ParallelContext): process group object + """ + + def __init__( + self, + module: torch.nn.Module, + parallel_context: ParallelContext = None, + bucket_cap_mb: int = 25, + rebuild_bucket: bool = True, + ) -> None: + super().__init__(parallelism_priority=99) + self.module = module + + self.comm_stream: torch.cuda.Stream = torch.cuda.Stream() + assert parallel_context + self.parallel_context = get_parallel_context(module, parallel_context) + self.dp_world_size = self.parallel_context.get_world_size(ParallelMode.DATA) + + self.reducer = Reducer(bucket_cap_mb) + self.rebuild_bucket = rebuild_bucket + + for p in module.parameters(): + if is_ddp_ignored(p): + continue + if p.requires_grad: + p.register_hook(partial(self.grad_handle, p)) + + def parallelize(self): + self._forward = copy.copy(self.module.forward) + self.module.zero_grad = self.zero_grad + + def forward(self, *args, **kwargs): + return BackwardFunction.apply(self, self._forward(*args, **kwargs)) + + def _pre_backward(self): + pass + + def _post_backward(self): + with torch.cuda.stream(self.comm_stream): + self.reducer.flush() + torch.cuda.current_stream().wait_stream(self.comm_stream) + if self.rebuild_bucket: + self.reducer.free() + for p in self.module.parameters(): + if is_ddp_ignored(p): + continue + if p.grad.device.type != "cpu": + p.grad = p._saved_grad + + def grad_handle(self, p, grad): + if grad.device.type != "cpu": + empty_grad = torch.empty_like(grad) + free_storage(empty_grad) + if self.dp_world_size > 1: + grad = grad / self.dp_world_size + self.comm_stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(self.comm_stream): + self.reducer.all_reduce_async( + grad, + group=self.parallel_context.get_group(ParallelMode.DATA), + callback_fn=partial(self._save_grad, p), + ) + grad.record_stream(self.comm_stream) + else: + _DistributedDataParallel._save_grad(p, grad) + + return empty_grad + + else: + # You must model.to('cpu') after oslo.ready() to use cpu. + dist.all_reduce( + grad, group=self.parallel_context.get_cpu_group(ParallelMode.DATA) + ) + return grad + + @staticmethod + def _save_grad(p, grad): + if hasattr(p, "_saved_grad"): + p._saved_grad.add_(grad) + else: + p._saved_grad = grad + + def zero_grad(self, set_to_none: bool = False) -> None: + super().zero_grad(set_to_none=True) + for p in self.module.parameters(): + if getattr(p, "_saved_grad", None) is not None: + if set_to_none: + p._saved_grad = None + else: + if p._saved_grad.grad_fn is not None: + p._saved_grad.detach_() + else: + p._saved_grad.requires_grad_(False) + p._saved_grad.zero_() diff --git a/oslo/torch/nn/parallel/data_parallel/zero/__init__.py b/oslo/torch/nn/parallel/data_parallel/zero/__init__.py index bd05561e..8ebd1b5e 100644 --- a/oslo/torch/nn/parallel/data_parallel/zero/__init__.py +++ b/oslo/torch/nn/parallel/data_parallel/zero/__init__.py @@ -2,4 +2,4 @@ ZeroRedundancyOptimizer, ) -__all__ = ["ZeroRedundancyOptimizer"] +__ALL__ = ["ZeroRedundancyOptimizer"] diff --git a/oslo/torch/nn/parallel/data_parallel/zero/sharded_optim/__init__.py b/oslo/torch/nn/parallel/data_parallel/zero/sharded_optim/__init__.py index bd05561e..8ebd1b5e 100644 --- a/oslo/torch/nn/parallel/data_parallel/zero/sharded_optim/__init__.py +++ b/oslo/torch/nn/parallel/data_parallel/zero/sharded_optim/__init__.py @@ -2,4 +2,4 @@ ZeroRedundancyOptimizer, ) -__all__ = ["ZeroRedundancyOptimizer"] +__ALL__ = ["ZeroRedundancyOptimizer"] diff --git a/oslo/torch/nn/parallel/data_parallel/zero/sharded_optim/bookkeeping/__init__.py b/oslo/torch/nn/parallel/data_parallel/zero/sharded_optim/bookkeeping/__init__.py index ba750fb7..499f9aee 100644 --- a/oslo/torch/nn/parallel/data_parallel/zero/sharded_optim/bookkeeping/__init__.py +++ b/oslo/torch/nn/parallel/data_parallel/zero/sharded_optim/bookkeeping/__init__.py @@ -3,4 +3,4 @@ from .parameter_store import ParameterStore from .tensor_store import TensorBucket -_all__ = ["BucketStore", "GradientStore", "ParameterStore", "TensorBucket"] +__ALL__ = ["BucketStore", "GradientStore", "ParameterStore", "TensorBucket"]