diff --git a/ignite/metrics/precision.py b/ignite/metrics/precision.py index 912ff813cd43..50142b7f78a3 100644 --- a/ignite/metrics/precision.py +++ b/ignite/metrics/precision.py @@ -20,14 +20,6 @@ def __init__( is_multilabel: bool = False, device: Union[str, torch.device] = torch.device("cpu"), ): - if idist.get_world_size() > 1: - if (not average) and is_multilabel: - warnings.warn( - "Precision/Recall metrics do not work in distributed setting when average=False " - "and is_multilabel=True. Results are not reduced across computing devices. Computed result " - "corresponds to the local rank's (single process) result.", - RuntimeWarning, - ) self._average = average self.eps = 1e-20 @@ -53,12 +45,14 @@ def compute(self) -> Union[torch.Tensor, float]: raise NotComputableError( f"{self.__class__.__name__} must have at least one example before it can be computed." ) - - if not (self._type == "multilabel" and not self._average): - if not self._is_reduced: + if not self._is_reduced: + if not (self._type == "multilabel" and not self._average): self._true_positives = idist.all_reduce(self._true_positives) # type: ignore[assignment] self._positives = idist.all_reduce(self._positives) # type: ignore[assignment] - self._is_reduced = True # type: bool + else: + self._true_positives = cast(torch.Tensor, idist.all_gather(self._true_positives)) + self._positives = cast(torch.Tensor, idist.all_gather(self._positives)) + self._is_reduced = True # type: bool result = self._true_positives / (self._positives + self.eps) @@ -107,11 +101,6 @@ def thresholded_output_transform(output): as tensors before computing a metric. This can potentially lead to a memory error if the input data is larger than available RAM. - .. warning:: - - In multilabel cases, if average is False, current implementation does not work with distributed computations. - Results are not reduced across the GPUs. Computed result corresponds to the local rank's (single GPU) result. - Args: output_transform (callable, optional): a callable that is used to transform the diff --git a/ignite/metrics/recall.py b/ignite/metrics/recall.py index 69e16155b6e6..a11cb7d583bf 100644 --- a/ignite/metrics/recall.py +++ b/ignite/metrics/recall.py @@ -48,11 +48,6 @@ def thresholded_output_transform(output): as tensors before computing a metric. This can potentially lead to a memory error if the input data is larger than available RAM. - .. warning:: - - In multilabel cases, if average is False, current implementation does not work with distributed computations. - Results are not reduced across the GPUs. Computed result corresponds to the local rank's (single GPU) result. - Args: output_transform (callable, optional): a callable that is used to transform the diff --git a/tests/ignite/metrics/test_precision.py b/tests/ignite/metrics/test_precision.py index 3c5e4af784a5..a4bc56cdd92b 100644 --- a/tests/ignite/metrics/test_precision.py +++ b/tests/ignite/metrics/test_precision.py @@ -792,7 +792,7 @@ def update(engine, i): engine = Engine(update) - pr = Precision(average=average, is_multilabel=True) + pr = Precision(average=average, is_multilabel=True, device=metric_device) pr.attach(engine, "pr") data = list(range(n_iters)) @@ -808,13 +808,13 @@ def update(engine, i): else: assert res == res2 + np_y_preds = to_numpy_multilabel(y_preds) + np_y_true = to_numpy_multilabel(y_true) + assert pr._type == "multilabel" + res = res if average else res.mean().item() with warnings.catch_warnings(): warnings.simplefilter("ignore", category=UndefinedMetricWarning) - true_res = precision_score( - to_numpy_multilabel(y_true), to_numpy_multilabel(y_preds), average="samples" if average else None - ) - - assert pytest.approx(res) == true_res + assert precision_score(np_y_true, np_y_preds, average="samples") == pytest.approx(res) metric_devices = ["cpu"] if device.type != "xla": @@ -823,22 +823,16 @@ def update(engine, i): for metric_device in metric_devices: _test(average=True, n_epochs=1, metric_device=metric_device) _test(average=True, n_epochs=2, metric_device=metric_device) + _test(average=False, n_epochs=1, metric_device=metric_device) + _test(average=False, n_epochs=2, metric_device=metric_device) - if idist.get_world_size() > 1: - with pytest.warns( - RuntimeWarning, - match="Precision/Recall metrics do not work in distributed setting when " - "average=False and is_multilabel=True", - ): - pr = Precision(average=False, is_multilabel=True) - - y_pred = torch.randint(0, 2, size=(4, 3, 6, 8)) - y = torch.randint(0, 2, size=(4, 3, 6, 8)).long() - pr.update((y_pred, y)) - pr_compute1 = pr.compute() - pr_compute2 = pr.compute() - assert len(pr_compute1) == 4 * 6 * 8 - assert (pr_compute1 == pr_compute2).all() + pr1 = Precision(is_multilabel=True, average=True) + pr2 = Precision(is_multilabel=True, average=False) + y_pred = torch.randint(0, 2, size=(10, 4, 20, 23)) + y = torch.randint(0, 2, size=(10, 4, 20, 23)).long() + pr1.update((y_pred, y)) + pr2.update((y_pred, y)) + assert pr1.compute() == pytest.approx(pr2.compute().mean().item()) def _test_distrib_accumulator_device(device): diff --git a/tests/ignite/metrics/test_recall.py b/tests/ignite/metrics/test_recall.py index 537cea38fb79..70c1ba2dd23b 100644 --- a/tests/ignite/metrics/test_recall.py +++ b/tests/ignite/metrics/test_recall.py @@ -808,13 +808,13 @@ def update(engine, i): else: assert res == res2 + np_y_preds = to_numpy_multilabel(y_preds) + np_y_true = to_numpy_multilabel(y_true) + assert re._type == "multilabel" + res = res if average else res.mean().item() with warnings.catch_warnings(): warnings.simplefilter("ignore", category=UndefinedMetricWarning) - true_res = recall_score( - to_numpy_multilabel(y_true), to_numpy_multilabel(y_preds), average="samples" if average else None - ) - - assert pytest.approx(res) == true_res + assert recall_score(np_y_true, np_y_preds, average="samples") == pytest.approx(res) metric_devices = ["cpu"] if device.type != "xla": @@ -823,22 +823,16 @@ def update(engine, i): for metric_device in metric_devices: _test(average=True, n_epochs=1, metric_device=metric_device) _test(average=True, n_epochs=2, metric_device=metric_device) + _test(average=False, n_epochs=1, metric_device=metric_device) + _test(average=False, n_epochs=2, metric_device=metric_device) - if idist.get_world_size() > 1: - with pytest.warns( - RuntimeWarning, - match="Precision/Recall metrics do not work in distributed setting when " - "average=False and is_multilabel=True", - ): - re = Recall(average=False, is_multilabel=True) - - y_pred = torch.randint(0, 2, size=(4, 3, 6, 8)) - y = torch.randint(0, 2, size=(4, 3, 6, 8)).long() - re.update((y_pred, y)) - re_compute1 = re.compute() - re_compute2 = re.compute() - assert len(re_compute1) == 4 * 6 * 8 - assert (re_compute1 == re_compute2).all() + re1 = Recall(is_multilabel=True, average=True) + re2 = Recall(is_multilabel=True, average=False) + y_pred = torch.randint(0, 2, size=(10, 4, 20, 23)) + y = torch.randint(0, 2, size=(10, 4, 20, 23)).long() + re1.update((y_pred, y)) + re2.update((y_pred, y)) + assert re1.compute() == pytest.approx(re2.compute().mean().item()) def _test_distrib_accumulator_device(device):