Skip to content
10 changes: 8 additions & 2 deletions ignite/contrib/metrics/average_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ class AveragePrecision(EpochMetric):
:class:`~ignite.engine.Engine`'s `process_function`'s output into the
form expected by the metric. This can be useful if, for example, you have a multi-output model and
you want to compute the metric with respect to one of the outputs.
check_compute_fn (bool): Optional default False. If True, `sklearn.metrics.average_precision_score
<http://scikit-learn.org/stable/modules/generated/sklearn.metrics.average_precision_score.html
#sklearn.metrics.average_precision_score>`_ is run on the first batch of data to ensure there are
no issues. User will be warned in case there are any issues computing the function.

AveragePrecision expects y to be comprised of 0's and 1's. y_pred must either be probability estimates or
confidence values. To apply an activation to y_pred, use output_transform as shown below:
Expand All @@ -37,5 +41,7 @@ def activated_output_transform(output):

"""

def __init__(self, output_transform=lambda x: x):
super(AveragePrecision, self).__init__(average_precision_compute_fn, output_transform=output_transform)
def __init__(self, output_transform=lambda x: x, check_compute_fn: bool = False):
super(AveragePrecision, self).__init__(
average_precision_compute_fn, output_transform=output_transform, check_compute_fn=check_compute_fn
)
10 changes: 8 additions & 2 deletions ignite/contrib/metrics/precision_recall_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ class PrecisionRecallCurve(EpochMetric):
:class:`~ignite.engine.Engine`'s `process_function`'s output into the
form expected by the metric. This can be useful if, for example, you have a multi-output model and
you want to compute the metric with respect to one of the outputs.
check_compute_fn (bool): Optional default False. If True, `sklearn.metrics.precision_recall_curve
<http://scikit-learn.org/stable/modules/generated/sklearn.metrics.precision_recall_curve.html
#sklearn.metrics.precision_recall_curve>`_ is run on the first batch of data to ensure there are
no issues. User will be warned in case there are any issues computing the function.

PrecisionRecallCurve expects y to be comprised of 0's and 1's. y_pred must either be probability estimates
or confidence values. To apply an activation to y_pred, use output_transform as shown below:
Expand All @@ -38,5 +42,7 @@ def activated_output_transform(output):

"""

def __init__(self, output_transform=lambda x: x):
super(PrecisionRecallCurve, self).__init__(precision_recall_curve_compute_fn, output_transform=output_transform)
def __init__(self, output_transform=lambda x: x, check_compute_fn: bool = False):
super(PrecisionRecallCurve, self).__init__(
precision_recall_curve_compute_fn, output_transform=output_transform, check_compute_fn=check_compute_fn
)
6 changes: 4 additions & 2 deletions ignite/contrib/metrics/regression/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,10 @@ class _BaseRegressionEpoch(EpochMetric):
# `update` method check the shapes and call internal overloaded method `_update`.
# Class internally stores complete history of predictions and targets of type float32.

def __init__(self, compute_fn, output_transform=lambda x: x):
super(_BaseRegressionEpoch, self).__init__(compute_fn=compute_fn, output_transform=output_transform)
def __init__(self, compute_fn, output_transform=lambda x: x, check_compute_fn: bool = True):
super(_BaseRegressionEpoch, self).__init__(
compute_fn=compute_fn, output_transform=output_transform, check_compute_fn=check_compute_fn
)

def _check_type(self, output):
_check_output_types(output)
Expand Down
24 changes: 18 additions & 6 deletions ignite/contrib/metrics/roc_auc.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,13 @@ class ROC_AUC(EpochMetric):

Args:
output_transform (callable, optional): a callable that is used to transform the
:class:`~ignite.engine.Engine`'s `process_function`'s output into the
:class:`~ignite.engine.Engine`'s ``process_function``'s output into the
form expected by the metric. This can be useful if, for example, you have a multi-output model and
you want to compute the metric with respect to one of the outputs.
check_compute_fn (bool): Optional default False. If True, `sklearn.metrics.roc_curve
<http://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_auc_score.html#
sklearn.metrics.roc_auc_score>`_ is run on the first batch of data to ensure there are
no issues. User will be warned in case there are any issues computing the function.

