Skip to content

Commit

Permalink
Add possibility to provide distributed_available_fn (#1301)
Browse files Browse the repository at this point in the history
  • Loading branch information
justusschock authored Oct 31, 2022
1 parent a42a25a commit 5bbad47
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 3 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Added `TotalVariation` to image package ([#978](https://github.com/Lightning-AI/metrics/pull/978))

- Added option to pass `distributed_available_fn` to metrics to allow checks for custom communication backend for making `dist_sync_fn` actually useful ([#1301](https://github.com/Lightning-AI/metrics/pull/1301))


### Changed

Expand Down
11 changes: 9 additions & 2 deletions src/torchmetrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ class Metric(Module, ABC):
- process_group: The process group on which the synchronization is called. Default is the world.
- dist_sync_fn: function that performs the allgather option on the metric state. Default is an
custom implementation that calls ``torch.distributed.all_gather`` internally.
- distributed_available_fn: function that checks if the distributed backend is available.
Defaults to a check of ``torch.distributed.is_available()`` and ``torch.distributed.is_initialized()``.
- sync_on_compute: If metric state should synchronize when ``compute`` is called. Default is ``True``-
"""

Expand Down Expand Up @@ -110,6 +112,8 @@ def __init__(
f"Expected keyword argument `dist_sync_fn` to be an callable function but got {self.dist_sync_fn}"
)

self.distributed_available_fn = kwargs.pop("distributed_available_fn", jit_distributed_available)

self.sync_on_compute = kwargs.pop("sync_on_compute", True)
if not isinstance(self.sync_on_compute, bool):
raise ValueError(
Expand Down Expand Up @@ -421,7 +425,7 @@ def sync(
dist_sync_fn: Optional[Callable] = None,
process_group: Optional[Any] = None,
should_sync: bool = True,
distributed_available: Optional[Callable] = jit_distributed_available,
distributed_available: Optional[Callable] = None,
) -> None:
"""Sync function for manually controlling when metrics states should be synced across processes.
Expand All @@ -437,6 +441,9 @@ def sync(
if self._is_synced and should_sync:
raise TorchMetricsUserError("The Metric has already been synced.")

if distributed_available is None and self.distributed_available_fn is not None:
distributed_available = self.distributed_available_fn

is_distributed = distributed_available() if callable(distributed_available) else None

if not should_sync or not is_distributed:
Expand Down Expand Up @@ -481,7 +488,7 @@ def sync_context(
process_group: Optional[Any] = None,
should_sync: bool = True,
should_unsync: bool = True,
distributed_available: Optional[Callable] = jit_distributed_available,
distributed_available: Optional[Callable] = None,
) -> Generator:
"""Context manager to synchronize the states between processes when running in a distributed setting and
restore the local cache states after yielding.
Expand Down
17 changes: 16 additions & 1 deletion tests/unittests/bases/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import os
import pickle
from collections import OrderedDict
from unittest.mock import Mock

import cloudpickle
import numpy as np
Expand All @@ -23,7 +24,7 @@
from torch import Tensor, tensor
from torch.nn import Module

from torchmetrics import PearsonCorrCoef
from torchmetrics import Accuracy, PearsonCorrCoef
from unittests.helpers import seed_all
from unittests.helpers.testers import DummyListMetric, DummyMetric, DummyMetricMultiOutput, DummyMetricSum
from unittests.helpers.utilities import no_warning_call
Expand Down Expand Up @@ -449,3 +450,17 @@ def forward(self, *args, **kwargs):
match="Torchmetrics v0.9 introduced a new argument class property called.*",
):
UnsetProperty()


def test_custom_availability_check_and_sync_fn():
dummy_availability_check = Mock(return_value=True)
dummy_dist_sync_fn = Mock(wraps=lambda x, group: [x])
acc = Accuracy(dist_sync_fn=dummy_dist_sync_fn, distributed_available_fn=dummy_availability_check)

acc.update(torch.tensor([[1], [1], [1], [1]]), torch.tensor([[1], [1], [1], [1]]))
dummy_dist_sync_fn.assert_not_called()
dummy_availability_check.assert_not_called()

acc.compute()
dummy_availability_check.assert_called_once()
assert dummy_dist_sync_fn.call_count == 4 # tp, fp, tn, fn

0 comments on commit 5bbad47

Please sign in to comment.