diff --git a/CHANGELOG.md b/CHANGELOG.md index d194c345e5d..a20b8e622d8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed +- `MetricCollection` should return metrics with prefix on `items()`, `keys()` ([#209](https://github.com/PyTorchLightning/metrics/pull/209)) + + - Calling `compute` before `update` will now give an warning ([#164](https://github.com/PyTorchLightning/metrics/pull/164)) diff --git a/tests/bases/test_collections.py b/tests/bases/test_collections.py index cb3d3d2c4b1..8ee8dae77f3 100644 --- a/tests/bases/test_collections.py +++ b/tests/bases/test_collections.py @@ -168,6 +168,21 @@ def test_metric_collection_prefix_postfix_args(prefix, postfix): for name in names: assert f"new_prefix_{name}" in out, 'prefix argument not working as intended with clone method' + for k, _ in new_metric_collection.items(): + assert 'new_prefix_' in k + + for k in new_metric_collection.keys(): + assert 'new_prefix_' in k + + for k, _ in new_metric_collection.items(keep_base=True): + assert 'new_prefix_' not in k + + for k in new_metric_collection.keys(keep_base=True): + assert 'new_prefix_' not in k + + assert type(new_metric_collection.keys(keep_base=True)) == type(new_metric_collection.keys(keep_base=False)) # noqa E721 + assert type(new_metric_collection.items(keep_base=True)) == type(new_metric_collection.items(keep_base=False)) # noqa E721 + new_metric_collection = new_metric_collection.clone(postfix='_new_postfix') out = new_metric_collection(5) names = [n[:-len(postfix)] if postfix is not None else n for n in names] # strip away old postfix @@ -175,6 +190,38 @@ def test_metric_collection_prefix_postfix_args(prefix, postfix): assert f"new_prefix_{name}_new_postfix" in out, 'postfix argument not working as intended with clone method' +def test_metric_collection_repr(): + """ + Test MetricCollection + """ + + class A(DummyMetricSum): + pass + + class B(DummyMetricDiff): + pass + + m1 = A() + m2 = B() + metric_collection = MetricCollection([m1, m2], prefix=None, postfix=None) + + expected = "MetricCollection(\n (A): A()\n (B): B()\n)" + assert metric_collection.__repr__() == expected + + metric_collection = MetricCollection([m1, m2], prefix="a", postfix=None) + + expected = 'MetricCollection(\n (A): A()\n (B): B(),\n prefix=a\n)' + assert metric_collection.__repr__() == expected + + metric_collection = MetricCollection([m1, m2], prefix=None, postfix="a") + expected = 'MetricCollection(\n (A): A()\n (B): B(),\n postfix=a\n)' + assert metric_collection.__repr__() == expected + + metric_collection = MetricCollection([m1, m2], prefix="a", postfix="b") + expected = 'MetricCollection(\n (A): A()\n (B): B(),\n prefix=a,\n postfix=b\n)' + assert metric_collection.__repr__() == expected + + def test_metric_collection_same_order(): m1 = DummyMetricSum() m2 = DummyMetricDiff() diff --git a/torchmetrics/collections.py b/torchmetrics/collections.py index fd7ce8945b9..105b3d282ad 100644 --- a/torchmetrics/collections.py +++ b/torchmetrics/collections.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections import OrderedDict from copy import deepcopy -from typing import Any, Dict, Optional, Sequence, Union +from typing import Any, Dict, Iterable, Optional, Sequence, Tuple, Union from torch import nn @@ -144,7 +145,7 @@ def forward(self, *args, **kwargs) -> Dict[str, Any]: # pylint: disable=E0202 be passed to every metric in the collection, while keyword arguments (kwargs) will be filtered based on the signature of the individual metric. """ - return {self._set_name(k): m(*args, **m._filter_kwargs(**kwargs)) for k, m in self.items()} + return {k: m(*args, **m._filter_kwargs(**kwargs)) for k, m in self.items()} def update(self, *args, **kwargs): # pylint: disable=E0202 """ @@ -152,16 +153,16 @@ def update(self, *args, **kwargs): # pylint: disable=E0202 be passed to every metric in the collection, while keyword arguments (kwargs) will be filtered based on the signature of the individual metric. """ - for _, m in self.items(): + for _, m in self.items(keep_base=True): m_kwargs = m._filter_kwargs(**kwargs) m.update(*args, **m_kwargs) def compute(self) -> Dict[str, Any]: - return {self._set_name(k): m.compute() for k, m in self.items()} + return {k: m.compute() for k, m in self.items()} def reset(self) -> None: """ Iteratively call reset for each metric """ - for _, m in self.items(): + for _, m in self.items(keep_base=True): m.reset() def clone(self, prefix: Optional[str] = None, postfix: Optional[str] = None) -> 'MetricCollection': @@ -182,7 +183,7 @@ def persistent(self, mode: bool = True) -> None: """Method for post-init to change if metric states should be saved to its state_dict """ - for _, m in self.items(): + for _, m in self.items(keep_base=True): m.persistent(mode) def _set_name(self, base: str) -> str: @@ -190,8 +191,40 @@ def _set_name(self, base: str) -> str: name = name if self.postfix is None else name + self.postfix return name + def _to_renamed_ordered_dict(self) -> OrderedDict: + od = OrderedDict() + for k, v in self._modules.items(): + od[self._set_name(k)] = v + return od + + def keys(self, keep_base: bool = False): + r"""Return an iterable of the ModuleDict key. + Args: + keep_base: Whether to add prefix/postfix on the items collection. + """ + if keep_base: + return self._modules.keys() + return self._to_renamed_ordered_dict().keys() + + def items(self, keep_base: bool = False) -> Iterable[Tuple[str, nn.Module]]: + r"""Return an iterable of the ModuleDict key/value pairs. + Args: + keep_base: Whether to add prefix/postfix on the items collection. + """ + if keep_base: + return self._modules.items() + return self._to_renamed_ordered_dict().items() + @staticmethod def _check_arg(arg: Optional[str], name: str) -> Optional[str]: if arg is None or isinstance(arg, str): return arg raise ValueError(f'Expected input `{name}` to be a string, but got {type(arg)}') + + def __repr__(self) -> Optional[str]: + repr = super().__repr__()[:-2] + if self.prefix: + repr += f",\n prefix={self.prefix}{',' if self.postfix else ''}" + if self.postfix: + repr += f"{',' if not self.prefix else ''}\n postfix={self.postfix}" + return repr + "\n)"