diff --git a/CHANGELOG.md b/CHANGELOG.md index 537d5cf61f7..dd66310e14a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -33,6 +33,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +- Fix getitem for metric collection when prefix/postfix is set ([#2430](https://github.com/Lightning-AI/torchmetrics/pull/2430)) + + - Fixed axis names with Precision-Recall curve ([#2462](https://github.com/Lightning-AI/torchmetrics/pull/2462)) diff --git a/src/torchmetrics/collections.py b/src/torchmetrics/collections.py index d6ad1287c58..e4b0dbafd2a 100644 --- a/src/torchmetrics/collections.py +++ b/src/torchmetrics/collections.py @@ -547,6 +547,10 @@ def __getitem__(self, key: str, copy_state: bool = True) -> Metric: """ self._compute_groups_create_state_ref(copy_state) + if self.prefix: + key = key.removeprefix(self.prefix) + if self.postfix: + key = key.removesuffix(self.postfix) return self._modules[key] @staticmethod diff --git a/tests/unittests/bases/test_collections.py b/tests/unittests/bases/test_collections.py index 0e124125509..a677c92ddb1 100644 --- a/tests/unittests/bases/test_collections.py +++ b/tests/unittests/bases/test_collections.py @@ -33,6 +33,7 @@ MultilabelAveragePrecision, ) from torchmetrics.utilities.checks import _allclose_recursive +from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_0 from unittests._helpers import seed_all from unittests._helpers.testers import DummyMetricDiff, DummyMetricMultiOutputDict, DummyMetricSum @@ -150,6 +151,7 @@ def test_metric_collection_args_kwargs(tmpdir): assert metric_collection["DummyMetricDiff"].x == -20 +@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_2_0, reason="Test requires torch 2.0 or higher") @pytest.mark.parametrize( ("prefix", "postfix"), [ @@ -204,6 +206,10 @@ def test_metric_collection_prefix_postfix_args(prefix, postfix): for name in names: assert f"new_prefix_{name}_new_postfix" in out, "postfix argument not working as intended with clone method" + keys = list(new_metric_collection.keys()) + for k in keys: + assert new_metric_collection[k] # check that the keys are valid even with prefix and postfix + def test_metric_collection_repr(): """Test MetricCollection."""