ROC_AUC expects y to be comprised of 0's and 1's. y_pred must either be probability estimates or confidence
values. To apply an activation to y_pred, use output_transform as shown below:
Expand All @@ -49,8 +53,10 @@ def activated_output_transform(output):

"""

def __init__(self, output_transform=lambda x: x):
super(ROC_AUC, self).__init__(roc_auc_compute_fn, output_transform=output_transform)
def __init__(self, output_transform=lambda x: x, check_compute_fn: bool = False):
super(ROC_AUC, self).__init__(
roc_auc_compute_fn, output_transform=output_transform, check_compute_fn=check_compute_fn
)


class RocCurve(EpochMetric):
Expand All @@ -61,9 +67,13 @@ class RocCurve(EpochMetric):

Args:
output_transform (callable, optional): a callable that is used to transform the
:class:`~ignite.engine.Engine`'s `process_function`'s output into the
:class:`~ignite.engine.Engine`'s ``process_function``'s output into the
form expected by the metric. This can be useful if, for example, you have a multi-output model and
you want to compute the metric with respect to one of the outputs.
check_compute_fn (bool): Optional default False. If True, `sklearn.metrics.roc_curve
<http://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_curve.html#
sklearn.metrics.roc_curve>`_ is run on the first batch of data to ensure there are
no issues. User will be warned in case there are any issues computing the function.

RocCurve expects y to be comprised of 0's and 1's. y_pred must either be probability estimates or confidence
values. To apply an activation to y_pred, use output_transform as shown below:
Expand All @@ -79,5 +89,7 @@ def activated_output_transform(output):

"""

def __init__(self, output_transform=lambda x: x):
super(RocCurve, self).__init__(roc_auc_curve_compute_fn, output_transform=output_transform)
def __init__(self, output_transform=lambda x: x, check_compute_fn: bool = False):
super(RocCurve, self).__init__(
roc_auc_curve_compute_fn, output_transform=output_transform, check_compute_fn=check_compute_fn
)
9 changes: 7 additions & 2 deletions ignite/metrics/epoch_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,21 @@ class EpochMetric(Metric):
:class:`~ignite.engine.Engine`'s `process_function`'s output into the
form expected by the metric. This can be useful if, for example, you have a multi-output model and
you want to compute the metric with respect to one of the outputs.
check_compute_fn (bool): if True, compute_fn is run on the first batch of data to ensure there are no
issues. If issues exist, user is warned that there might be an issue with the ``compute_fn``.

