Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Supporting Adding DDP Communication Hooks #6736

Merged
merged 56 commits into from
Apr 7, 2021
Merged
Show file tree
Hide file tree
Changes from 53 commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
89f284d
Fix some test errors
Mar 23, 2021
80cfbff
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
Mar 23, 2021
536c132
checkpoint consolidation
Mar 24, 2021
f172101
Update ddp_spawn.py
shuyingsunshine21 Mar 24, 2021
bf70e43
Update test_metric_result_integration.py
shuyingsunshine21 Mar 24, 2021
ea74906
Update test_results.py
shuyingsunshine21 Mar 24, 2021
a9aae99
Update utils.py
shuyingsunshine21 Mar 24, 2021
70fe5da
Update utils.py
shuyingsunshine21 Mar 24, 2021
0d23d75
Update test_all_gather_grad.py
shuyingsunshine21 Mar 24, 2021
ca6f98b
Update test_all_gather_grad.py
shuyingsunshine21 Mar 24, 2021
c5053da
Merge pull request #1 from shuyingsunshine21/shuyingsunshine21-checkp…
shuyingsunshine21 Mar 24, 2021
9d4a2b8
Update test_results.py
shuyingsunshine21 Mar 24, 2021
7635b4f
Revert "Update test_results.py"
shuyingsunshine21 Mar 24, 2021
d64f90c
Revert "Merge pull request #1 from shuyingsunshine21/shuyingsunshine2…
shuyingsunshine21 Mar 24, 2021
dcdcd29
Revert "Update test_all_gather_grad.py"
shuyingsunshine21 Mar 24, 2021
8651d54
Revert "Update utils.py"
shuyingsunshine21 Mar 24, 2021
15f4b9e
Revert "Update utils.py"
shuyingsunshine21 Mar 24, 2021
250d0aa
Revert "Update test_results.py"
shuyingsunshine21 Mar 24, 2021
6c095b2
Revert "Update test_metric_result_integration.py"
shuyingsunshine21 Mar 24, 2021
8222dc9
Revert "Update ddp_spawn.py"
shuyingsunshine21 Mar 24, 2021
3a9fde9
Revert "checkpoint consolidation"
shuyingsunshine21 Mar 24, 2021
7a369f4
Revert "Revert "checkpoint consolidation""
shuyingsunshine21 Mar 24, 2021
b4a0b9e
Revert "Revert "Revert "checkpoint consolidation"""
shuyingsunshine21 Mar 24, 2021
5cf1db1
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
Mar 24, 2021
0ce7e05
Revert "Revert "Update ddp_spawn.py""
shuyingsunshine21 Mar 24, 2021
fe9736d
Revert "Revert "Update test_metric_result_integration.py""
shuyingsunshine21 Mar 24, 2021
c314ef6
Revert "Revert "Update test_results.py""
shuyingsunshine21 Mar 24, 2021
c3feda0
Revert "Revert "Update utils.py""
shuyingsunshine21 Mar 24, 2021
c759477
Revert "Revert "Update test_all_gather_grad.py""
shuyingsunshine21 Mar 24, 2021
7a8e540
Merge branch 'master' of https://github.com/shuyingsunshine21/pytorch…
Mar 24, 2021
ab8b849
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
Mar 24, 2021
4e67db2
modify distributed environment to make test pass
Mar 24, 2021
67b6188
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
Mar 25, 2021
1e41d5b
add DDP communication hook
Mar 30, 2021
6833b87
remove test related setting
Mar 30, 2021
f856d31
remove more test related setting
Mar 30, 2021
14a0a1b
fix ddp comm hook util import issue
Mar 30, 2021
8998469
comments
Mar 30, 2021
a17947b
one more fix for test_custom_plugin
Mar 30, 2021
91a945a
fix ddp spwan
Mar 30, 2021
78c6925
fix sgd
Mar 30, 2021
443f223
address comments and add tests
Mar 30, 2021
f8d0603
1. add is gpu checking 2. modify test a bit 3. formatting
Mar 31, 2021
f06285f
formatting nit
Mar 31, 2021
b607ebd
fix conda 3.7 1.7 issue for no torch.distributed.algorithms module
Mar 31, 2021
6cc9dfa
need at least 1.8.0
Apr 1, 2021
b12a16b
minor fix
Apr 1, 2021
25ccb82
modify changelog
Apr 1, 2021
35d49bc
changelog should link to PR number instead of issue number
Apr 1, 2021
dc5c55c
refine a bit on doc for register_ddp_comm_hook function, like ddp_com…
Apr 1, 2021
fb184b2
move single device checking before call register_ddp_comm_hook
Apr 1, 2021
bf44378
formatting
Apr 2, 2021
d529985
comments
Apr 5, 2021
b8105be
typo
Apr 5, 2021
e32a11d
pre-commit formatting
Apr 6, 2021
2275b45
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
Apr 6, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added support for `precision=64`, enabling training with double precision ([#6595](https://github.com/PyTorchLightning/pytorch-lightning/pull/6595))

- Added support for DDP communication hooks ([#6736] (https://github.com/PyTorchLightning/pytorch-lightning/issues/6736))
carmocca marked this conversation as resolved.
Show resolved Hide resolved

### Changed

Expand Down
36 changes: 35 additions & 1 deletion pytorch_lightning/plugins/training_type/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,21 @@
from pytorch_lightning.overrides.distributed import prepare_for_backward
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
from pytorch_lightning.utilities import _HYDRA_AVAILABLE, _TORCH_GREATER_EQUAL_1_7, rank_zero_warn
from pytorch_lightning.utilities import (
_HYDRA_AVAILABLE,
_TORCH_GREATER_EQUAL_1_7,
_TORCH_GREATER_EQUAL_1_8,
rank_zero_warn,
)
from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp, sync_ddp_if_available
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.seed import seed_everything

if _HYDRA_AVAILABLE:
from hydra.core.hydra_config import HydraConfig
from hydra.utils import get_original_cwd, to_absolute_path
if _TORCH_GREATER_EQUAL_1_8:
from pytorch_lightning.utilities.distributed import register_ddp_comm_hook

log = logging.getLogger(__name__)

Expand All @@ -58,6 +65,9 @@ def __init__(
num_nodes: int = 1,
cluster_environment: ClusterEnvironment = None,
sync_batchnorm: bool = False,
ddp_comm_state: Optional[object] = None,
ddp_comm_hook: Optional[callable] = None,
ddp_comm_wrapper: Optional[callable] = None,
**kwargs: Union[Any, Dict[str, Any]],
) -> None:
super().__init__(parallel_devices=parallel_devices, cluster_environment=cluster_environment)
Expand All @@ -70,6 +80,9 @@ def __init__(
self.task_idx = None
self.node_rank = 0
self.num_processes = len(parallel_devices) if parallel_devices is not None else parallel_devices
self._ddp_comm_state = ddp_comm_state
self._ddp_comm_hook = ddp_comm_hook
self._ddp_comm_wrapper = ddp_comm_wrapper

@property
def root_device(self):
Expand All @@ -80,6 +93,10 @@ def distributed_sampler_kwargs(self):
distributed_sampler_kwargs = dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank)
return distributed_sampler_kwargs

@property
def _is_single_process_single_device(self) -> bool:
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved
return True

def setup_environment(self):
# start the other scripts
if not self.cluster_environment.creates_children() and os.environ.get("PL_IN_DDP_SUBPROCESS", "0") != "1":
Expand Down Expand Up @@ -218,13 +235,30 @@ def pre_configure_ddp(self):
)
self._ddp_kwargs["find_unused_parameters"] = True

def _register_ddp_hooks(self) -> None:
# currently, DDP communication hooks only work with NCCL backend and SPSD (singlge process single device) mode
carmocca marked this conversation as resolved.
Show resolved Hide resolved
# https://github.com/pytorch/pytorch/blob/e6779d4357ae94cc9f9fedb83a87eb6126016769/
# torch/nn/parallel/distributed.py#L1040
carmocca marked this conversation as resolved.
Show resolved Hide resolved
if (
_TORCH_GREATER_EQUAL_1_8
and self.on_gpu
and self._is_single_process_single_device
):
register_ddp_comm_hook(
model=self._model,
ddp_comm_state=self._ddp_comm_state,
ddp_comm_hook=self._ddp_comm_hook,
ddp_comm_wrapper=self._ddp_comm_wrapper,
)

def configure_ddp(self):
self.pre_configure_ddp()
self._model = DistributedDataParallel(
LightningDistributedModule(self.model),
device_ids=self.determine_ddp_device_ids(),
**self._ddp_kwargs,
)
self._register_ddp_hooks()

def determine_ddp_device_ids(self):
if self.root_device.type == "cpu":
Expand Down
4 changes: 4 additions & 0 deletions pytorch_lightning/plugins/training_type/ddp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ def distributed_sampler_kwargs(self):
distributed_sampler_kwargs = dict(num_replicas=self.num_nodes, rank=self.global_rank)
return distributed_sampler_kwargs

@property
def _is_single_process_single_device(self) -> bool:
return False

def set_world_ranks(self):
self.local_rank = self.task_idx
self.node_rank = self.cluster_environment.node_rank()
Expand Down
34 changes: 31 additions & 3 deletions pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,13 @@
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_7
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_7, _TORCH_GREATER_EQUAL_1_8
from pytorch_lightning.utilities.cloud_io import atomic_save
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn, ReduceOp, sync_ddp_if_available
from pytorch_lightning.utilities.seed import seed_everything
if _TORCH_GREATER_EQUAL_1_8:
from pytorch_lightning.utilities.distributed import register_ddp_comm_hook

log = logging.getLogger(__name__)

Expand All @@ -47,16 +49,22 @@ def __init__(
num_nodes: int = 1,
cluster_environment: ClusterEnvironment = None,
sync_batchnorm: bool = False,
ddp_comm_state: Optional[object] = None,
ddp_comm_hook: Optional[callable] = None,
ddp_comm_wrapper: Optional[callable] = None,
**kwargs: Union[Any, Dict[str, Any]],
):
super().__init__(parallel_devices=parallel_devices, cluster_environment=cluster_environment)
self.num_nodes = num_nodes
self.sync_batchnorm = sync_batchnorm
self._ddp_kwargs = kwargs
self.dist = LightningDistributed()
self.num_processes = len(parallel_devices)
self.num_processes = len(parallel_devices) if parallel_devices is not None else 0
self.node_rank = 0
self.mp_queue = None
self._ddp_comm_state = ddp_comm_state
self._ddp_comm_hook = ddp_comm_hook
self._ddp_comm_wrapper = ddp_comm_wrapper

def __getstate__(self):
""" Makes this plugin pickleable without destroying the queue in the current process. """
Expand All @@ -76,9 +84,12 @@ def distributed_sampler_kwargs(self):
distributed_sampler_kwargs = dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank)
return distributed_sampler_kwargs

@property
def _is_single_process_single_device(self):
return True

def setup(self, model):
os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port())

# pass in a state q
smp = mp.get_context("spawn")
self.mp_queue = smp.SimpleQueue()
Expand Down Expand Up @@ -181,13 +192,30 @@ def pre_configure_ddp(self):
)
self._ddp_kwargs["find_unused_parameters"] = True

