diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index 9127d9b62f6854..2a4b4d0d25eb45 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -10,7 +10,7 @@ from torch import nn from pytorch_lightning.utilities.apply_func import apply_to_collection -from pytorch_lightning.utilities.distributed import gather_all_tensors_if_available +from pytorch_lightning.utilities.distributed import gather_all_tensors_if_available, is_distributed from pytorch_lightning.metrics.utils import _flatten, dim_zero_cat, dim_zero_mean, dim_zero_sum @@ -179,11 +179,7 @@ def wrapped_func(*args, **kwargs): if self._computed is not None: return self._computed - if ( - self._to_sync - and torch.distributed.is_available() # noqa: W503 - and torch.distributed.is_initialized() # noqa: W503 - ): + if self._to_sync and is_distributed: self._sync_dist() self._computed = compute(*args, **kwargs) diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index 3b03ac2772021e..426babfa7b2e8a 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -81,6 +81,11 @@ def find_free_network_port() -> int: return port +def is_distributed(): + return (torch.distributed.is_available() and torch.distributed.is_initialized()) or \ + (HOROVOD_AVAILABLE and hvd.is_initialized()) + + def gather_all_tensors_if_available(result: Union[torch.Tensor], group: Optional[Any] = None): """ Function to gather all tensors from several distributed processes onto a list that @@ -124,9 +129,15 @@ def gather_horovod(result: Union[torch.Tensor], group: Optional[Any] = None): "Unset `group`." ) + if len(result.shape) == 0: + # Convert scalars to single dimension tensors + result = result.reshape(1) + # sync and gather all hvd.join() - return hvd.allgather(result) + gathered = hvd.allgather(result) + gathered_result = list(gathered.split(1, dim=0)) + return gathered_result def sync_dist_if_available( @@ -182,6 +193,8 @@ def sync_ddp( if divide_by_world_size: result = result / torch.distributed.get_world_size(group) + return result + def sync_horovod( result: Union[torch.Tensor], group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None diff --git a/tests/models/test_horovod.py b/tests/models/test_horovod.py index 54e427897f7154..054472da90d0a0 100644 --- a/tests/models/test_horovod.py +++ b/tests/models/test_horovod.py @@ -5,13 +5,17 @@ import subprocess import sys +import numpy as np import pytest import torch +from sklearn.metrics import accuracy_score + import tests.base.develop_pipelines as tpipes import tests.base.develop_utils as tutils from pytorch_lightning import Trainer from pytorch_lightning.core.step_result import Result, TrainResult, EvalResult +from pytorch_lightning.metrics.classification.accuracy import Accuracy from tests.base import EvalModelTemplate from tests.base.models import BasicGAN @@ -198,6 +202,7 @@ def get_optimizer_params(optimizer): @pytest.mark.parametrize("result_cls", [Result, TrainResult, EvalResult]) +@pytest.mark.skipif(not HOROVOD_AVAILABLE, reason="Horovod is unavailable") @pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows") def test_result_reduce_horovod(result_cls): """Make sure result logging works with Horovod.""" @@ -217,6 +222,48 @@ def hvd_test_fn(): horovod.run(hvd_test_fn, np=2) + +def test_accuracy_metric_horovod(): + num_batches = 10 + batch_size = 16 + threshold = 0.5 + + def sk_metric(preds, target): + sk_preds = (preds.view(-1).numpy() >= threshold).astype(np.uint8) + sk_target = target.view(-1).numpy() + return accuracy_score(y_true=sk_target, y_pred=sk_preds) + + preds = torch.rand(num_batches, batch_size) + target = torch.randint(high=2, size=(num_batches, batch_size)) + + def _compute_batch(): + import horovod.torch as hvd + hvd.init() + + metric = Accuracy(compute_on_step=True, + dist_sync_on_step=True, + threshold=threshold) + + for i in range(hvd.rank(), num_batches, hvd.size()): + batch_result = metric(preds[i], target[i]) + if hvd.rank() == 0: + dist_preds = torch.stack([preds[i + r] for r in range(hvd.size())]) + dist_target = torch.stack([target[i + r] for r in range(hvd.size())]) + sk_batch_result = sk_metric(dist_preds, dist_target) + assert np.allclose(batch_result.numpy(), sk_batch_result) + + # check on all batches on all ranks + result = metric.compute() + assert isinstance(result, torch.Tensor) + + total_preds = torch.stack([preds[i] for i in range(num_batches)]) + total_target = torch.stack([target[i] for i in range(num_batches)]) + sk_result = sk_metric(total_preds, total_target) + + assert np.allclose(result.numpy(), sk_result) + + horovod.run(_compute_batch, np=2) + # @pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows") # def test_horovod_multi_optimizer_with_scheduling_stepping(tmpdir): # hparams = EvalModelTemplate.get_default_hparams()