From de6697a2130e0ba4742e2aa1dd268ff89f320d0d Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 5 Mar 2024 16:54:06 +0100 Subject: [PATCH 1/5] fix implementation --- src/torchmetrics/collections.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/torchmetrics/collections.py b/src/torchmetrics/collections.py index 06b9b7b4c4e..33165fe97a0 100644 --- a/src/torchmetrics/collections.py +++ b/src/torchmetrics/collections.py @@ -547,6 +547,10 @@ def __getitem__(self, key: str, copy_state: bool = True) -> Metric: """ self._compute_groups_create_state_ref(copy_state) + if self.prefix is not None: + key = key.removeprefix(self.prefix) + if self.postfix is not None: + key = key.removesuffix(self.postfix) return self._modules[key] @staticmethod From 6ec207aa286047eeba9a5de30cc9477963db14c8 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 5 Mar 2024 16:54:54 +0100 Subject: [PATCH 2/5] tests --- tests/unittests/bases/test_collections.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/unittests/bases/test_collections.py b/tests/unittests/bases/test_collections.py index 16c95fc879a..404422fca18 100644 --- a/tests/unittests/bases/test_collections.py +++ b/tests/unittests/bases/test_collections.py @@ -204,6 +204,10 @@ def test_metric_collection_prefix_postfix_args(prefix, postfix): for name in names: assert f"new_prefix_{name}_new_postfix" in out, "postfix argument not working as intended with clone method" + keys = list(new_metric_collection.keys()) + for k in keys: + assert new_metric_collection[k] # check that the keys are valid even with prefix and postfix + def test_metric_collection_repr(): """Test MetricCollection.""" From 25e7b8f07fb3215ec8cd3936bda9ae848143ffd3 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 5 Mar 2024 16:56:48 +0100 Subject: [PATCH 3/5] changelog --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3f2aa69ba29..63c8cafbac1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -36,6 +36,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed dtype being changed by deepspeed for certain regression metrics ([#2379](https://github.com/Lightning-AI/torchmetrics/pull/2379)) +- Fix getitem for metric collection when prefix/postfix is set ([#2430](https://github.com/Lightning-AI/torchmetrics/pull/2430)) + ## [1.3.1] - 2024-02-12 ### Fixed From 3418529c7e91a0fc88354003855b6d6a39ddf60d Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Wed, 6 Mar 2024 19:12:15 +0100 Subject: [PATCH 4/5] Update src/torchmetrics/collections.py Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com> --- src/torchmetrics/collections.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/torchmetrics/collections.py b/src/torchmetrics/collections.py index 33165fe97a0..f303845999c 100644 --- a/src/torchmetrics/collections.py +++ b/src/torchmetrics/collections.py @@ -547,9 +547,9 @@ def __getitem__(self, key: str, copy_state: bool = True) -> Metric: """ self._compute_groups_create_state_ref(copy_state) - if self.prefix is not None: + if self.prefix: key = key.removeprefix(self.prefix) - if self.postfix is not None: + if self.postfix: key = key.removesuffix(self.postfix) return self._modules[key] From f6e5aa138e9779baf40ed9877c5bf854bb78f750 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Sun, 24 Mar 2024 15:21:34 +0100 Subject: [PATCH 5/5] skip test on older versions --- tests/unittests/bases/test_collections.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/unittests/bases/test_collections.py b/tests/unittests/bases/test_collections.py index 3beea475f0a..a677c92ddb1 100644 --- a/tests/unittests/bases/test_collections.py +++ b/tests/unittests/bases/test_collections.py @@ -33,6 +33,7 @@ MultilabelAveragePrecision, ) from torchmetrics.utilities.checks import _allclose_recursive +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_0 from unittests._helpers import seed_all from unittests._helpers.testers import DummyMetricDiff, DummyMetricMultiOutputDict, DummyMetricSum @@ -150,6 +151,7 @@ def test_metric_collection_args_kwargs(tmpdir): assert metric_collection["DummyMetricDiff"].x == -20 +@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_2_0, reason="Test requires torch 2.0 or higher") @pytest.mark.parametrize( ("prefix", "postfix"), [