Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bugfix] MetricCollection should return metrics with prefix on items(), keys() #209

Merged
merged 30 commits into from
May 4, 2021
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
b5d69b7
update
tchaton Apr 28, 2021
137da8e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 28, 2021
c7a4008
update print
tchaton Apr 28, 2021
5ae8d8e
resolve flake8
tchaton Apr 28, 2021
8a5156f
Merge branch 'metrics_collection' of https://github.com/PyTorchLightn…
tchaton Apr 28, 2021
6af03f3
update
tchaton Apr 28, 2021
953e8ab
Merge branch 'master' into metrics_collection
SkafteNicki Apr 28, 2021
5deb58f
Merge branch 'master' into metrics_collection
Borda Apr 28, 2021
452ccd8
CI: update pre-commit (#207)
Borda Apr 28, 2021
77614f6
fix setup imports (#208)
Borda Apr 28, 2021
9ed3315
chlog
Borda Apr 28, 2021
fcadfc0
Merge branch 'metrics_collection' of https://github.com/PyTorchLightn…
Borda Apr 28, 2021
10e1c84
format
Borda Apr 28, 2021
76c5f49
update
tchaton Apr 28, 2021
599ee65
Merge branch 'metrics_collection' of https://github.com/PyTorchLightn…
tchaton Apr 28, 2021
59e6ab8
Merge branch 'master' into metrics_collection
Borda Apr 28, 2021
d1c0fdd
update
tchaton Apr 29, 2021
dbb565a
Merge branch 'metrics_collection' of https://github.com/PyTorchLightn…
tchaton Apr 29, 2021
dc9d12a
check type
tchaton Apr 29, 2021
ab19250
update on comments
tchaton Apr 29, 2021
0076b69
update on comments
tchaton Apr 29, 2021
685061b
Merge branch 'master' into metrics_collection
tchaton Apr 30, 2021
0a88348
Merge branch 'master' into metrics_collection
tchaton Apr 30, 2021
ebe9d98
Merge branch 'master' into metrics_collection
Borda Apr 30, 2021
281d22d
Merge branch 'master' into metrics_collection
mergify[bot] Apr 30, 2021
79fcf67
Merge branch 'master' into metrics_collection
mergify[bot] Apr 30, 2021
12e90fc
Merge branch 'master' into metrics_collection
mergify[bot] May 3, 2021
b349ae6
Merge branch 'master' into metrics_collection
Borda May 3, 2021
2f2dfae
Merge branch 'master' into metrics_collection
mergify[bot] May 3, 2021
a1596a6
Merge branch 'master' into metrics_collection
mergify[bot] May 3, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions tests/bases/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,21 @@ def test_metric_collection_prefix_postfix_args(prefix, postfix):
for name in names:
assert f"new_prefix_{name}" in out, 'prefix argument not working as intended with clone method'

for k, _ in new_metric_collection.items():
assert 'new_prefix_' in k

for k in new_metric_collection.keys():
assert 'new_prefix_' in k

for k, _ in new_metric_collection.items(keep_base=True):
assert 'new_prefix_' not in k

for k in new_metric_collection.keys(keep_base=True):
assert 'new_prefix_' not in k

repr = f'MetricCollection(\n prefix=new_prefix_,\n postfix={postfix},\n (DummyMetricSum): DummyMetricSum()\n (DummyMetricDiff): DummyMetricDiff()\n)'
assert new_metric_collection.__repr__() == repr

new_metric_collection = new_metric_collection.clone(postfix='_new_postfix')
out = new_metric_collection(5)
names = [n[:-len(postfix)] if postfix is not None else n for n in names] # strip away old postfix
Expand Down
37 changes: 31 additions & 6 deletions torchmetrics/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
# limitations under the License.

from copy import deepcopy
from typing import Any, Dict, Optional, Sequence, Union
from typing import Any, Dict, Iterable, Optional, Sequence, Tuple, Union

from torch import nn
from torch._jit_internal import _copy_to_script_wrapper

from torchmetrics.metric import Metric
from torchmetrics.utilities import rank_zero_warn
Expand Down Expand Up @@ -144,24 +145,24 @@ def forward(self, *args, **kwargs) -> Dict[str, Any]: # pylint: disable=E0202
be passed to every metric in the collection, while keyword arguments (kwargs)
will be filtered based on the signature of the individual metric.
"""
return {self._set_name(k): m(*args, **m._filter_kwargs(**kwargs)) for k, m in self.items()}
return {k: m(*args, **m._filter_kwargs(**kwargs)) for k, m in self.items()}

def update(self, *args, **kwargs): # pylint: disable=E0202
"""
Iteratively call update for each metric. Positional arguments (args) will
be passed to every metric in the collection, while keyword arguments (kwargs)
will be filtered based on the signature of the individual metric.
"""
for _, m in self.items():
for _, m in self.items(keep_base=True):
m_kwargs = m._filter_kwargs(**kwargs)
m.update(*args, **m_kwargs)

def compute(self) -> Dict[str, Any]:
return {self._set_name(k): m.compute() for k, m in self.items()}
return {k: m.compute() for k, m in self.items()}

def reset(self) -> None:
""" Iteratively call reset for each metric """
for _, m in self.items():
for _, m in self.items(keep_base=True):
m.reset()

def clone(self, prefix: Optional[str] = None, postfix: Optional[str] = None) -> 'MetricCollection':
Expand All @@ -182,16 +183,40 @@ def persistent(self, mode: bool = True) -> None:
"""Method for post-init to change if metric states should be saved to
its state_dict
"""
for _, m in self.items():
for _, m in self.items(keep_base=True):
m.persistent(mode)

def _set_name(self, base: str) -> str:
name = base if self.prefix is None else self.prefix + base
name = name if self.postfix is None else name + self.postfix
return name

@_copy_to_script_wrapper
def keys(self, keep_base: bool = False):
r"""Return an iterable of the ModuleDict key.
Args:
keep_base: Whether to add prefix/postfix on the items collection.
"""
if keep_base:
return self._modules.keys()
return [self._set_name(k) for k in self._modules.keys()]
carmocca marked this conversation as resolved.
Show resolved Hide resolved

@_copy_to_script_wrapper
def items(self, keep_base: bool = False) -> Iterable[Tuple[str, nn.Module]]:
r"""Return an iterable of the ModuleDict key/value pairs.
Args:
keep_base: Whether to add prefix/postfix on the items collection.
"""
if keep_base:
return self._modules.items()
return [(self._set_name(k), v) for k, v in self._modules.items()]

@staticmethod
def _check_arg(arg: Optional[str], name: str) -> Optional[str]:
if arg is None or isinstance(arg, str):
return arg
raise ValueError(f'Expected input `{name}` to be a string, but got {type(arg)}')

def extra_repr(self) -> Optional[str]:
if self.prefix or self.postfix:
return f"prefix={self.prefix},\npostfix={self.postfix},"