def _register_ddp_hooks(self) -> None:
# currently, DDP communication hooks only work with NCCL backend and SPSD (singlge process single device) mode
carmocca marked this conversation as resolved.
Show resolved Hide resolved
# https://github.com/pytorch/pytorch/blob/e6779d4357ae94cc9f9fedb83a87eb6126016769/
# torch/nn/parallel/distributed.py#L1040
carmocca marked this conversation as resolved.
Show resolved Hide resolved
if (
_TORCH_GREATER_EQUAL_1_8
and self.on_gpu
and self._is_single_process_single_device
):
register_ddp_comm_hook(
model=self._model,
ddp_comm_state=self._ddp_comm_state,
ddp_comm_hook=self._ddp_comm_hook,
ddp_comm_wrapper=self._ddp_comm_wrapper,
)

def configure_ddp(self):
self.pre_configure_ddp()
self._model = DistributedDataParallel(
LightningDistributedModule(self.model),
device_ids=self.determine_ddp_device_ids(),
**self._ddp_kwargs,
)
self._register_ddp_hooks()

def init_ddp_connection(self, global_rank: int, world_size: int) -> None:
# TODO: this code is duplicated in DDP and DDPSpawn, make this a function
Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/utilities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
_RPC_AVAILABLE,
_TORCH_GREATER_EQUAL_1_6,
_TORCH_GREATER_EQUAL_1_7,
_TORCH_GREATER_EQUAL_1_8,
_TORCH_GREATER_EQUAL_1_9,
_TORCH_LOWER_EQUAL_1_4,
_TORCH_QUANTIZE_AVAILABLE,
_TORCHTEXT_AVAILABLE,
Expand Down
110 changes: 110 additions & 0 deletions pytorch_lightning/utilities/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,15 @@
import warnings
from functools import wraps
from typing import Any, Optional, Union
from pytorch_lightning.utilities.imports import (
_TORCH_GREATER_EQUAL_1_8,
_TORCH_GREATER_EQUAL_1_9,
)

