Skip to content
10 changes: 6 additions & 4 deletions ignite/contrib/metrics/regression/_base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from abc import abstractmethod
from typing import Callable, Union

import torch

from ignite.metrics import EpochMetric, Metric
from ignite.metrics import EpochMetric, Metric, reinit__is_reduced


def _check_output_shapes(output):
Expand Down Expand Up @@ -33,16 +34,17 @@ class _BaseRegression(Metric):
# `update` method check the shapes and call internal overloaded
# method `_update`.

@reinit__is_reduced
def update(self, output):
_check_output_shapes(output)
_check_output_types(output)
y_pred, y = output
y_pred, y = output[0].detach(), output[1].detach()

if y_pred.ndimension() == 2 and y_pred.shape[1] == 1:
y_pred = y_pred.squeeze(dim=-1)
y_pred = y_pred.squeeze(dim=-1).to(self._device)

if y.ndimension() == 2 and y.shape[1] == 1:
y = y.squeeze(dim=-1)
y = y.squeeze(dim=-1).to(self._device)

self._update((y_pred, y))

Expand Down
17 changes: 14 additions & 3 deletions ignite/contrib/metrics/regression/canberra_metric.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from typing import Callable, Union

import torch

from ignite.contrib.metrics.regression._base import _BaseRegression
from ignite.metrics import reinit__is_reduced, sync_all_reduce


class CanberraMetric(_BaseRegression):
Expand All @@ -22,13 +25,21 @@ class CanberraMetric(_BaseRegression):

"""

def __init__(
self, output_transform: Callable = lambda x: x, device: Union[str, torch.device] = torch.device("cpu")
):
self._sum_of_errors = None
super(CanberraMetric, self).__init__(output_transform, device)

@reinit__is_reduced
def reset(self):
self._sum_of_errors = 0.0
self._sum_of_errors = torch.tensor(0.0, device=self._device)

def _update(self, output):
y_pred, y = output
errors = torch.abs(y.view_as(y_pred) - y_pred) / (torch.abs(y_pred) + torch.abs(y.view_as(y_pred)))
self._sum_of_errors += torch.sum(errors).item()
self._sum_of_errors += torch.sum(errors).to(self._device)

@sync_all_reduce("_sum_of_errors")
def compute(self):
return self._sum_of_errors
return self._sum_of_errors.item()
12 changes: 11 additions & 1 deletion ignite/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,15 @@
from ignite.metrics.mean_absolute_error import MeanAbsoluteError
from ignite.metrics.mean_pairwise_distance import MeanPairwiseDistance
from ignite.metrics.mean_squared_error import MeanSquaredError
from ignite.metrics.metric import BatchFiltered, BatchWise, EpochWise, Metric, MetricUsage
from ignite.metrics.metric import (
BatchFiltered,
BatchWise,
EpochWise,
Metric,
MetricUsage,
reinit__is_reduced,
sync_all_reduce,
)
from ignite.metrics.metrics_lambda import MetricsLambda
from ignite.metrics.precision import Precision
from ignite.metrics.recall import Recall
Expand Down Expand Up @@ -41,4 +49,6 @@
"VariableAccumulation",
"Frequency",
"SSIM",
"reinit__is_reduced",
"sync_all_reduce",
]
2 changes: 1 addition & 1 deletion ignite/metrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import ignite.distributed as idist
from ignite.engine import Engine, Events

__all__ = ["Metric", "MetricUsage", "EpochWise", "BatchWise", "BatchFiltered"]
__all__ = ["Metric", "MetricUsage", "EpochWise", "BatchWise", "BatchFiltered", "reinit__is_reduced", "sync_all_reduce"]


class MetricUsage:
Expand Down
97 changes: 97 additions & 0 deletions tests/ignite/contrib/metrics/regression/test_canberra_metric.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import os

import numpy as np
import pytest
import torch
from sklearn.neighbors import DistanceMetric

import ignite.distributed as idist
from ignite.contrib.metrics.regression import CanberraMetric


Expand Down Expand Up @@ -58,3 +61,97 @@ def test_compute():
v1 = np.hstack([v1, d])
v2 = np.hstack([v2, ground_truth])
assert canberra.pairwise([v1, v2])[0][1] == pytest.approx(np_sum)


def _test_distrib_compute(device):
rank = idist.get_rank()

canberra = DistanceMetric.get_metric("canberra")

def _test(metric_device):
metric_device = torch.device(metric_device)
m = CanberraMetric(device=metric_device)
torch.manual_seed(10 + rank)

y_pred = torch.randint(0, 10, size=(10,), device=device).float()
y = torch.randint(0, 10, size=(10,), device=device).float()

m.update((y_pred, y))

# gather y_pred, y
y_pred = idist.all_gather(y_pred)
y = idist.all_gather(y)

np_y_pred = y_pred.cpu().numpy()
np_y = y.cpu().numpy()
res = m.compute()
assert canberra.pairwise([np_y_pred, np_y])[0][1] == pytest.approx(res)

for _ in range(3):
_test("cpu")
if device.type != "xla":
_test(idist.device())


@pytest.mark.distributed
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU")
def test_distrib_gpu(distributed_context_single_node_nccl):
device = torch.device("cuda:{}".format(distributed_context_single_node_nccl["local_rank"]))
_test_distrib_compute(device)


@pytest.mark.distributed
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
def test_distrib_cpu(distributed_context_single_node_gloo):

device = torch.device("cpu")
_test_distrib_compute(device)


@pytest.mark.distributed
@pytest.mark.skipif(not idist.has_hvd_support, reason="Skip if no Horovod dist support")
@pytest.mark.skipif("WORLD_SIZE" in os.environ, reason="Skip if launched as multiproc")
def test_distrib_hvd(gloo_hvd_executor):

device = torch.device("cpu" if not torch.cuda.is_available() else "cuda")
nproc = 4 if not torch.cuda.is_available() else torch.cuda.device_count()

gloo_hvd_executor(_test_distrib_compute, (device,), np=nproc, do_init=True)


@pytest.mark.multinode_distributed
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
@pytest.mark.skipif("MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed")
def test_multinode_distrib_cpu(distributed_context_multi_node_gloo):
device = torch.device("cpu")
_test_distrib_compute(device)


@pytest.mark.multinode_distributed
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
@pytest.mark.skipif("GPU_MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed")
def test_multinode_distrib_gpu(distributed_context_multi_node_nccl):
device = torch.device("cuda:{}".format(distributed_context_multi_node_nccl["local_rank"]))
_test_distrib_compute(device)


@pytest.mark.tpu
@pytest.mark.skipif("NUM_TPU_WORKERS" in os.environ, reason="Skip if NUM_TPU_WORKERS is in env vars")
@pytest.mark.skipif(not idist.has_xla_support, reason="Skip if no PyTorch XLA package")
def test_distrib_single_device_xla():
device = idist.device()
_test_distrib_compute(device)


def _test_distrib_xla_nprocs(index):
device = idist.device()
_test_distrib_compute(device)


@pytest.mark.tpu
@pytest.mark.skipif("NUM_TPU_WORKERS" not in os.environ, reason="Skip if no NUM_TPU_WORKERS in env vars")
@pytest.mark.skipif(not idist.has_xla_support, reason="Skip if no PyTorch XLA package")
def test_distrib_xla_nprocs(xmp_executor):
n = int(os.environ["NUM_TPU_WORKERS"])
xmp_executor(_test_distrib_xla_nprocs, args=(), nprocs=n)