From 9fe12e278fe74cd96f0fa3393c1045f837d924f8 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Thu, 4 Nov 2021 10:18:44 +0100 Subject: [PATCH 1/3] fix --- tests/classification/test_auroc.py | 10 +++++++--- torchmetrics/functional/classification/auroc.py | 2 +- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/tests/classification/test_auroc.py b/tests/classification/test_auroc.py index aa8e77f4cf1..f1abe8b23cc 100644 --- a/tests/classification/test_auroc.py +++ b/tests/classification/test_auroc.py @@ -185,11 +185,15 @@ def test_error_multiclass_no_num_classes(): _ = auroc(torch.randn(20, 3).softmax(dim=-1), torch.randint(3, (20,))) -def test_weighted_with_empty_classes(): +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +def test_weighted_with_empty_classes(device): """Tests that weighted multiclass AUROC calculation yields the same results if a new but empty class exists. Tests that the proper warnings and errors are raised """ + if not torch.cuda.is_available() and device=='cuda': + pytest.skip('Test requires gpu to run') + preds = torch.tensor( [ [0.90, 0.05, 0.05], @@ -198,8 +202,8 @@ def test_weighted_with_empty_classes(): [0.85, 0.05, 0.10], [0.10, 0.10, 0.80], ] - ) - target = torch.tensor([0, 1, 1, 2, 2]) + ).to(device) + target = torch.tensor([0, 1, 1, 2, 2]).to(device) num_classes = 3 _auroc = auroc(preds, target, average="weighted", num_classes=num_classes) diff --git a/torchmetrics/functional/classification/auroc.py b/torchmetrics/functional/classification/auroc.py index 8bca3bdf099..259b9cc5ce0 100644 --- a/torchmetrics/functional/classification/auroc.py +++ b/torchmetrics/functional/classification/auroc.py @@ -135,7 +135,7 @@ def _auroc_compute( raise ValueError("Detected input to `multiclass` but you did not provide `num_classes` argument") if average == AverageMethod.WEIGHTED and len(torch.unique(target)) < num_classes: # If one or more classes has 0 observations, we should exclude them, as its weight will be 0 - target_bool_mat = torch.zeros((len(target), num_classes), dtype=bool) + target_bool_mat = torch.zeros((len(target), num_classes), dtype=bool, device=target.device) target_bool_mat[torch.arange(len(target)), target.long()] = 1 class_observed = target_bool_mat.sum(axis=0) > 0 for c in range(num_classes): From d07a1929683cdd6877c9018d89fa4287ecd2cc5f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 4 Nov 2021 09:22:24 +0000 Subject: [PATCH 2/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/classification/test_auroc.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/classification/test_auroc.py b/tests/classification/test_auroc.py index f1abe8b23cc..f616c8c341d 100644 --- a/tests/classification/test_auroc.py +++ b/tests/classification/test_auroc.py @@ -191,8 +191,8 @@ def test_weighted_with_empty_classes(device): Tests that the proper warnings and errors are raised """ - if not torch.cuda.is_available() and device=='cuda': - pytest.skip('Test requires gpu to run') + if not torch.cuda.is_available() and device == "cuda": + pytest.skip("Test requires gpu to run") preds = torch.tensor( [ From b7084d7a7b2e7733d70ba3aacaa0c097cd9204d5 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Thu, 4 Nov 2021 10:23:14 +0100 Subject: [PATCH 3/3] changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5707c02edb0..b4f30533f7d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,6 +26,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fix empty predictions in MAP metric ([#594](https://github.com/PyTorchLightning/metrics/pull/594)) +- Fix edge case of AUROC with `average=weighted` on GPU ([#606](https://github.com/PyTorchLightning/metrics/pull/606)) + + ## [0.6.0] - 2021-10-28 ### Added