Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixes metric hashing #478

Merged
merged 11 commits into from
Aug 24, 2021
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Fixed bug where compositional metrics where unable to sync because of type mismatch ([#454](https://github.com/PyTorchLightning/metrics/pull/454))


- Fixed metric hashing ([#478](https://github.com/PyTorchLightning/metrics/pull/478))


## [0.5.0] - 2021-08-09

### Added
Expand Down
22 changes: 22 additions & 0 deletions tests/bases/test_hashing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import pytest

from tests.helpers.testers import DummyListMetric, DummyMetric


@pytest.mark.parametrize(
"metric_cls",
[
DummyMetric,
DummyListMetric,
],
)
def test_metric_hashing(metric_cls):
"""Tests that hases are different.

See the Metric's hash function for details on why this is required.
"""
instance_1 = metric_cls()
instance_2 = metric_cls()

assert hash(instance_1) != hash(instance_2)
assert id(instance_1) != id(instance_2)
2 changes: 1 addition & 1 deletion tests/bases/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ class B(DummyListMetric):

b1 = B()
b2 = B()
assert hash(b1) == hash(b2)
assert hash(b1) != hash(b2) # different ids
assert isinstance(b1.x, list) and len(b1.x) == 0
b1.x.append(tensor(5))
assert isinstance(hash(b1), int) # <- check that nothing crashes
Expand Down
7 changes: 6 additions & 1 deletion torchmetrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,12 @@ def _filter_kwargs(self, **kwargs: Any) -> Dict[str, Any]:
return filtered_kwargs

def __hash__(self) -> int:
hash_vals = [self.__class__.__name__]
# we need to add the id here, since PyTorch requires a module hash to be unique.
# Internally, PyTorch nn.Module relies on that for children discovery
# (see https://github.com/pytorch/pytorch/blob/v1.9.0/torch/nn/modules/module.py#L1544)
# For metrics that include tensors it is not a problem,
# since their hash is unique based on the memory location but we cannot rely on that for every metric.
hash_vals = [self.__class__.__name__, id(self)]

for key in self._defaults:
val = getattr(self, key)
Expand Down