From 08439b37ac402a7fdf01183663cba6e1e2144c2d Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Sun, 9 Jul 2023 15:44:52 +0200 Subject: [PATCH 1/2] fix --- .../functional/classification/auroc.py | 2 +- tests/unittests/classification/test_auroc.py | 14 ++++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/src/torchmetrics/functional/classification/auroc.py b/src/torchmetrics/functional/classification/auroc.py index 94ca84877f0..8d0157fe7aa 100644 --- a/src/torchmetrics/functional/classification/auroc.py +++ b/src/torchmetrics/functional/classification/auroc.py @@ -86,7 +86,7 @@ def _binary_auroc_compute( pos_label: int = 1, ) -> Tensor: fpr, tpr, _ = _binary_roc_compute(state, thresholds, pos_label) - if max_fpr is None or max_fpr == 1: + if max_fpr is None or max_fpr == 1 or fpr.sum() == 0 or tpr.sum() == 0: return _auc_compute_without_check(fpr, tpr, 1.0) _device = fpr.device if isinstance(fpr, Tensor) else fpr[0].device diff --git a/tests/unittests/classification/test_auroc.py b/tests/unittests/classification/test_auroc.py index 0fdf7834da0..e51f564b02f 100644 --- a/tests/unittests/classification/test_auroc.py +++ b/tests/unittests/classification/test_auroc.py @@ -393,3 +393,17 @@ def test_valid_input_thresholds(metric, thresholds): with pytest.warns(None) as record: metric(thresholds=thresholds) assert len(record) == 0 + + +@pytest.mark.parametrize("max_fpr", [None, 0.8, 0.5]) +def test_corner_case_max_fpr(max_fpr): + """Check that metric returns 0 when one class is missing and `max_fpr` is set.""" + preds = torch.tensor([0.1, 0.2, 0.3, 0.4]) + target = torch.tensor([0, 0, 0, 0]) + metric = BinaryAUROC(max_fpr=max_fpr) + assert metric(preds, target) == 0.0 + + preds = torch.tensor([0.5, 0.6, 0.7, 0.8]) + target = torch.tensor([1, 1, 1, 1]) + metric = BinaryAUROC(max_fpr=max_fpr) + assert metric(preds, target) == 0.0 From 7cd50b7493dd47398c2b340aeb933ab9e788fd69 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Sun, 9 Jul 2023 15:46:47 +0200 Subject: [PATCH 2/2] changelog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index eedd710f6de..990b7901a79 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,7 +26,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed -- +- Fixed the use of `max_fpr` in `AUROC` metric when only one class is present ([#1895](https://github.com/Lightning-AI/torchmetrics/pull/1895)) ## [1.0.0] - 2022-07-04