-
Notifications
You must be signed in to change notification settings - Fork 422
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Explicitly initializing tensors in float32
in MeanMetric goes against torch.set_default_dtype
, leading to numerical errors
#2365
Comments
Hi! thanks for your contribution!, great first issue! |
float32
in MeanMetric goes against torch.set_default_dtype
, leading to numerical errors
So honoring the default if it is set would help, right? mind sending a fix PR... 🐰 |
Something like that. I could try, would be my first time but I'll give it a go. |
@viktor-ktorvi I had a bit of time to look at it and it should be fixed in PR #2366 |
Hi, I just found the time to check this. I've upgraded and the issue still isn't fixed. What's wrongThe example I stated still doesn't function as expected. Why the test passedThe implemented test passes because of the use of Let me demonstrate why it's wrong: import torch
from torchmetrics.aggregation import MeanMetric
torch.set_default_dtype(torch.float64)
metric = MeanMetric()
values = torch.randn(10000)
metric.update(values)
result = metric.compute()
actual_mean = values.mean()
print(f"{result} = Result\n{actual_mean} = Actual mean")
print(f"\nAll close = {torch.allclose(result, actual_mean, atol=1e-12)}")
print(f"Exactly equal = {result == actual_mean}")
MotivationIt might feel like I'm nitpicking, but these sorts of errors add up in complex problem formulations. For context, I'm working on approximating optimization problems in with ML, and in my particular case, when casting from How to fixI've narrowed it down to the Additionally, lots of these if not isinstance(x, Tensor):
x = torch.as_tensor(x, dtype=torch.float32, device=self.device)
if weight is not None and not isinstance(weight, Tensor):
weight = torch.as_tensor(weight, dtype=torch.float32, device=self.device) statements, where the dtype is explicitly called exist, e.g., line 79 or line 559. So, each of those would need to be replaced with Finally, the test needs to check for equality i.e., Thanks for your time! @SkafteNicki @Borda |
Hi @viktor-ktorvi, thanks for getting back. I was too quick on the trigger and did not correctly verify that everything was in order. Sorry about that. I have now created #2386 that should be correct. It uses a hard equal |
🐛 Bug
When working in
torch.float64
(set bytorch.set_default_dtype(torch.float64)
), the result ofMeanMetric
is slightly incorrect because, internaly, theweight
andvalue
tensors are initialized totorch.float32
explicitly (see these lines).Naturally, if someone's working in float64 then the problem can't tolerate the slight mismatch.
To Reproduce
The output of the snippet is:
Notice that the two differ. Where as if we switch the default to
torch.float32
:they're exactly the same.
Expected behavior
Idealy, the result of the MeanMetric (and possibly other metrics where this might have happened) should give the correct result irrespective of the default dtype.
Environment
pip freeze
:Python 3.9
Windows 11
The text was updated successfully, but these errors were encountered: