Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixes metric hashing #478

Merged
merged 11 commits into from
Aug 24, 2021
23 changes: 23 additions & 0 deletions tests/bases/test_hashing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import pytest

from tests.helpers.testers import DummyMetric, DummyNonTensorMetric
Borda marked this conversation as resolved.
Show resolved Hide resolved

pytest.mark.parametrize(
"metric_cls",
[
(DummyMetric,),
(DummyNonTensorMetric,),
Borda marked this conversation as resolved.
Show resolved Hide resolved
],
)


def test_metric_hashing(metric_cls):
"""Tests that hases are different.

See the Metric's hash function for details on why this is required.
"""
instance_1 = metric_cls()
instance_2 = metric_cls()

assert hash(instance_1) != hash(instance_2)
assert id(instance_1) != id(instance_2)
14 changes: 14 additions & 0 deletions tests/helpers/testers.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,20 @@ def compute(self):
pass


class DummyNonTensorMetric(Metric):
name = "DummyNonTensor"

def __init__(self):
super().__init__()
self.add_state("x", [], dist_reduce_fx=None)

def update(self):
pass

def compute(self):
pass
Borda marked this conversation as resolved.
Show resolved Hide resolved


class DummyListMetric(Metric):
name = "DummyList"

Expand Down
7 changes: 6 additions & 1 deletion torchmetrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,12 @@ def _filter_kwargs(self, **kwargs: Any) -> Dict[str, Any]:
return filtered_kwargs

def __hash__(self) -> int:
hash_vals = [self.__class__.__name__]
# we need to add the id here, since PyTorch requires the hash of an module to be unique.
# they rely on that for children discovery
# (see https://github.com/pytorch/pytorch/blob/v1.9.0/torch/nn/modules/module.py#L1544)
# For metrics that include tensors it is not a problem,
# since their hash is unique based on the memory location but we cannot rely on that for every metric.
hash_vals = [self.__class__.__name__, id(self)]
justusschock marked this conversation as resolved.
Show resolved Hide resolved

for key in self._defaults:
val = getattr(self, key)
Expand Down