Skip to content

Commit

Permalink
Fixed compatibility between compute groups in MetricCollection and …
Browse files Browse the repository at this point in the history
…prefix/postfix arg (#1008)
  • Loading branch information
SkafteNicki authored May 6, 2022
1 parent 91ab307 commit 06fdb04
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 4 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed MAP metric when using custom list of thresholds ([#995](https://github.com/PyTorchLightning/metrics/issues/995))


- Fixed compatibility between compute groups in `MetricCollection` and prefix/postfix arg ([#1007](https://github.com/PyTorchLightning/metrics/pull/1008))


## [0.8.1] - 2022-04-27

### Changed
Expand Down
17 changes: 14 additions & 3 deletions tests/bases/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,8 @@ def test_collection_check_arg():


def test_collection_filtering():
"""Test that collections works with the kwargs argument."""

class DummyMetric(Metric):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -310,11 +312,20 @@ def compute(self):
),
],
)
def test_check_compute_groups(metrics, expected):
@pytest.mark.parametrize(
"prefix, postfix",
[
[None, None],
["prefix_", None],
[None, "_postfix"],
["prefix_", "_postfix"],
],
)
def test_check_compute_groups(metrics, expected, prefix, postfix):
"""Check that compute groups are formed after initialization."""
m = MetricCollection(deepcopy(metrics), compute_groups=True)
m = MetricCollection(deepcopy(metrics), prefix=prefix, postfix=postfix, compute_groups=True)
# Construct without for comparison
m2 = MetricCollection(deepcopy(metrics), compute_groups=False)
m2 = MetricCollection(deepcopy(metrics), prefix=prefix, postfix=postfix, compute_groups=False)

assert len(m.compute_groups) == len(m)
assert m2.compute_groups == {}
Expand Down
2 changes: 1 addition & 1 deletion torchmetrics/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ def _init_compute_groups(self) -> None:
self._groups_checked = True
else:
# Initialize all metrics as their own compute group
self._groups = {i: [str(k)] for i, k in enumerate(self.keys(keep_base=False))}
self._groups = {i: [str(k)] for i, k in enumerate(self.keys(keep_base=True))}

@property
def compute_groups(self) -> Dict[int, List[str]]:
Expand Down

0 comments on commit 06fdb04

Please sign in to comment.