-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MMD metric occasionally returns nan
#855
Comments
It appears the root of the problem lies in catastrophic cancellation, which happens when subtracting two approximations that are very close in value. The resulting approximation of the subtraction can have a very large precision error even if the individual approximations are good. E.g., take 2 numbers with 8 significant figures that coincide in the first 5 s.f.: 0.20802966, 0.20802812. Their difference is now 0.00000154, so we went from 8 s.f. to 3 s.f. If the original numbers are approximations, the resulting subtraction can vary a lot. In our case we have I am still not sure what the best solution here would be. Some ideas:
Logs for the example above:
|
Setting the floor to zero (option 3) seems like the obvious option to me when it's numerical analysis explaining this - we can't just go on a threshold setting law combatting physics. We may want some safety check in place to guard against a coding bug. |
I agree that if we go that route, just truncating to 0 is the best option. I assume the threshold was there originally to allow some precision tolerance but detect if something weird is happening - and I think we understand what is happening now. |
What's the problem?
The MMD metric between a dataset and a coreset sometimes returns
nan
when the dataset is relatively large (order of 10^5 points).This is likely a precision error:
kernel_xx_mean + kernel_yy_mean - 2 * kernel_xy_mean
inMMD.compute
evaluates to a small negative number in these cases, but larger than the precision threshold that was put in place to catch this. Then the square root turns it intonan
.The issue goes away when double precision is used:
from jax import config; config.update("jax_enable_x64", True)
.The simplest solution is to increase the precision tolerance threshold to catch these cases, but I am not sure what would be an appropriate value.
We could just truncate the expression at 0, but it could mask a bug down the line. Also the coreset may be relatively small (<10%) for this to happen and it is strange to have an MMD of 0 in that case.
Another solution is to increase precision to float64, but I am not sure about the performance and other implications of this.
How can we reproduce the issue?
Python version
3.12
Package version
0.3.0
Operating system
Windows 10
Other packages
No response
Relevant log output
No response
The text was updated successfully, but these errors were encountered: