Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed compatibility between compute groups in MetricCollection and prefix/postfix arg #1008

Merged
merged 12 commits into from
May 6, 2022
Merged
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed MAP metric when using custom list of thresholds ([#995](https://github.com/PyTorchLightning/metrics/issues/995))


- Fixed compatibility between compute groups in `MetricCollection` and prefix/postfix arg ([#1007](https://github.com/PyTorchLightning/metrics/pull/1008))


## [0.8.1] - 2022-04-27

### Changed
Expand Down
17 changes: 14 additions & 3 deletions tests/bases/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,8 @@ def test_collection_check_arg():


def test_collection_filtering():
"""Test that collections works with the kwargs argument."""

class DummyMetric(Metric):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -310,11 +312,20 @@ def compute(self):
),
],
)
def test_check_compute_groups(metrics, expected):
@pytest.mark.parametrize(
"prefix, postfix",
[
[None, None],
["prefix_", None],
[None, "_postfix"],
["prefix_", "_postfix"],
],
)
def test_check_compute_groups(metrics, expected, prefix, postfix):
"""Check that compute groups are formed after initialization."""
m = MetricCollection(deepcopy(metrics), compute_groups=True)
m = MetricCollection(deepcopy(metrics), prefix=prefix, postfix=postfix, compute_groups=True)
# Construct without for comparison
m2 = MetricCollection(deepcopy(metrics), compute_groups=False)
m2 = MetricCollection(deepcopy(metrics), prefix=prefix, postfix=postfix, compute_groups=False)

assert len(m.compute_groups) == len(m)
assert m2.compute_groups == {}
Expand Down
2 changes: 1 addition & 1 deletion torchmetrics/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ def _init_compute_groups(self) -> None:
self._groups_checked = True
else:
# Initialize all metrics as their own compute group
self._groups = {i: [str(k)] for i, k in enumerate(self.keys(keep_base=False))}
self._groups = {i: [str(k)] for i, k in enumerate(self.keys(keep_base=True))}

@property
def compute_groups(self) -> Dict[int, List[str]]:
Expand Down