Warnings:
EpochMetricWarning: User is warned that there are issues with compute_fn on a batch of data processed.
"""

def __init__(self, compute_fn: Callable, output_transform: Callable = lambda x: x):
def __init__(self, compute_fn: Callable, output_transform: Callable = lambda x: x, check_compute_fn: bool = True):

if not callable(compute_fn):
raise TypeError("Argument compute_fn should be callable.")

super(EpochMetric, self).__init__(output_transform=output_transform, device="cpu")
self.compute_fn = compute_fn
self._check_compute_fn = check_compute_fn

def reset(self) -> None:
self._predictions = []
Expand Down Expand Up @@ -95,7 +100,7 @@ def update(self, output: Sequence[torch.Tensor]) -> None:
self._targets.append(y)

# Check once the signature and execution of compute_fn
if len(self._predictions) == 1:
if len(self._predictions) == 1 and self._check_compute_fn:
try:
self.compute_fn(self._predictions[0], self._targets[0])
except Exception as e:
Expand Down
16 changes: 16 additions & 0 deletions tests/ignite/contrib/metrics/regression/test__base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch

from ignite.contrib.metrics.regression._base import _BaseRegression, _BaseRegressionEpoch
from ignite.metrics.epoch_metric import EpochMetricWarning


def test_base_regression_shapes():
Expand Down Expand Up @@ -84,3 +85,18 @@ def test_base_regression_compute_fn():
# Wrong compute function
with pytest.raises(TypeError):
_BaseRegressionEpoch(12345)


def test_check_compute_fn():
def compute_fn(y_preds, y_targets):
raise Exception

em = _BaseRegressionEpoch(compute_fn, check_compute_fn=True)

em.reset()
output1 = (torch.rand(4, 1).float(), torch.randint(0, 2, size=(4, 1), dtype=torch.float32))
with pytest.warns(EpochMetricWarning, match=r"Probably, there can be a problem with `compute_fn`"):
em.update(output1)

em = _BaseRegressionEpoch(compute_fn, check_compute_fn=False)
em.update(output1)
18 changes: 18 additions & 0 deletions tests/ignite/contrib/metrics/test_precision_recall_curve.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import numpy as np
import pytest
import torch
from sklearn.metrics import precision_recall_curve

from ignite.contrib.metrics.precision_recall_curve import PrecisionRecallCurve
from ignite.engine import Engine
from ignite.metrics.epoch_metric import EpochMetricWarning


def test_precision_recall_curve():
Expand Down Expand Up @@ -89,3 +91,19 @@ def update_fn(engine, batch):
assert np.array_equal(recall, sk_recall)
# assert thresholds almost equal, due to numpy->torch->numpy conversion
np.testing.assert_array_almost_equal(thresholds, sk_thresholds)


def test_check_compute_fn():
y_pred = torch.zeros((8, 13))
y_pred[:, 1] = 1
y_true = torch.zeros_like(y_pred)
output = (y_pred, y_true)

em = PrecisionRecallCurve(check_compute_fn=True)

em.reset()
with pytest.warns(EpochMetricWarning, match=r"Probably, there can be a problem with `compute_fn`"):
em.update(output)

em = PrecisionRecallCurve(check_compute_fn=False)
em.update(output)
18 changes: 18 additions & 0 deletions tests/ignite/contrib/metrics/test_roc_auc.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import numpy as np
import pytest
import torch
from sklearn.metrics import roc_auc_score

from ignite.contrib.metrics import ROC_AUC
from ignite.engine import Engine
from ignite.metrics.epoch_metric import EpochMetricWarning


def test_roc_auc_score():
Expand Down Expand Up @@ -110,3 +112,19 @@ def update_fn(engine, batch):
roc_auc = engine.run(data, max_epochs=1).metrics["roc_auc"]

assert roc_auc == np_roc_auc


def test_check_compute_fn():
y_pred = torch.zeros((8, 13))
y_pred[:, 1] = 1
y_true = torch.zeros_like(y_pred)
output = (y_pred, y_true)

em = ROC_AUC(check_compute_fn=True)

em.reset()
with pytest.warns(EpochMetricWarning, match=r"Probably, there can be a problem with `compute_fn`"):
em.update(output)

em = ROC_AUC(check_compute_fn=False)
em.update(output)
18 changes: 18 additions & 0 deletions tests/ignite/contrib/metrics/test_roc_curve.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import numpy as np
import pytest
import torch
from sklearn.metrics import roc_curve

from ignite.contrib.metrics.roc_auc import RocCurve
from ignite.engine import Engine
from ignite.metrics.epoch_metric import EpochMetricWarning


def test_roc_curve():
Expand Down Expand Up @@ -89,3 +91,19 @@ def update_fn(engine, batch):
assert np.array_equal(tpr, sk_tpr)
# assert thresholds almost equal, due to numpy->torch->numpy conversion
np.testing.assert_array_almost_equal(thresholds, sk_thresholds)


def test_check_compute_fn():
y_pred = torch.zeros((8, 13))
y_pred[:, 1] = 1
y_true = torch.zeros_like(y_pred)
output = (y_pred, y_true)

em = RocCurve(check_compute_fn=True)

em.reset()
with pytest.warns(EpochMetricWarning, match=r"Probably, there can be a problem with `compute_fn`"):
em.update(output)

em = RocCurve(check_compute_fn=False)
em.update(output)
15 changes: 15 additions & 0 deletions tests/ignite/metrics/test_epoch_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,21 @@ def compute_fn(y_preds, y_targets):
EpochMetric(compute_fn)


def test_check_compute_fn():
def compute_fn(y_preds, y_targets):
raise Exception

em = EpochMetric(compute_fn, check_compute_fn=True)

em.reset()
output1 = (torch.rand(4, 3), torch.randint(0, 2, size=(4, 3), dtype=torch.long))
with pytest.warns(EpochMetricWarning, match=r"Probably, there can be a problem with `compute_fn`"):
em.update(output1)

em = EpochMetric(compute_fn, check_compute_fn=False)
em.update(output1)


@pytest.mark.distributed
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU")
Expand Down