import torch

from torch.nn.parallel.distributed import DistributedDataParallel

log = logging.getLogger(__name__)

if torch.distributed.is_available():
Expand Down Expand Up @@ -197,3 +203,107 @@ def all_gather_ddp_if_available(
with torch.no_grad():
return AllGatherGrad.apply(tensor, group)
return tensor


def register_ddp_comm_hook(
model: DistributedDataParallel,
ddp_comm_state: Optional[object] = None,
ddp_comm_hook: Optional[callable] = None,
ddp_comm_wrapper: Optional[callable] = None,
) -> None:
"""
Function to register communication hook for DDP model
https://pytorch.org/docs/master/ddp_comm_hooks.html

Args:
model:
DDP model
ddp_comm_state:
state is passed to the hook and can be used to maintain
and update any state information that users would like to
maintain as part of the training process. Examples: error
feedback in gradient compression, peers to communicate with
next in GossipGrad etc.
ddp_comm_hook:
hook(state: object, bucket: dist._GradBucket) -> torch.futures.Future

This callable function is called once the bucket is ready. The
hook can perform whatever processing is needed and return
a Future indicating completion of any async work (ex: allreduce).
If the hook doesn't perform any communication, it can also
just return a completed Future. The Future should hold the
new value of grad bucket's tensors. Once a bucket is ready,
c10d reducer would call this hook and use the tensors returned
by the Future and copy grads to individual parameters.
ddp_comm_wrapper:
communication hook wraper to support a communication hook such
as FP16 compression as wrapper, which could be combined with
ddp_comm_hook

.. warning ::
DDP communication hook needs pytorch version at least 1.8.0

.. warning ::
DDP communication wrapper needs pytorch version at least 1.9.0

Example:

from torch.distributed.algorithms.ddp_comm_hooks import (
default_hooks as default,
powerSGD_hook as powerSGD,
)

# fp16_compress_hook for compress gradients
register_ddp_comm_hook(
model=ddp_model,
ddp_comm_hook=default.fp16_compress_hook,
)

# powerSGD_hook
register_ddp_comm_hook(
model=ddp_model,
ddp_comm_state=powerSGD.PowerSGDState(
process_group=None,
matrix_approximation_rank=1,
start_powerSGD_iter=5000,
),
ddp_comm_hook=powerSGD.powerSGD_hook,
)

# fp16_compress_wrapper combined with other communication hook
register_ddp_comm_hook(
model=ddp_model,
ddp_comm_state=powerSGD.PowerSGDState(
process_group=None,
matrix_approximation_rank=1,
start_powerSGD_iter=5000,
),
ddp_comm_hook=powerSGD.powerSGD_hook,
ddp_comm_wrapper=default.fp16_compress_wrapper,
)
"""
if not _TORCH_GREATER_EQUAL_1_8:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically it's also available in 1.7.0 right? But protected with an underscore. Do we want to include it or were important improvements done from 1.7 to 1.8?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

encountered import issue when I tried to import torch.distributed.algorithms for 1.7.0.
complaining ModuleNotFoundError: No module named torch.distributed.algorithms for conda tests (3.7, 1.7)

also, power SGD is introduced later

rank_zero_warn(
"Not registering DDP comm hook. "
"To use communication hooks, please use pytorch>=1.8.0."
)
return
if ddp_comm_hook is None:
return
if ddp_comm_wrapper is not None:
if not _TORCH_GREATER_EQUAL_1_9:
rank_zero_warn(
"Not applying DDP comm wrapper. "
"To use communication wrapper, please use pytorch>=1.9.0."
)
else:
rank_zero_info(
f"DDP comm wrapper is provided, apply {ddp_comm_wrapper.__qualname__}({ddp_comm_hook.__qualname__})."
)
ddp_comm_hook = ddp_comm_wrapper(ddp_comm_hook)

rank_zero_debug(f"Registering DDP comm hook: {ddp_comm_hook.__qualname__}.")
model.register_comm_hook(
state=ddp_comm_state,
hook=ddp_comm_hook,
)
1 change: 1 addition & 0 deletions pytorch_lightning/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def _compare_version(package: str, op, version) -> bool:
_TORCH_GREATER_EQUAL_1_6 = _compare_version("torch", operator.ge, "1.6.0")
_TORCH_GREATER_EQUAL_1_7 = _compare_version("torch", operator.ge, "1.7.0")
_TORCH_GREATER_EQUAL_1_8 = _compare_version("torch", operator.ge, "1.8.0")
_TORCH_GREATER_EQUAL_1_9 = _compare_version("torch", operator.ge, "1.9.0")

_KINETO_AVAILABLE = torch.profiler.kineto_available() if _TORCH_GREATER_EQUAL_1_8 else False
_APEX_AVAILABLE = _module_available("apex.amp")
Expand Down
1 change: 0 additions & 1 deletion tests/plugins/test_custom_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@


class CustomParallelPlugin(DDPPlugin):

def __init__(self, **kwargs):
super().__init__(**kwargs)
# Set to None so it will be overwritten by the accelerator connector.
Expand Down
Loading