Skip to content

Commit

Permalink
fix: pearson changes inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda committed Oct 1, 2024
1 parent 20b4d3a commit fcd42e9
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 4 deletions.
3 changes: 2 additions & 1 deletion src/torchmetrics/functional/image/rmse_sw.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ def _rmse_sw_compute(
"""
rmse = rmse_val_sum / total_images if rmse_val_sum is not None else None
if rmse_map is not None:
rmse_map /= total_images
# prevent overwrite the inputs
rmse_map = rmse_map / total_images
return rmse, rmse_map


Expand Down
7 changes: 4 additions & 3 deletions src/torchmetrics/functional/regression/pearson.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,10 @@ def _pearson_corrcoef_compute(
nb: number of observations
"""
var_x /= nb - 1
var_y /= nb - 1
corr_xy /= nb - 1
# prevent overwrite the inputs
var_x = var_x / (nb - 1)
var_y = var_y / (nb - 1)
corr_xy = corr_xy / (nb - 1)
# if var_x, var_y is float16 and on cpu, make it bfloat16 as sqrt is not supported for float16
# on cpu, remove this after https://github.com/pytorch/pytorch/issues/54774 is fixed
if var_x.dtype == torch.float16 and var_x.device == torch.device("cpu"):
Expand Down
20 changes: 20 additions & 0 deletions tests/unittests/regression/test_pearson.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,23 @@ def test_single_sample_update():
metric(torch.tensor([7.0]), torch.tensor([8.0]))
res2 = metric.compute()
assert torch.allclose(res1, res2)

def test_overwrite_reference_inputs():
"""Test that the normalizations does not overwrite inputs.
Variables var_x, var_y, corr_xy are references to the object variables and get incorrectly scaled down
such that when you update again and compute you get very wrong values.
"""
y = torch.randn(100)
y_pred = y + torch.randn(y.shape) / 5
# Initialize Pearson correlation coefficient metric
pearson = PearsonCorrCoef()
# Compute the Pearson correlation coefficient
correlation = pearson(y, y_pred)

pearson = PearsonCorrCoef()
for lower, upper in [(0, 33), (33, 66), (66, 99), (99, 100)]:
pearson.update(torch.tensor(y[lower:upper]), torch.tensor(y_pred[lower:upper]))
pearson.compute()

assert torch.isclose(pearson.compute(), correlation)

0 comments on commit fcd42e9

Please sign in to comment.