diff --git a/CHANGELOG.md b/CHANGELOG.md index 7db6036c871..7da61d6c634 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -42,8 +42,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed mAP calculation for areas with 0 predictions ([#1080](https://github.com/PyTorchLightning/metrics/pull/1080)) - -- +- Fixed bug where avg precision state and auroc state was not merge when using MetricCollections ([#1086](https://github.com/PyTorchLightning/metrics/pull/1086)) ## [0.9.1] - 2022-06-08 diff --git a/src/torchmetrics/functional/classification/average_precision.py b/src/torchmetrics/functional/classification/average_precision.py index 93662e021ef..db256cd00d0 100644 --- a/src/torchmetrics/functional/classification/average_precision.py +++ b/src/torchmetrics/functional/classification/average_precision.py @@ -43,14 +43,8 @@ def _average_precision_update( average: reduction method for multi-class or multi-label problems """ preds, target, num_classes, pos_label = _precision_recall_curve_update(preds, target, num_classes, pos_label) - if average == "micro": - if preds.ndim == target.ndim: - # Considering each element of the label indicator matrix as a label - preds = preds.flatten() - target = target.flatten() - num_classes = 1 - else: - raise ValueError("Cannot use `micro` average with multi-class input") + if average == "micro" and preds.ndim != target.ndim: + raise ValueError("Cannot use `micro` average with multi-class input") return preds, target, num_classes, pos_label @@ -97,6 +91,11 @@ def _average_precision_compute( """ # todo: `sample_weights` is unused + if average == "micro" and preds.ndim == target.ndim: + preds = preds.flatten() + target = target.flatten() + num_classes = 1 + precision, recall, _ = _precision_recall_curve_compute(preds, target, num_classes, pos_label) if average == "weighted": if preds.ndim == target.ndim and target.ndim > 1: diff --git a/test/unittests/bases/test_collections.py b/test/unittests/bases/test_collections.py index 405ab9a3d08..34b9d1af1c1 100644 --- a/test/unittests/bases/test_collections.py +++ b/test/unittests/bases/test_collections.py @@ -19,7 +19,9 @@ import torch from torchmetrics import ( + AUROC, Accuracy, + AveragePrecision, CohenKappa, ConfusionMatrix, F1Score, @@ -29,6 +31,7 @@ Precision, Recall, ) +from torchmetrics.utilities.checks import _allclose_recursive from unittests.helpers import seed_all from unittests.helpers.testers import DummyMetricDiff, DummyMetricSum @@ -267,6 +270,8 @@ def test_collection_filtering(): """Test that collections works with the kwargs argument.""" class DummyMetric(Metric): + full_state_update = True + def __init__(self): super().__init__() @@ -277,6 +282,8 @@ def compute(self): return class MyAccuracy(Metric): + full_state_update = True + def __init__(self): super().__init__() @@ -292,21 +299,30 @@ def compute(self): mc2(torch.tensor([0, 1]), torch.tensor([0, 1]), kwarg="kwarg", kwarg2="kwarg2") +# function for generating +_mc_preds = torch.randn(10, 3).softmax(dim=-1) +_mc_target = torch.randint(3, (10,)) +_ml_preds = torch.rand(10, 3) +_ml_target = torch.randint(2, (10, 3)) + + @pytest.mark.parametrize( - "metrics, expected", + "metrics, expected, preds, target", [ # single metric forms its own compute group - (Accuracy(3), {0: ["Accuracy"]}), + (Accuracy(3), {0: ["Accuracy"]}, _mc_preds, _mc_target), # two metrics of same class forms a compute group - ({"acc0": Accuracy(3), "acc1": Accuracy(3)}, {0: ["acc0", "acc1"]}), + ({"acc0": Accuracy(3), "acc1": Accuracy(3)}, {0: ["acc0", "acc1"]}, _mc_preds, _mc_target), # two metrics from registry froms a compute group - ([Precision(3), Recall(3)], {0: ["Precision", "Recall"]}), + ([Precision(3), Recall(3)], {0: ["Precision", "Recall"]}, _mc_preds, _mc_target), # two metrics from different classes gives two compute groups - ([ConfusionMatrix(3), Recall(3)], {0: ["ConfusionMatrix"], 1: ["Recall"]}), + ([ConfusionMatrix(3), Recall(3)], {0: ["ConfusionMatrix"], 1: ["Recall"]}, _mc_preds, _mc_target), # multi group multi metric ( [ConfusionMatrix(3), CohenKappa(3), Recall(3), Precision(3)], {0: ["ConfusionMatrix", "CohenKappa"], 1: ["Recall", "Precision"]}, + _mc_preds, + _mc_target, ), # Complex example ( @@ -319,6 +335,33 @@ def compute(self): "confmat": ConfusionMatrix(3), }, {0: ["acc", "acc2", "f1", "recall"], 1: ["acc3"], 2: ["confmat"]}, + _mc_preds, + _mc_target, + ), + # With list states + ( + [AUROC(average="macro", num_classes=3), AveragePrecision(average="macro", num_classes=3)], + {0: ["AUROC", "AveragePrecision"]}, + _mc_preds, + _mc_target, + ), + # Nested collections + ( + [ + MetricCollection( + AUROC(average="micro", num_classes=3), + AveragePrecision(average="micro", num_classes=3), + postfix="_micro", + ), + MetricCollection( + AUROC(average="macro", num_classes=3), + AveragePrecision(average="macro", num_classes=3), + postfix="_macro", + ), + ], + {0: ["AUROC_micro", "AveragePrecision_micro", "AUROC_macro", "AveragePrecision_macro"]}, + _ml_preds, + _ml_target, ), ], ) @@ -332,8 +375,10 @@ class TestComputeGroups: ["prefix_", "_postfix"], ], ) - def test_check_compute_groups_correctness(self, metrics, expected, prefix, postfix): + def test_check_compute_groups_correctness(self, metrics, expected, preds, target, prefix, postfix): """Check that compute groups are formed after initialization and that metrics are correctly computed.""" + if isinstance(metrics, MetricCollection): + prefix, postfix = None, None # disable for nested collections m = MetricCollection(deepcopy(metrics), prefix=prefix, postfix=postfix, compute_groups=True) # Construct without for comparison m2 = MetricCollection(deepcopy(metrics), prefix=prefix, postfix=postfix, compute_groups=False) @@ -342,8 +387,6 @@ def test_check_compute_groups_correctness(self, metrics, expected, prefix, postf assert m2.compute_groups == {} for _ in range(2): # repeat to emulate effect of multiple epochs - preds = torch.randn(10, 3).softmax(dim=-1) - target = torch.randint(3, (10,)) m.update(preds, target) m2.update(preds, target) @@ -353,8 +396,6 @@ def test_check_compute_groups_correctness(self, metrics, expected, prefix, postf assert m.compute_groups == expected assert m2.compute_groups == {} - preds = torch.randn(10, 3).softmax(dim=-1) - target = torch.randint(3, (10,)) # compute groups should kick in here m.update(preds, target) m2.update(preds, target) @@ -372,7 +413,7 @@ def test_check_compute_groups_correctness(self, metrics, expected, prefix, postf m2.reset() @pytest.mark.parametrize("method", ["items", "values", "keys"]) - def test_check_compute_groups_items_and_values(self, metrics, expected, method): + def test_check_compute_groups_items_and_values(self, metrics, expected, preds, target, method): """Check that whenever user call a methods that give access to the indivitual metric that state are copied instead of just passed by reference.""" m = MetricCollection(deepcopy(metrics), compute_groups=True) @@ -380,14 +421,12 @@ def test_check_compute_groups_items_and_values(self, metrics, expected, method): for _ in range(2): # repeat to emulate effect of multiple epochs for _ in range(2): # repeat to emulate effect of multiple batches - preds = torch.randn(10, 3).softmax(dim=-1) - target = torch.randint(3, (10,)) m.update(preds, target) m2.update(preds, target) def _compare(m1, m2): for state in m1._defaults: - assert torch.allclose(getattr(m1, state), getattr(m2, state)) + assert _allclose_recursive(getattr(m1, state), getattr(m2, state)) # if states are still by reference the reset will make following metrics fail m1.reset() m2.reset() diff --git a/test/unittests/bases/test_composition.py b/test/unittests/bases/test_composition.py index 436b01d026b..9d5fc7f4605 100644 --- a/test/unittests/bases/test_composition.py +++ b/test/unittests/bases/test_composition.py @@ -22,6 +22,8 @@ class DummyMetric(Metric): + full_state_update = True + def __init__(self, val_to_return): super().__init__() self.add_state("_num_updates", tensor(0), dist_reduce_fx="sum")