Skip to content

Commit 454e1a3

Browse files
anmolsjoshisdesrozisvfdev-5
authored
Added check_compute_fn argument to EpochMetric and related metrics (#1140)
* Added check_compute_fn argument to EpochMetric and related functions. * Updated docstrings * Added check_compute_fn to _BaseRegressionEpoch * Adding typing hints for check_compute_fn * Update roc_auc.py Co-authored-by: Sylvain Desroziers <sylvain.desroziers@gmail.com> Co-authored-by: vfdev <vfdev.5@gmail.com>
1 parent a10bf6e commit 454e1a3

File tree

10 files changed

+130
-14
lines changed

10 files changed

+130
-14
lines changed

ignite/contrib/metrics/average_precision.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@ class AveragePrecision(EpochMetric):
2222
:class:`~ignite.engine.Engine`'s `process_function`'s output into the
2323
form expected by the metric. This can be useful if, for example, you have a multi-output model and
2424
you want to compute the metric with respect to one of the outputs.
25+
check_compute_fn (bool): Optional default False. If True, `sklearn.metrics.average_precision_score
26+
<http://scikit-learn.org/stable/modules/generated/sklearn.metrics.average_precision_score.html
27+
#sklearn.metrics.average_precision_score>`_ is run on the first batch of data to ensure there are
28+
no issues. User will be warned in case there are any issues computing the function.
2529
2630
AveragePrecision expects y to be comprised of 0's and 1's. y_pred must either be probability estimates or
2731
confidence values. To apply an activation to y_pred, use output_transform as shown below:
@@ -37,5 +41,7 @@ def activated_output_transform(output):
3741
3842
"""
3943

40-
def __init__(self, output_transform=lambda x: x):
41-
super(AveragePrecision, self).__init__(average_precision_compute_fn, output_transform=output_transform)
44+
def __init__(self, output_transform=lambda x: x, check_compute_fn: bool = False):
45+
super(AveragePrecision, self).__init__(
46+
average_precision_compute_fn, output_transform=output_transform, check_compute_fn=check_compute_fn
47+
)

ignite/contrib/metrics/precision_recall_curve.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ class PrecisionRecallCurve(EpochMetric):
2323
:class:`~ignite.engine.Engine`'s `process_function`'s output into the
2424
form expected by the metric. This can be useful if, for example, you have a multi-output model and
2525
you want to compute the metric with respect to one of the outputs.
26+
check_compute_fn (bool): Optional default False. If True, `sklearn.metrics.precision_recall_curve
27+
<http://scikit-learn.org/stable/modules/generated/sklearn.metrics.precision_recall_curve.html
28+
#sklearn.metrics.precision_recall_curve>`_ is run on the first batch of data to ensure there are
29+
no issues. User will be warned in case there are any issues computing the function.
2630
2731
PrecisionRecallCurve expects y to be comprised of 0's and 1's. y_pred must either be probability estimates
2832
or confidence values. To apply an activation to y_pred, use output_transform as shown below:
@@ -38,5 +42,7 @@ def activated_output_transform(output):
3842
3943
"""
4044

41-
def __init__(self, output_transform=lambda x: x):
42-
super(PrecisionRecallCurve, self).__init__(precision_recall_curve_compute_fn, output_transform=output_transform)
45+
def __init__(self, output_transform=lambda x: x, check_compute_fn: bool = False):
46+
super(PrecisionRecallCurve, self).__init__(
47+
precision_recall_curve_compute_fn, output_transform=output_transform, check_compute_fn=check_compute_fn
48+
)

ignite/contrib/metrics/regression/_base.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,10 @@ class _BaseRegressionEpoch(EpochMetric):
5656
# `update` method check the shapes and call internal overloaded method `_update`.
5757
# Class internally stores complete history of predictions and targets of type float32.
5858

59-
def __init__(self, compute_fn, output_transform=lambda x: x):
60-
super(_BaseRegressionEpoch, self).__init__(compute_fn=compute_fn, output_transform=output_transform)
59+
def __init__(self, compute_fn, output_transform=lambda x: x, check_compute_fn: bool = True):
60+
super(_BaseRegressionEpoch, self).__init__(
61+
compute_fn=compute_fn, output_transform=output_transform, check_compute_fn=check_compute_fn
62+
)
6163

6264
def _check_type(self, output):
6365
_check_output_types(output)

ignite/contrib/metrics/roc_auc.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,13 @@ class ROC_AUC(EpochMetric):
3131
3232
Args:
3333
output_transform (callable, optional): a callable that is used to transform the
34-
:class:`~ignite.engine.Engine`'s `process_function`'s output into the
34+
:class:`~ignite.engine.Engine`'s ``process_function``'s output into the
3535
form expected by the metric. This can be useful if, for example, you have a multi-output model and
3636
you want to compute the metric with respect to one of the outputs.
37+
check_compute_fn (bool): Optional default False. If True, `sklearn.metrics.roc_curve
38+
<http://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_auc_score.html#
39+
sklearn.metrics.roc_auc_score>`_ is run on the first batch of data to ensure there are
40+
no issues. User will be warned in case there are any issues computing the function.
3741
3842
ROC_AUC expects y to be comprised of 0's and 1's. y_pred must either be probability estimates or confidence
3943
values. To apply an activation to y_pred, use output_transform as shown below:
@@ -49,8 +53,10 @@ def activated_output_transform(output):
4953
5054
"""
5155

52-
def __init__(self, output_transform=lambda x: x):
53-
super(ROC_AUC, self).__init__(roc_auc_compute_fn, output_transform=output_transform)
56+
def __init__(self, output_transform=lambda x: x, check_compute_fn: bool = False):
57+
super(ROC_AUC, self).__init__(
58+
roc_auc_compute_fn, output_transform=output_transform, check_compute_fn=check_compute_fn
59+
)
5460

5561

5662
class RocCurve(EpochMetric):
@@ -61,9 +67,13 @@ class RocCurve(EpochMetric):
6167
6268
Args:
6369
output_transform (callable, optional): a callable that is used to transform the
64-
:class:`~ignite.engine.Engine`'s `process_function`'s output into the
70+
:class:`~ignite.engine.Engine`'s ``process_function``'s output into the
6571
form expected by the metric. This can be useful if, for example, you have a multi-output model and
6672
you want to compute the metric with respect to one of the outputs.
73+
check_compute_fn (bool): Optional default False. If True, `sklearn.metrics.roc_curve
74+
<http://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_curve.html#
75+
sklearn.metrics.roc_curve>`_ is run on the first batch of data to ensure there are
76+
no issues. User will be warned in case there are any issues computing the function.
6777
6878
RocCurve expects y to be comprised of 0's and 1's. y_pred must either be probability estimates or confidence
6979
values. To apply an activation to y_pred, use output_transform as shown below:
@@ -79,5 +89,7 @@ def activated_output_transform(output):
7989
8090
"""
8191

82-
def __init__(self, output_transform=lambda x: x):
83-
super(RocCurve, self).__init__(roc_auc_curve_compute_fn, output_transform=output_transform)
92+
def __init__(self, output_transform=lambda x: x, check_compute_fn: bool = False):
93+
super(RocCurve, self).__init__(
94+
roc_auc_curve_compute_fn, output_transform=output_transform, check_compute_fn=check_compute_fn
95+
)

ignite/metrics/epoch_metric.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,16 +34,21 @@ class EpochMetric(Metric):
3434
:class:`~ignite.engine.Engine`'s `process_function`'s output into the
3535
form expected by the metric. This can be useful if, for example, you have a multi-output model and
3636
you want to compute the metric with respect to one of the outputs.
37+
check_compute_fn (bool): if True, compute_fn is run on the first batch of data to ensure there are no
38+
issues. If issues exist, user is warned that there might be an issue with the ``compute_fn``.
3739
40+
Warnings:
41+
EpochMetricWarning: User is warned that there are issues with compute_fn on a batch of data processed.
3842
"""
3943

40-
def __init__(self, compute_fn: Callable, output_transform: Callable = lambda x: x):
44+
def __init__(self, compute_fn: Callable, output_transform: Callable = lambda x: x, check_compute_fn: bool = True):
4145

4246
if not callable(compute_fn):
4347
raise TypeError("Argument compute_fn should be callable.")
4448

4549
super(EpochMetric, self).__init__(output_transform=output_transform, device="cpu")
4650
self.compute_fn = compute_fn
51+
self._check_compute_fn = check_compute_fn
4752

4853
def reset(self) -> None:
4954
self._predictions = []
@@ -95,7 +100,7 @@ def update(self, output: Sequence[torch.Tensor]) -> None:
95100
self._targets.append(y)
96101

97102
# Check once the signature and execution of compute_fn
98-
if len(self._predictions) == 1:
103+
if len(self._predictions) == 1 and self._check_compute_fn:
99104
try:
100105
self.compute_fn(self._predictions[0], self._targets[0])
101106
except Exception as e:

tests/ignite/contrib/metrics/regression/test__base.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import torch
44

55
from ignite.contrib.metrics.regression._base import _BaseRegression, _BaseRegressionEpoch
6+
from ignite.metrics.epoch_metric import EpochMetricWarning
67

78

89
def test_base_regression_shapes():
@@ -84,3 +85,18 @@ def test_base_regression_compute_fn():
8485
# Wrong compute function
8586
with pytest.raises(TypeError):
8687
_BaseRegressionEpoch(12345)
88+
89+
90+
def test_check_compute_fn():
91+
def compute_fn(y_preds, y_targets):
92+
raise Exception
93+
94+
em = _BaseRegressionEpoch(compute_fn, check_compute_fn=True)
95+
96+
em.reset()
97+
output1 = (torch.rand(4, 1).float(), torch.randint(0, 2, size=(4, 1), dtype=torch.float32))
98+
with pytest.warns(EpochMetricWarning, match=r"Probably, there can be a problem with `compute_fn`"):
99+
em.update(output1)
100+
101+
em = _BaseRegressionEpoch(compute_fn, check_compute_fn=False)
102+
em.update(output1)

tests/ignite/contrib/metrics/test_precision_recall_curve.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import numpy as np
2+
import pytest
23
import torch
34
from sklearn.metrics import precision_recall_curve
45

56
from ignite.contrib.metrics.precision_recall_curve import PrecisionRecallCurve
67
from ignite.engine import Engine
8+
from ignite.metrics.epoch_metric import EpochMetricWarning
79

810

911
def test_precision_recall_curve():
@@ -89,3 +91,19 @@ def update_fn(engine, batch):
8991
assert np.array_equal(recall, sk_recall)
9092
# assert thresholds almost equal, due to numpy->torch->numpy conversion
9193
np.testing.assert_array_almost_equal(thresholds, sk_thresholds)
94+
95+
96+
def test_check_compute_fn():
97+
y_pred = torch.zeros((8, 13))
98+
y_pred[:, 1] = 1
99+
y_true = torch.zeros_like(y_pred)
100+
output = (y_pred, y_true)
101+
102+
em = PrecisionRecallCurve(check_compute_fn=True)
103+
104+
em.reset()
105+
with pytest.warns(EpochMetricWarning, match=r"Probably, there can be a problem with `compute_fn`"):
106+
em.update(output)
107+
108+
em = PrecisionRecallCurve(check_compute_fn=False)
109+
em.update(output)

tests/ignite/contrib/metrics/test_roc_auc.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import numpy as np
2+
import pytest
23
import torch
34
from sklearn.metrics import roc_auc_score
45

56
from ignite.contrib.metrics import ROC_AUC
67
from ignite.engine import Engine
8+
from ignite.metrics.epoch_metric import EpochMetricWarning
79

810

911
def test_roc_auc_score():
@@ -110,3 +112,19 @@ def update_fn(engine, batch):
110112
roc_auc = engine.run(data, max_epochs=1).metrics["roc_auc"]
111113

112114
assert roc_auc == np_roc_auc
115+
116+
117+
def test_check_compute_fn():
118+
y_pred = torch.zeros((8, 13))
119+
y_pred[:, 1] = 1
120+
y_true = torch.zeros_like(y_pred)
121+
output = (y_pred, y_true)
122+
123+
em = ROC_AUC(check_compute_fn=True)
124+
125+
em.reset()
126+
with pytest.warns(EpochMetricWarning, match=r"Probably, there can be a problem with `compute_fn`"):
127+
em.update(output)
128+
129+
em = ROC_AUC(check_compute_fn=False)
130+
em.update(output)

tests/ignite/contrib/metrics/test_roc_curve.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import numpy as np
2+
import pytest
23
import torch
34
from sklearn.metrics import roc_curve
45

56
from ignite.contrib.metrics.roc_auc import RocCurve
67
from ignite.engine import Engine
8+
from ignite.metrics.epoch_metric import EpochMetricWarning
79

810

911
def test_roc_curve():
@@ -89,3 +91,19 @@ def update_fn(engine, batch):
8991
assert np.array_equal(tpr, sk_tpr)
9092
# assert thresholds almost equal, due to numpy->torch->numpy conversion
9193
np.testing.assert_array_almost_equal(thresholds, sk_thresholds)
94+
95+
96+
def test_check_compute_fn():
97+
y_pred = torch.zeros((8, 13))
98+
y_pred[:, 1] = 1
99+
y_true = torch.zeros_like(y_pred)
100+
output = (y_pred, y_true)
101+
102+
em = RocCurve(check_compute_fn=True)
103+
104+
em.reset()
105+
with pytest.warns(EpochMetricWarning, match=r"Probably, there can be a problem with `compute_fn`"):
106+
em.update(output)
107+
108+
em = RocCurve(check_compute_fn=False)
109+
em.update(output)

tests/ignite/metrics/test_epoch_metric.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,21 @@ def compute_fn(y_preds, y_targets):
144144
EpochMetric(compute_fn)
145145

146146

147+
def test_check_compute_fn():
148+
def compute_fn(y_preds, y_targets):
149+
raise Exception
150+
151+
em = EpochMetric(compute_fn, check_compute_fn=True)
152+
153+
em.reset()
154+
output1 = (torch.rand(4, 3), torch.randint(0, 2, size=(4, 3), dtype=torch.long))
155+
with pytest.warns(EpochMetricWarning, match=r"Probably, there can be a problem with `compute_fn`"):
156+
em.update(output1)
157+
158+
em = EpochMetric(compute_fn, check_compute_fn=False)
159+
em.update(output1)
160+
161+
147162
@pytest.mark.distributed
148163
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
149164
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU")

0 commit comments

Comments
 (0)