Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Explicitly initializing tensors in float32 in MeanMetric goes against torch.set_default_dtype, leading to numerical errors #2365

Closed
viktor-ktorvi opened this issue Feb 8, 2024 · 6 comments · Fixed by #2366 or #2386
Labels
bug / fix Something isn't working help wanted Extra attention is needed v1.2.x

Comments

@viktor-ktorvi
Copy link

🐛 Bug

When working in torch.float64 (set by torch.set_default_dtype(torch.float64)), the result of MeanMetric is slightly incorrect because, internaly, the weight and value tensors are initialized to torch.float32 explicitly (see these lines).

Naturally, if someone's working in float64 then the problem can't tolerate the slight mismatch.

To Reproduce

import torch
from torchmetrics.aggregation import MeanMetric

torch.set_default_dtype(torch.float64)

metric = MeanMetric()

values = torch.randn(10000)
metric.update(values)
result = metric.compute()

print(f"{result} = Result\n{values.mean()} = Actual mean")

The output of the snippet is:

0.018805044555664063 = Result
0.01880504383120193 = Actual mean

Notice that the two differ. Where as if we switch the default to torch.float32:

0.013675613328814507 = Result
0.013675613328814507 = Actual mean

they're exactly the same.

Expected behavior

Idealy, the result of the MeanMetric (and possibly other metrics where this might have happened) should give the correct result irrespective of the default dtype.

Environment

pip freeze:

torch==2.1.0+cu121
torchmetrics==1.2.1

Python 3.9

Windows 11

@viktor-ktorvi viktor-ktorvi added bug / fix Something isn't working help wanted Extra attention is needed labels Feb 8, 2024
Copy link

github-actions bot commented Feb 8, 2024

Hi! thanks for your contribution!, great first issue!

@Borda Borda changed the title Explicitly initializing tensors in float32 in MeanMetric goes against torch.set_default_dtype, leading to numerical errors Explicitly initializing tensors in float32 in MeanMetric goes against torch.set_default_dtype, leading to numerical errors Feb 9, 2024
@Borda Borda added the v1.2.x label Feb 9, 2024
@Borda
Copy link
Member

Borda commented Feb 9, 2024

So honoring the default if it is set would help, right? mind sending a fix PR... 🐰

@viktor-ktorvi
Copy link
Author

Something like that.

I could try, would be my first time but I'll give it a go.

@SkafteNicki
Copy link
Member

@viktor-ktorvi I had a bit of time to look at it and it should be fixed in PR #2366

@viktor-ktorvi
Copy link
Author

viktor-ktorvi commented Feb 14, 2024

Hi,

I just found the time to check this. I've upgraded and the issue still isn't fixed.

What's wrong

The example I stated still doesn't function as expected.

Why the test passed

The implemented test passes because of the use of torch.allclose, however, the values need not be close, but exactly equal. There's no reason for the two to not be exactly equal if all the calculations are performed in the same dtype.

Let me demonstrate why it's wrong:

import torch
from torchmetrics.aggregation import MeanMetric

torch.set_default_dtype(torch.float64)

metric = MeanMetric()

values = torch.randn(10000)
metric.update(values)
result = metric.compute()

actual_mean = values.mean()

print(f"{result} = Result\n{actual_mean} = Actual mean")

print(f"\nAll close = {torch.allclose(result, actual_mean, atol=1e-12)}")
print(f"Exactly equal = {result == actual_mean}")
-0.0041637580871582034 = Result
-0.004163758971599815 = Actual mean

All close = True
Exactly equal = False

Motivation

It might feel like I'm nitpicking, but these sorts of errors add up in complex problem formulations. For context, I'm working on approximating optimization problems in with ML, and in my particular case, when casting from float64 to float32 and recalculating the values, the equality and inequality constraints are no longer fulfilled and the objective function is off.

How to fix

I've narrowed it down to the _cast_and_nan_check_input function, which gets called in update line 564. At the end of _cast_and_nan_check_input (line 104), x.float() get's called, explicitly casting to float32, so that'd need changing.

Additionally, lots of these

if not isinstance(x, Tensor):
    x = torch.as_tensor(x, dtype=torch.float32, device=self.device)
if weight is not None and not isinstance(weight, Tensor):
    weight = torch.as_tensor(weight, dtype=torch.float32, device=self.device)

statements, where the dtype is explicitly called exist, e.g., line 79 or line 559. So, each of those would need to be replaced with dtype=torch.get_default_dtype().

Finally, the test needs to check for equality i.e., result == compare_function(values) instead of using torch.allclose.

Thanks for your time! @SkafteNicki @Borda

@SkafteNicki
Copy link
Member

Hi @viktor-ktorvi, thanks for getting back. I was too quick on the trigger and did not correctly verify that everything was in order. Sorry about that. I have now created #2386 that should be correct. It uses a hard equal == in the comparison instead of torch.allclose as you wrote. Tensors should now be kept in whatever dtype the metric is initialized with.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug / fix Something isn't working help wanted Extra attention is needed v1.2.x
Projects
None yet
3 participants