diff --git a/ignite/contrib/metrics/regression/_base.py b/ignite/contrib/metrics/regression/_base.py index ba08441a4574..b08cf655e6f5 100644 --- a/ignite/contrib/metrics/regression/_base.py +++ b/ignite/contrib/metrics/regression/_base.py @@ -1,8 +1,10 @@ from abc import abstractmethod +from typing import Callable, Union import torch from ignite.metrics import EpochMetric, Metric +from ignite.metrics.metric import reinit__is_reduced def _check_output_shapes(output): @@ -33,10 +35,11 @@ class _BaseRegression(Metric): # `update` method check the shapes and call internal overloaded # method `_update`. + @reinit__is_reduced def update(self, output): _check_output_shapes(output) _check_output_types(output) - y_pred, y = output + y_pred, y = output[0].detach(), output[1].detach() if y_pred.ndimension() == 2 and y_pred.shape[1] == 1: y_pred = y_pred.squeeze(dim=-1) diff --git a/ignite/contrib/metrics/regression/canberra_metric.py b/ignite/contrib/metrics/regression/canberra_metric.py index 188f5a4cd493..3436e24c2949 100644 --- a/ignite/contrib/metrics/regression/canberra_metric.py +++ b/ignite/contrib/metrics/regression/canberra_metric.py @@ -1,6 +1,9 @@ +from typing import Callable, Union + import torch from ignite.contrib.metrics.regression._base import _BaseRegression +from ignite.metrics.metric import reinit__is_reduced, sync_all_reduce class CanberraMetric(_BaseRegression): @@ -22,13 +25,21 @@ class CanberraMetric(_BaseRegression): """ + def __init__( + self, output_transform: Callable = lambda x: x, device: Union[str, torch.device] = torch.device("cpu") + ): + self._sum_of_errors = None + super(CanberraMetric, self).__init__(output_transform, device) + + @reinit__is_reduced def reset(self): - self._sum_of_errors = 0.0 + self._sum_of_errors = torch.tensor(0.0, device=self._device) def _update(self, output): y_pred, y = output - errors = torch.abs(y.view_as(y_pred) - y_pred) / (torch.abs(y_pred) + torch.abs(y.view_as(y_pred))) - self._sum_of_errors += torch.sum(errors).item() + errors = torch.abs(y - y_pred) / (torch.abs(y_pred) + torch.abs(y)) + self._sum_of_errors += torch.sum(errors).to(self._device) + @sync_all_reduce("_sum_of_errors") def compute(self): - return self._sum_of_errors + return self._sum_of_errors.item() diff --git a/tests/ignite/contrib/metrics/regression/test_canberra_metric.py b/tests/ignite/contrib/metrics/regression/test_canberra_metric.py index 31150f601b81..56c730a964cd 100644 --- a/tests/ignite/contrib/metrics/regression/test_canberra_metric.py +++ b/tests/ignite/contrib/metrics/regression/test_canberra_metric.py @@ -1,8 +1,11 @@ +import os + import numpy as np import pytest import torch from sklearn.neighbors import DistanceMetric +import ignite.distributed as idist from ignite.contrib.metrics.regression import CanberraMetric @@ -58,3 +61,97 @@ def test_compute(): v1 = np.hstack([v1, d]) v2 = np.hstack([v2, ground_truth]) assert canberra.pairwise([v1, v2])[0][1] == pytest.approx(np_sum) + + +def _test_distrib_compute(device): + rank = idist.get_rank() + + canberra = DistanceMetric.get_metric("canberra") + + def _test(metric_device): + metric_device = torch.device(metric_device) + m = CanberraMetric(device=metric_device) + torch.manual_seed(10 + rank) + + y_pred = torch.randint(0, 10, size=(10,), device=device).float() + y = torch.randint(0, 10, size=(10,), device=device).float() + + m.update((y_pred, y)) + + # gather y_pred, y + y_pred = idist.all_gather(y_pred) + y = idist.all_gather(y) + + np_y_pred = y_pred.cpu().numpy() + np_y = y.cpu().numpy() + res = m.compute() + assert canberra.pairwise([np_y_pred, np_y])[0][1] == pytest.approx(res) + + for _ in range(3): + _test("cpu") + if device.type != "xla": + _test(idist.device()) + + +@pytest.mark.distributed +@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") +@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU") +def test_distrib_gpu(distributed_context_single_node_nccl): + device = torch.device("cuda:{}".format(distributed_context_single_node_nccl["local_rank"])) + _test_distrib_compute(device) + + +@pytest.mark.distributed +@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") +def test_distrib_cpu(distributed_context_single_node_gloo): + + device = torch.device("cpu") + _test_distrib_compute(device) + + +@pytest.mark.distributed +@pytest.mark.skipif(not idist.has_hvd_support, reason="Skip if no Horovod dist support") +@pytest.mark.skipif("WORLD_SIZE" in os.environ, reason="Skip if launched as multiproc") +def test_distrib_hvd(gloo_hvd_executor): + + device = torch.device("cpu" if not torch.cuda.is_available() else "cuda") + nproc = 4 if not torch.cuda.is_available() else torch.cuda.device_count() + + gloo_hvd_executor(_test_distrib_compute, (device,), np=nproc, do_init=True) + + +@pytest.mark.multinode_distributed +@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") +@pytest.mark.skipif("MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed") +def test_multinode_distrib_cpu(distributed_context_multi_node_gloo): + device = torch.device("cpu") + _test_distrib_compute(device) + + +@pytest.mark.multinode_distributed +@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") +@pytest.mark.skipif("GPU_MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed") +def test_multinode_distrib_gpu(distributed_context_multi_node_nccl): + device = torch.device("cuda:{}".format(distributed_context_multi_node_nccl["local_rank"])) + _test_distrib_compute(device) + + +@pytest.mark.tpu +@pytest.mark.skipif("NUM_TPU_WORKERS" in os.environ, reason="Skip if NUM_TPU_WORKERS is in env vars") +@pytest.mark.skipif(not idist.has_xla_support, reason="Skip if no PyTorch XLA package") +def test_distrib_single_device_xla(): + device = idist.device() + _test_distrib_compute(device) + + +def _test_distrib_xla_nprocs(index): + device = idist.device() + _test_distrib_compute(device) + + +@pytest.mark.tpu +@pytest.mark.skipif("NUM_TPU_WORKERS" not in os.environ, reason="Skip if no NUM_TPU_WORKERS in env vars") +@pytest.mark.skipif(not idist.has_xla_support, reason="Skip if no PyTorch XLA package") +def test_distrib_xla_nprocs(xmp_executor): + n = int(os.environ["NUM_TPU_WORKERS"]) + xmp_executor(_test_distrib_xla_nprocs, args=(), nprocs=n)