diff --git a/CHANGELOG.md b/CHANGELOG.md index 42d78b40b0332..1f4defbd5cc30 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,6 +30,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `fsspec` to tuner ([#4458](https://github.com/PyTorchLightning/pytorch-lightning/pull/4458)) +- Added metrics aggregation in Horovod and fixed early stopping ([#3775](https://github.com/PyTorchLightning/pytorch-lightning/pull/3775)) + + ### Changed diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index e69addf234a36..408c979d1ff2e 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -14,7 +14,7 @@ import os import math from enum import Enum -from typing import Any, Optional +from typing import Any, Optional, Union import torch @@ -30,6 +30,12 @@ except ImportError: amp = None +if torch.distributed.is_available(): + from torch.distributed import ReduceOp +else: + class ReduceOp: + SUM = None + EPSILON = 1e-6 EPSILON_FP16 = 1e-5 @@ -209,6 +215,22 @@ def init_ddp_connection( torch_backend, rank=global_rank, world_size=world_size ) + def sync_tensor(self, + tensor: Union[torch.Tensor], + group: Optional[Any] = None, + reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: + """ + Function to reduce a tensor from several distributed processes to one aggregated tensor. + Args: + tensor: the tensor to sync and reduce + group: the process group to gather results from. Defaults to all processes (world) + reduce_op: the reduction operation. Defaults to sum. + Can also be a string of 'avg', 'mean' to calculate the mean during reduction. + Return: + reduced value + """ + raise NotImplementedError() + def __getstate__(self): return { 'trainer': self.trainer, diff --git a/pytorch_lightning/accelerators/ddp_accelerator.py b/pytorch_lightning/accelerators/ddp_accelerator.py index b9f01b5ddc167..b127fdd40c934 100644 --- a/pytorch_lightning/accelerators/ddp_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_accelerator.py @@ -18,17 +18,18 @@ import sys from os.path import abspath from time import sleep -from typing import Optional, List +from typing import Any, Optional, List, Union import numpy as np from pytorch_lightning import _logger as log -from pytorch_lightning.accelerators.accelerator import Accelerator +from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.distributed.dist import LightningDistributed from pytorch_lightning.utilities import AMPType from pytorch_lightning.utilities.distributed import find_free_network_port from pytorch_lightning.utilities.distributed import rank_zero_only +from pytorch_lightning.utilities.distributed import sync_ddp_if_available from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.seed import seed_everything from torch.nn.parallel import DistributedDataParallel @@ -298,3 +299,9 @@ def configure_sync_batchnorm(self, model: LightningModule) -> LightningModule: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model, process_group=None) return model + + def sync_tensor(self, + tensor: Union[torch.Tensor], + group: Optional[Any] = None, + reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: + return sync_ddp_if_available(tensor, group, reduce_op) diff --git a/pytorch_lightning/accelerators/ddp_cpu_slurm_accelerator.py b/pytorch_lightning/accelerators/ddp_cpu_slurm_accelerator.py index 2aad005a07847..c80e8a4ec355c 100644 --- a/pytorch_lightning/accelerators/ddp_cpu_slurm_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_cpu_slurm_accelerator.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License import os -from typing import List, Optional +from typing import Any, List, Optional, Union import torch import torch.distributed as torch_distrib @@ -20,10 +20,11 @@ from torch.nn.parallel import DistributedDataParallel from pytorch_lightning import _logger as log -from pytorch_lightning.accelerators.accelerator import Accelerator +from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.utilities import AMPType from pytorch_lightning.utilities.distributed import rank_zero_only +from pytorch_lightning.utilities.distributed import sync_ddp_if_available from pytorch_lightning.distributed.dist import LightningDistributed @@ -199,3 +200,9 @@ def configure_sync_batchnorm(self, model: LightningModule) -> LightningModule: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model, process_group=None) return model + + def sync_tensor(self, + tensor: Union[torch.Tensor], + group: Optional[Any] = None, + reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: + return sync_ddp_if_available(tensor, group, reduce_op) diff --git a/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py b/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py index f1813361c5eec..64e326b7ee0fc 100644 --- a/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License import os -from typing import List, Optional +from typing import Any, List, Optional, Union import torch import torch.distributed as torch_distrib @@ -21,11 +21,11 @@ from torch.nn.parallel import DistributedDataParallel from pytorch_lightning import _logger as log -from pytorch_lightning.accelerators.accelerator import Accelerator +from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.utilities import AMPType from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn -from pytorch_lightning.utilities.distributed import find_free_network_port +from pytorch_lightning.utilities.distributed import find_free_network_port, sync_ddp_if_available from pytorch_lightning.distributed.dist import LightningDistributed try: @@ -229,3 +229,9 @@ def configure_sync_batchnorm(self, model: LightningModule) -> LightningModule: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model, process_group=None) return model + + def sync_tensor(self, + tensor: Union[torch.Tensor], + group: Optional[Any] = None, + reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: + return sync_ddp_if_available(tensor, group, reduce_op) diff --git a/pytorch_lightning/accelerators/ddp_cpu_torchelastic_accelerator.py b/pytorch_lightning/accelerators/ddp_cpu_torchelastic_accelerator.py index 6b27e7da330ea..a90d7750eaeea 100644 --- a/pytorch_lightning/accelerators/ddp_cpu_torchelastic_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_cpu_torchelastic_accelerator.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License import os -from typing import List, Optional +from typing import Any, List, Optional, Union import torch import torch.distributed as torch_distrib @@ -20,11 +20,12 @@ from torch.nn.parallel import DistributedDataParallel from pytorch_lightning import _logger as log -from pytorch_lightning.accelerators.accelerator import Accelerator +from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.distributed.dist import LightningDistributed from pytorch_lightning.utilities import AMPType from pytorch_lightning.utilities.distributed import rank_zero_only +from pytorch_lightning.utilities.distributed import sync_ddp_if_available try: from hydra.utils import to_absolute_path, get_original_cwd @@ -198,3 +199,9 @@ def configure_sync_batchnorm(self, model: LightningModule) -> LightningModule: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model, process_group=None) return model + + def sync_tensor(self, + tensor: Union[torch.Tensor], + group: Optional[Any] = None, + reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: + return sync_ddp_if_available(tensor, group, reduce_op) diff --git a/pytorch_lightning/accelerators/ddp_slurm_accelerator.py b/pytorch_lightning/accelerators/ddp_slurm_accelerator.py index 8a6326d3d5cb8..4960445edd27d 100644 --- a/pytorch_lightning/accelerators/ddp_slurm_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_slurm_accelerator.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License import os -from typing import List +from typing import Any, List, Optional, Union import torch import torch.distributed as torch_distrib @@ -20,11 +20,11 @@ from torch.nn.parallel import DistributedDataParallel from pytorch_lightning import _logger as log -from pytorch_lightning.accelerators.accelerator import Accelerator +from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.distributed.dist import LightningDistributed from pytorch_lightning.utilities import AMPType -from pytorch_lightning.utilities.distributed import rank_zero_only +from pytorch_lightning.utilities.distributed import rank_zero_only, sync_ddp_if_available from pytorch_lightning.utilities.seed import seed_everything try: @@ -205,3 +205,9 @@ def configure_sync_batchnorm(self, model: LightningModule) -> LightningModule: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model, process_group=None) return model + + def sync_tensor(self, + tensor: Union[torch.Tensor], + group: Optional[Any] = None, + reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: + return sync_ddp_if_available(tensor, group, reduce_op) diff --git a/pytorch_lightning/accelerators/ddp_spawn_accelerator.py b/pytorch_lightning/accelerators/ddp_spawn_accelerator.py index b204494773362..2e0bac46c4c20 100644 --- a/pytorch_lightning/accelerators/ddp_spawn_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_spawn_accelerator.py @@ -13,7 +13,7 @@ # limitations under the License import os import re -from typing import List, Optional +from typing import Any, List, Optional, Union import torch import torch.multiprocessing as mp @@ -22,11 +22,12 @@ from torch.nn.parallel import DistributedDataParallel from pytorch_lightning import _logger as log -from pytorch_lightning.accelerators.accelerator import Accelerator +from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.utilities import AMPType from pytorch_lightning.utilities.cloud_io import atomic_save, load as pl_load from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn, find_free_network_port +from pytorch_lightning.utilities.distributed import sync_ddp_if_available from pytorch_lightning.utilities.seed import seed_everything from pytorch_lightning.distributed.dist import LightningDistributed @@ -254,3 +255,9 @@ def configure_sync_batchnorm(self, model: LightningModule) -> LightningModule: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model, process_group=None) return model + + def sync_tensor(self, + tensor: Union[torch.Tensor], + group: Optional[Any] = None, + reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: + return sync_ddp_if_available(tensor, group, reduce_op) diff --git a/pytorch_lightning/accelerators/ddp_torchelastic_accelerator.py b/pytorch_lightning/accelerators/ddp_torchelastic_accelerator.py index 8a9e6ac77e574..e54ad905de80e 100644 --- a/pytorch_lightning/accelerators/ddp_torchelastic_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_torchelastic_accelerator.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License import os -from typing import List, Optional +from typing import Any, List, Optional, Union import torch import torch.distributed as torch_distrib @@ -20,11 +20,12 @@ from torch.nn.parallel import DistributedDataParallel from pytorch_lightning import _logger as log -from pytorch_lightning.accelerators.accelerator import Accelerator +from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.distributed.dist import LightningDistributed from pytorch_lightning.utilities import AMPType from pytorch_lightning.utilities.distributed import rank_zero_only +from pytorch_lightning.utilities.distributed import sync_ddp_if_available try: @@ -201,3 +202,9 @@ def configure_sync_batchnorm(self, model: LightningModule) -> LightningModule: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model, process_group=None) return model + + def sync_tensor(self, + tensor: Union[torch.Tensor], + group: Optional[Any] = None, + reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: + return sync_ddp_if_available(tensor, group, reduce_op) diff --git a/pytorch_lightning/accelerators/horovod_accelerator.py b/pytorch_lightning/accelerators/horovod_accelerator.py index 91a5400999f6e..e5314a983f9db 100644 --- a/pytorch_lightning/accelerators/horovod_accelerator.py +++ b/pytorch_lightning/accelerators/horovod_accelerator.py @@ -12,12 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. from contextlib import ExitStack -from typing import Optional +from typing import Any, Optional, Union import torch from torch.optim.lr_scheduler import _LRScheduler -from pytorch_lightning.accelerators.accelerator import Accelerator +from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp from pytorch_lightning.utilities import AMPType from pytorch_lightning.utilities.distributed import rank_zero_only @@ -161,3 +161,41 @@ def barrier(self, name: Optional[str] = None): def broadcast(self, obj, src=0): obj = hvd.broadcast_object(obj, src) return obj + + def gather_all_tensors(self, result: Union[torch.Tensor], group: Optional[Any] = None): + if group is not None: + raise ValueError( + "Horovod does not support allgather using a subcommunicator at this time. " + "Unset `group`." + ) + + if len(result.shape) == 0: + # Convert scalars to single dimension tensors + result = result.reshape(1) + + # sync and gather all + hvd.join() + gathered = hvd.allgather(result) + gathered_result = list(gathered.split(1, dim=0)) + return gathered_result + + def sync_tensor(self, + tensor: Union[torch.Tensor], + group: Optional[Any] = None, + reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: + if group is not None: + raise ValueError( + "Horovod does not support allreduce using a subcommunicator at this time. " + "Unset `group`." + ) + + if reduce_op is None or reduce_op == "sum": + reduce_op = hvd.Sum + elif isinstance(reduce_op, str) and reduce_op in ("avg", "mean"): + reduce_op = hvd.Average + else: + raise ValueError(f"unrecognized `reduce_op`: {reduce_op}") + + # sync all processes before reduction + hvd.join() + return hvd.allreduce(tensor, op=reduce_op) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 78b9c7025711a..05eb8ee86be63 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -258,6 +258,8 @@ def log( raise MisconfigurationException( f"Logged key: {name} should not contain information about dataloader_idx.") + accelerator = self.trainer.accelerator_backend + self._results.log( name, value, @@ -272,6 +274,7 @@ def log( sync_dist, sync_dist_op, sync_dist_group, + accelerator.sync_tensor, self._current_dataloader_idx, ) diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index 059c724aa75a9..0eca72095e0e0 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -124,15 +124,17 @@ def log( sync_dist: bool = False, sync_dist_op: Union[Any, str] = 'mean', sync_dist_group: Optional[Any] = None, + sync_fn: Callable = None, dataloader_idx: Optional[int] = None, ): # no metrics should be logged with graphs if not enable_graph and isinstance(value, torch.Tensor): value = value.detach() - # sync across ddp + # sync across workers when using distributed training + sync_fn = sync_fn or sync_ddp_if_available if sync_dist and isinstance(value, (torch.Tensor, numbers.Number)): - value = sync_ddp_if_available(value, group=sync_dist_group, reduce_op=sync_dist_op) + value = sync_fn(value, group=sync_dist_group, reduce_op=sync_dist_op) if 'meta' not in self: self.__setitem__('meta', {}) diff --git a/pytorch_lightning/metrics/classification/accuracy.py b/pytorch_lightning/metrics/classification/accuracy.py index c7255c6e4497e..0f01fb9813407 100644 --- a/pytorch_lightning/metrics/classification/accuracy.py +++ b/pytorch_lightning/metrics/classification/accuracy.py @@ -50,6 +50,9 @@ class Accuracy(Metric): before returning the value at the step. default: False process_group: Specify the process group on which synchronization is called. default: None (which selects the entire world) + dist_sync_fn: + Callback that performs the allgather operation on the metric state. When `None`, DDP + will be used to perform the allgather. default: None Example: @@ -67,11 +70,13 @@ def __init__( compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, + dist_sync_fn: Callable = None, ): super().__init__( compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group, + dist_sync_fn=dist_sync_fn, ) self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum") diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index b716817427230..1a568bab37209 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -24,7 +24,7 @@ from torch import nn from pytorch_lightning.utilities.apply_func import apply_to_collection -from pytorch_lightning.utilities.distributed import gather_all_tensors_if_available +from pytorch_lightning.utilities.distributed import gather_all_tensors from pytorch_lightning.metrics.utils import _flatten, dim_zero_cat, dim_zero_mean, dim_zero_sum @@ -53,21 +53,26 @@ class Metric(nn.Module, ABC): Forward only calls ``update()`` and returns None if this is set to False. default: True dist_sync_on_step: Synchronize metric state across processes at each ``forward()`` - before returning the value at the step. default: False + before returning the value at the step. process_group: Specify the process group on which synchronization is called. default: None (which selects the entire world) + dist_sync_fn: + Callback that performs the allgather operation on the metric state. When `None`, DDP + will be used to perform the allgather. default: None """ def __init__( self, compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, + dist_sync_fn: Callable = None, ): super().__init__() self.dist_sync_on_step = dist_sync_on_step self.compute_on_step = compute_on_step self.process_group = process_group + self.dist_sync_fn = dist_sync_fn self._to_sync = True self.update = self._wrap_update(self.update) @@ -174,12 +179,12 @@ def forward(self, *args, **kwargs): return self._forward_cache - def _sync_dist(self): + def _sync_dist(self, dist_sync_fn=gather_all_tensors): input_dict = {attr: getattr(self, attr) for attr in self._reductions.keys()} output_dict = apply_to_collection( input_dict, torch.Tensor, - gather_all_tensors_if_available, + dist_sync_fn, group=self.process_group, ) @@ -208,12 +213,15 @@ def wrapped_func(*args, **kwargs): if self._computed is not None: return self._computed - if ( - self._to_sync - and torch.distributed.is_available() # noqa: W503 - and torch.distributed.is_initialized() # noqa: W503 - ): - self._sync_dist() + dist_sync_fn = self.dist_sync_fn + if (dist_sync_fn is None + and torch.distributed.is_available() + and torch.distributed.is_initialized()): + # User provided a bool, so we assume DDP if available + dist_sync_fn = gather_all_tensors + + if self._to_sync and dist_sync_fn is not None: + self._sync_dist(dist_sync_fn) self._computed = compute(*args, **kwargs) self.reset() diff --git a/pytorch_lightning/metrics/regression/explained_variance.py b/pytorch_lightning/metrics/regression/explained_variance.py index 79fc8b4c4e183..f59ce0b67de62 100644 --- a/pytorch_lightning/metrics/regression/explained_variance.py +++ b/pytorch_lightning/metrics/regression/explained_variance.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import torch -from typing import Any, Optional +from typing import Any, Callable, Optional from pytorch_lightning.metrics.metric import Metric from pytorch_lightning.utilities import rank_zero_warn @@ -74,11 +74,13 @@ def __init__( compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, + dist_sync_fn: Callable = None, ): super().__init__( compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group, + dist_sync_fn=dist_sync_fn, ) allowed_multioutput = ('raw_values', 'uniform_average', 'variance_weighted') if multioutput not in allowed_multioutput: diff --git a/pytorch_lightning/metrics/regression/mean_absolute_error.py b/pytorch_lightning/metrics/regression/mean_absolute_error.py index 89cb56d431ad4..ba6d2c6d79a08 100644 --- a/pytorch_lightning/metrics/regression/mean_absolute_error.py +++ b/pytorch_lightning/metrics/regression/mean_absolute_error.py @@ -49,11 +49,13 @@ def __init__( compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, + dist_sync_fn: Callable = None, ): super().__init__( compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group, + dist_sync_fn=dist_sync_fn, ) self.add_state("sum_abs_error", default=torch.tensor(0.0), dist_reduce_fx="sum") diff --git a/pytorch_lightning/metrics/regression/mean_squared_error.py b/pytorch_lightning/metrics/regression/mean_squared_error.py index 87c1fddf2674c..6da6d55d5dd1c 100644 --- a/pytorch_lightning/metrics/regression/mean_squared_error.py +++ b/pytorch_lightning/metrics/regression/mean_squared_error.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import torch -from typing import Any, Optional +from typing import Any, Callable, Optional from pytorch_lightning.metrics.metric import Metric from pytorch_lightning.metrics.functional.mean_squared_error import ( @@ -50,11 +50,13 @@ def __init__( compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, + dist_sync_fn: Callable = None, ): super().__init__( compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group, + dist_sync_fn=dist_sync_fn, ) self.add_state("sum_squared_error", default=torch.tensor(0.0), dist_reduce_fx="sum") diff --git a/pytorch_lightning/metrics/regression/mean_squared_log_error.py b/pytorch_lightning/metrics/regression/mean_squared_log_error.py index 256fac20365af..696ad01ca829d 100644 --- a/pytorch_lightning/metrics/regression/mean_squared_log_error.py +++ b/pytorch_lightning/metrics/regression/mean_squared_log_error.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import torch -from typing import Any, Optional +from typing import Any, Callable, Optional from pytorch_lightning.metrics.metric import Metric from pytorch_lightning.metrics.functional.mean_squared_log_error import ( @@ -50,11 +50,13 @@ def __init__( compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, + dist_sync_fn: Callable = None, ): super().__init__( compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group, + dist_sync_fn=dist_sync_fn, ) self.add_state("sum_squared_log_error", default=torch.tensor(0.0), dist_reduce_fx="sum") diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index a29fd3e5a1059..98d322ce0a3a2 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -73,7 +73,7 @@ def find_free_network_port() -> int: return port -def gather_all_tensors_if_available(result: Union[torch.Tensor], group: Optional[Any] = None): +def gather_all_tensors(result: Union[torch.Tensor], group: Optional[Any] = None): """ Function to gather all tensors from several ddp processes onto a list that is broadcasted to all processes @@ -85,26 +85,41 @@ def gather_all_tensors_if_available(result: Union[torch.Tensor], group: Optional Return: gathered_result: list with size equal to the process group where gathered_result[i] corresponds to result tensor from process i - """ - if torch.distributed.is_available() and torch.distributed.is_initialized(): - if group is None: - group = torch.distributed.group.WORLD + if group is None: + group = torch.distributed.group.WORLD - world_size = torch.distributed.get_world_size(group) + world_size = torch.distributed.get_world_size(group) - gathered_result = [torch.zeros_like(result) for _ in range(world_size)] + gathered_result = [torch.zeros_like(result) for _ in range(world_size)] - # sync and broadcast all - torch.distributed.barrier(group=group) - torch.distributed.all_gather(gathered_result, result, group) + # sync and broadcast all + torch.distributed.barrier(group=group) + torch.distributed.all_gather(gathered_result, result, group) - result = gathered_result - return result + return gathered_result def sync_ddp_if_available( result: Union[torch.Tensor], group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None +) -> torch.Tensor: + """ + Function to reduce a tensor across worker processes during distributed training + Args: + result: the value to sync and reduce (typically tensor or number) + group: the process group to gather results from. Defaults to all processes (world) + reduce_op: the reduction operation. Defaults to sum. + Can also be a string of 'avg', 'mean' to calculate the mean during reduction. + Return: + reduced value + """ + if torch.distributed.is_available() and torch.distributed.is_initialized(): + return sync_ddp(result, group=group, reduce_op=reduce_op) + return result + + +def sync_ddp( + result: Union[torch.Tensor], group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None ) -> torch.Tensor: """ Function to reduce the tensors from several ddp processes to one master process @@ -118,24 +133,22 @@ def sync_ddp_if_available( Return: reduced value """ + divide_by_world_size = False - if torch.distributed.is_available() and torch.distributed.is_initialized(): - divide_by_world_size = False - - if group is None: - group = torch.distributed.group.WORLD + if group is None: + group = torch.distributed.group.WORLD - if reduce_op is None: - reduce_op = torch.distributed.ReduceOp.SUM - elif isinstance(reduce_op, str) and reduce_op in ("avg", "mean"): - reduce_op = torch.distributed.ReduceOp.SUM - divide_by_world_size = True + if reduce_op is None: + reduce_op = torch.distributed.ReduceOp.SUM + elif isinstance(reduce_op, str) and reduce_op in ("avg", "mean"): + reduce_op = torch.distributed.ReduceOp.SUM + divide_by_world_size = True - # sync all processes before reduction - torch.distributed.barrier(group=group) - torch.distributed.all_reduce(result, op=reduce_op, group=group, async_op=False) + # sync all processes before reduction + torch.distributed.barrier(group=group) + torch.distributed.all_reduce(result, op=reduce_op, group=group, async_op=False) - if divide_by_world_size: - result = result / torch.distributed.get_world_size(group) + if divide_by_world_size: + result = result / torch.distributed.get_world_size(group) return result diff --git a/requirements/extra.txt b/requirements/extra.txt index dbd5f7515109e..be21317a1d826 100644 --- a/requirements/extra.txt +++ b/requirements/extra.txt @@ -1,8 +1,7 @@ # extended list of package dependencies to reach full functionality matplotlib>=3.1.1 -# no need to install with [pytorch] as pytorch is already installed and torchvision is required only for Horovod examples -horovod>=0.20.1 # v0.20.0 has problem with building the wheel/installation +horovod>=0.20.2 # no need to install with [pytorch] as pytorch is already installed omegaconf>=2.0.0 # scipy>=0.13.3 scikit-learn>=0.22.2 diff --git a/tests/README.md b/tests/README.md index 7fd3c90c0241e..8ef006c4d879a 100644 --- a/tests/README.md +++ b/tests/README.md @@ -30,7 +30,7 @@ To test models that require GPU make sure to run the above command on a GPU mach The GPU machine must have: 1. At least 2 GPUs. 2. [NVIDIA-apex](https://github.com/NVIDIA/apex#linux) installed. -3. [Horovod with NCCL](https://horovod.readthedocs.io/en/stable/gpus_include.html) support: `HOROVOD_GPU_ALLREDUCE=NCCL HOROVOD_GPU_BROADCAST=NCCL pip install horovod` +3. [Horovod with NCCL](https://horovod.readthedocs.io/en/stable/gpus_include.html) support: `HOROVOD_GPU_OPERATIONS=NCCL pip install horovod` ## Running Coverage diff --git a/tests/models/test_horovod.py b/tests/models/test_horovod.py index d09d9387ea485..d0ae17d8fee5d 100644 --- a/tests/models/test_horovod.py +++ b/tests/models/test_horovod.py @@ -17,19 +17,25 @@ import shlex import subprocess import sys -from unittest.mock import patch +import numpy as np import pytest import torch +from sklearn.metrics import accuracy_score + import tests.base.develop_pipelines as tpipes import tests.base.develop_utils as tutils from pytorch_lightning import Trainer +from pytorch_lightning.accelerators.horovod_accelerator import HorovodAccelerator +from pytorch_lightning.core.step_result import Result, TrainResult, EvalResult +from pytorch_lightning.metrics.classification.accuracy import Accuracy from pytorch_lightning.utilities import APEX_AVAILABLE, NATIVE_AMP_AVALAIBLE from tests.base import EvalModelTemplate from tests.base.models import BasicGAN try: + import horovod from horovod.common.util import nccl_built except ImportError: HOROVOD_AVAILABLE = False @@ -235,6 +241,111 @@ def get_optimizer_params(optimizer): assert get_model_params(model.generator) == get_optimizer_params(trainer.optimizers[0]) assert get_model_params(model.discriminator) == get_optimizer_params(trainer.optimizers[1]) + +@pytest.mark.skipif(not HOROVOD_AVAILABLE, reason="Horovod is unavailable") +@pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows") +def test_result_reduce_horovod(tmpdir): + """Make sure result logging works with Horovod. + + This test mirrors tests/core/test_results.py::_ddp_test_fn + """ + tutils.reset_seed() + tutils.set_random_master_port() + + def hvd_test_fn(): + path_here = os.path.abspath(os.path.dirname(__file__)) + path_root = os.path.abspath(os.path.join(path_here, '..', '..')) + sys.path.insert(0, os.path.abspath(path_root)) + + from tests.base.boring_model import BoringModel + + import horovod.torch as hvd + + class TestModel(BoringModel): + def training_step(self, batch, batch_idx): + self.training_step_called = True + + tensor = torch.tensor([1.0]) + self.log("test_tensor", tensor, sync_dist=True, sync_dist_op='sum', + on_step=True, on_epoch=True) + + res = self._results + + # Check that `tensor` is summed across all ranks automatically + assert res["test_tensor"].item() == hvd.size(), \ + "Result-Log does not work properly with Horovod and Tensors" + + def training_epoch_end(self, outputs) -> None: + assert len(outputs) == 0 + + model = TestModel() + model.val_dataloader = None + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=1, + log_every_n_steps=1, + weights_summary=None, + ) + + trainer.fit(model) + + horovod.run(hvd_test_fn, np=2) + + +@pytest.mark.skipif(not HOROVOD_AVAILABLE, reason="Horovod is unavailable") +@pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows") +def test_accuracy_metric_horovod(): + num_batches = 10 + batch_size = 16 + threshold = 0.5 + + def sk_metric(preds, target): + sk_preds = (preds.view(-1).numpy() >= threshold).astype(np.uint8) + sk_target = target.view(-1).numpy() + return accuracy_score(y_true=sk_target, y_pred=sk_preds) + + preds = torch.rand(num_batches, batch_size) + target = torch.randint(high=2, size=(num_batches, batch_size)) + + def _compute_batch(): + import horovod.torch as hvd + + trainer = Trainer( + fast_dev_run=True, + distributed_backend='horovod', + ) + + accelerator_backend = trainer.accelerator_connector.select_accelerator() + assert isinstance(accelerator_backend, HorovodAccelerator) + + metric = Accuracy(compute_on_step=True, + dist_sync_on_step=True, + dist_sync_fn=accelerator_backend.gather_all_tensors, + threshold=threshold) + + for i in range(hvd.rank(), num_batches, hvd.size()): + batch_result = metric(preds[i], target[i]) + if hvd.rank() == 0: + dist_preds = torch.stack([preds[i + r] for r in range(hvd.size())]) + dist_target = torch.stack([target[i + r] for r in range(hvd.size())]) + sk_batch_result = sk_metric(dist_preds, dist_target) + assert np.allclose(batch_result.numpy(), sk_batch_result) + + # check on all batches on all ranks + result = metric.compute() + assert isinstance(result, torch.Tensor) + + total_preds = torch.stack([preds[i] for i in range(num_batches)]) + total_target = torch.stack([target[i] for i in range(num_batches)]) + sk_result = sk_metric(total_preds, total_target) + + assert np.allclose(result.numpy(), sk_result) + + horovod.run(_compute_batch, np=2) + # @pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows") # def test_horovod_multi_optimizer_with_scheduling_stepping(tmpdir): # hparams = EvalModelTemplate.get_default_hparams()