Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix incorrect caching of MetricCollection #2571

Merged
6 changes: 5 additions & 1 deletion src/torchmetrics/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,11 @@ def update(self, *args: Any, **kwargs: Any) -> None:
"""
# Use compute groups if already initialized and checked
if self._groups_checked:
# Delete the cache of all metrics to invalidate the cache and therefore recent compute calls, forcing new
# compute calls to recompute
for k in self.keys(keep_base=True):
mi = getattr(self, k)
mi._computed = None
for cg in self._groups.values():
# only update the first member
m0 = getattr(self, cg[0])
Expand Down Expand Up @@ -304,7 +309,6 @@ def _compute_groups_create_state_ref(self, copy: bool = False) -> None:
# Determine if we just should set a reference or a full copy
setattr(mi, state, deepcopy(m0_state) if copy else m0_state)
mi._update_count = deepcopy(m0._update_count) if copy else m0._update_count
mi._computed = deepcopy(m0._computed) if copy else m0._computed
self._state_is_copy = copy

def compute(self) -> Dict[str, Any]:
Expand Down
5 changes: 3 additions & 2 deletions src/torchmetrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,7 +619,7 @@ def wrapped_func(*args: Any, **kwargs: Any) -> Any:

# return cached value
if self._computed is not None:
return self._computed
return deepcopy(self._computed)
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved

# compute relies on the sync context manager to gather the states across processes and apply reduction
# if synchronization happened, the current rank accumulated states will be restored to keep
Expand All @@ -634,7 +634,8 @@ def wrapped_func(*args: Any, **kwargs: Any) -> Any:
if self.compute_with_cache:
self._computed = value

return value
# Return a deep copy to avoid side effects for non-scalar values, e.g. ConfusionMatrix
return deepcopy(value)
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved

return wrapped_func

Expand Down
19 changes: 19 additions & 0 deletions tests/unittests/bases/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,10 +452,29 @@ def test_check_compute_groups_correctness(self, metrics, expected, preds, target
for key in res_cg:
assert torch.allclose(res_cg[key], res_without_cg[key])

# Check if second compute is the same
res_cg2 = m.compute()
for key in res_cg2:
assert torch.allclose(res_cg[key], res_cg2[key])

if with_reset:
m.reset()
m2.reset()

# Test if a second compute without a reset is the same
m.reset()
m.update(preds, target)
res_cg = m.compute()
# Simulate different preds by simply inversing them
m.update(1 - preds, target)
res_cg2 = m.compute()
# Now check if the results from the first compute are different from the second
for key in res_cg:
# A different shape is okay, therefore skip (this happens for multidim_average="samplewise")
if res_cg[key].shape != res_cg2[key].shape:
continue
assert not torch.all(res_cg[key] == res_cg2[key])

@pytest.mark.parametrize("method", ["items", "values", "keys"])
def test_check_compute_groups_items_and_values(self, metrics, expected, preds, target, method):
"""Check states are copied instead of passed by ref when a single metric in the collection is access."""
Expand Down
Loading