Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add distributed test to RocCurve metric #2802

Merged
merged 10 commits into from
Feb 16, 2023
51 changes: 48 additions & 3 deletions ignite/contrib/metrics/roc_auc.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from typing import Any, Callable, Tuple, Union
from typing import Any, Callable, cast, Tuple, Union

import torch

from ignite import distributed as idist
from ignite.exceptions import NotComputableError
from ignite.metrics import EpochMetric


Expand Down Expand Up @@ -103,6 +105,8 @@ class RocCurve(EpochMetric):
<https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_curve.html#
sklearn.metrics.roc_curve>`_ is run on the first batch of data to ensure there are
no issues. User will be warned in case there are any issues computing the function.
device: optional device specification for internal storage.

Note:
RocCurve expects y to be comprised of 0's and 1's. y_pred must either be probability estimates or confidence
values. To apply an activation to y_pred, use output_transform as shown below:
Expand Down Expand Up @@ -137,15 +141,56 @@ def sigmoid_output_transform(output):
FPR [0.0, 0.333, 0.333, 1.0]
TPR [0.0, 0.0, 1.0, 1.0]
Thresholds [2.0, 1.0, 0.711, 0.047]

.. versionchanged:: 0.4.11
added `device` argument
"""

def __init__(self, output_transform: Callable = lambda x: x, check_compute_fn: bool = False) -> None:
def __init__(
self,
output_transform: Callable = lambda x: x,
check_compute_fn: bool = False,
device: Union[str, torch.device] = torch.device("cpu"),
) -> None:

try:
from sklearn.metrics import roc_curve # noqa: F401
except ImportError:
raise ModuleNotFoundError("This contrib module requires scikit-learn to be installed.")

super(RocCurve, self).__init__(
roc_auc_curve_compute_fn, output_transform=output_transform, check_compute_fn=check_compute_fn
roc_auc_curve_compute_fn,
output_transform=output_transform,
check_compute_fn=check_compute_fn,
device=device,
)

def compute(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if len(self._predictions) < 1 or len(self._targets) < 1:
raise NotComputableError("RocCurve must have at least one example before it can be computed.")

_prediction_tensor = torch.cat(self._predictions, dim=0)
_target_tensor = torch.cat(self._targets, dim=0)

ws = idist.get_world_size()
if ws > 1:
# All gather across all processes
_prediction_tensor = cast(torch.Tensor, idist.all_gather(_prediction_tensor))
_target_tensor = cast(torch.Tensor, idist.all_gather(_target_tensor))

if idist.get_rank() == 0:
# Run compute_fn on zero rank only
fpr, tpr, thresholds = self.compute_fn(_prediction_tensor, _target_tensor)
fpr = torch.tensor(fpr)
tpr = torch.tensor(tpr)
thresholds = torch.tensor(thresholds)
else:
fpr, tpr, thresholds = None, None, None

if ws > 1:
# broadcast result to all processes
fpr = idist.broadcast(fpr, src=0, safe_mode=True)
tpr = idist.broadcast(tpr, src=0, safe_mode=True)
thresholds = idist.broadcast(thresholds, src=0, safe_mode=True)

return fpr, tpr, thresholds
45 changes: 45 additions & 0 deletions tests/ignite/contrib/metrics/test_roc_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,22 @@
import torch
from sklearn.metrics import roc_curve

from ignite import distributed as idist
from ignite.contrib.metrics.roc_auc import RocCurve
from ignite.engine import Engine
from ignite.exceptions import NotComputableError
from ignite.metrics.epoch_metric import EpochMetricWarning


def test_wrong_setup():
def compute_fn(y_preds, y_targets):
return 0.0

with pytest.raises(NotComputableError, match="RocCurve must have at least one example before it can be computed"):
metric = RocCurve(compute_fn)
metric.compute()


@pytest.fixture()
def mock_no_sklearn():
with patch.dict("sys.modules", {"sklearn.metrics": None}):
Expand Down Expand Up @@ -121,3 +132,37 @@ def test_check_compute_fn():

em = RocCurve(check_compute_fn=False)
em.update(output)


def test_distrib_integration(distributed):
rank = idist.get_rank()
torch.manual_seed(41 + rank)
n_batches, batch_size = 5, 10
y = torch.randint(0, 2, size=(n_batches * batch_size,))
y_pred = torch.rand((n_batches * batch_size,))

def update(engine, i):
return (
y_pred[i * batch_size : (i + 1) * batch_size],
y[i * batch_size : (i + 1) * batch_size],
)

engine = Engine(update)

device = "cpu" if idist.device().type == "xla" else idist.device()
metric = RocCurve(device=device)
metric.attach(engine, "roc_curve")

data = list(range(n_batches))

engine.run(data=data, max_epochs=1)

fpr, tpr, thresholds = engine.state.metrics["roc_curve"]

y = idist.all_gather(y)
y_pred = idist.all_gather(y_pred)
sk_fpr, sk_tpr, sk_thresholds = roc_curve(y, y_pred)

assert np.array_equal(fpr, sk_fpr)
assert np.array_equal(tpr, sk_tpr)
np.testing.assert_array_almost_equal(thresholds, sk_thresholds)