From 99cbfa465261cf2562cbf2aa1676377aa28dbc48 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Thu, 13 Jul 2023 10:28:37 +0200 Subject: [PATCH 1/4] fix --- src/torchmetrics/wrappers/classwise.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/torchmetrics/wrappers/classwise.py b/src/torchmetrics/wrappers/classwise.py index 3e2d8eb7fc2..4d9530b1361 100644 --- a/src/torchmetrics/wrappers/classwise.py +++ b/src/torchmetrics/wrappers/classwise.py @@ -128,22 +128,22 @@ def __init__( if prefix is not None and not isinstance(prefix, str): raise ValueError(f"Expected argument `prefix` to either be `None` or a string but got {prefix}") - self.prefix = prefix + self._prefix = prefix if postfix is not None and not isinstance(postfix, str): raise ValueError(f"Expected argument `postfix` to either be `None` or a string but got {postfix}") - self.postfix = postfix + self._postfix = postfix self._update_count = 1 def _convert(self, x: Tensor) -> Dict[str, Any]: # Will set the class name as prefix if neither prefix nor postfix is given - if not self.prefix and not self.postfix: + if not self._prefix and not self._postfix: prefix = f"{self.metric.__class__.__name__.lower()}_" postfix = "" else: - prefix = self.prefix or "" - postfix = self.postfix or "" + prefix = self._prefix or "" + postfix = self._postfix or "" if self.labels is None: return {f"{prefix}{i}{postfix}": val for i, val in enumerate(x)} return {f"{prefix}{lab}{postfix}": val for lab, val in zip(self.labels, x)} From fb4b30aeedebe54c24416f11639a58fe67dcf0b1 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Thu, 13 Jul 2023 10:29:14 +0200 Subject: [PATCH 2/4] tests --- tests/unittests/wrappers/test_classwise.py | 31 +++++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/tests/unittests/wrappers/test_classwise.py b/tests/unittests/wrappers/test_classwise.py index 88e71c411f7..b18efe7e1f5 100644 --- a/tests/unittests/wrappers/test_classwise.py +++ b/tests/unittests/wrappers/test_classwise.py @@ -1,7 +1,7 @@ import pytest import torch from torchmetrics import MetricCollection -from torchmetrics.classification import MulticlassAccuracy, MulticlassRecall +from torchmetrics.classification import MulticlassAccuracy, MulticlassF1Score, MulticlassRecall from torchmetrics.wrappers import ClasswiseWrapper @@ -120,3 +120,32 @@ def _get_correct_name(base): assert name in val name = _get_correct_name(f"multiclassrecall_{lab}") assert name in val + + +def test_double_use_of_prefix_with_metriccollection(): + """Test that the expected output is produced when using prefix/postfix with metric collection. + + See issue: https://github.com/Lightning-AI/torchmetrics/issues/1915 + """ + category_names = ["Tree", "Bush"] + num_classes = len(category_names) + + input_ = torch.rand((5, 2, 3, 3)) + target = torch.ones((5, 2, 3, 3)).long() + + val_metrics = MetricCollection( + { + "accuracy": MulticlassAccuracy(num_classes=num_classes), + "f1": ClasswiseWrapper( + MulticlassF1Score(num_classes=num_classes, average="none"), + category_names, + prefix="f_score_", + ), + }, + prefix="val/", + ) + + res = val_metrics(input_, target) + assert "val/accuracy" in res + assert "val/f_score_Tree" in res + assert "val/f_score_Bush" in res From 4cb768f93aa9af9dc23369d3199dc9eaeef84e26 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Thu, 13 Jul 2023 10:32:47 +0200 Subject: [PATCH 3/4] changelog --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index a02577387dd..87192fbf8ee 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -44,6 +44,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed bug related to expected input format of pycoco in `MeanAveragePrecision` ([#1913](https://github.com/Lightning-AI/torchmetrics/pull/1913)) +- Fixed bug related to the `prefix/postfix` arguments in `MetricCollection` and `ClasswiseWrapper` being duplicated ([#1918](https://github.com/Lightning-AI/torchmetrics/pull/1918)) + ## [1.0.0] - 2022-07-04 ### Added From 85888e20ecc2cd45a7874f4b83eb750f3df94405 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Tue, 25 Jul 2023 09:51:54 +0200 Subject: [PATCH 4/4] docformatter --- tests/unittests/wrappers/test_classwise.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/unittests/wrappers/test_classwise.py b/tests/unittests/wrappers/test_classwise.py index b18efe7e1f5..5c72362b545 100644 --- a/tests/unittests/wrappers/test_classwise.py +++ b/tests/unittests/wrappers/test_classwise.py @@ -126,6 +126,7 @@ def test_double_use_of_prefix_with_metriccollection(): """Test that the expected output is produced when using prefix/postfix with metric collection. See issue: https://github.com/Lightning-AI/torchmetrics/issues/1915 + """ category_names = ["Tree", "Bush"] num_classes = len(category_names)