-
Notifications
You must be signed in to change notification settings - Fork 120
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
float precision check in rational quadratic spline discriminant #71
base: master
Are you sure you want to change the base?
float precision check in rational quadratic spline discriminant #71
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey Francesco,
Thanks for taking the time to submit this.
Yeah, I see how this could be a problem when
Left some comments.
Artur
@@ -139,6 +139,10 @@ def rational_quadratic_spline( | |||
c = -input_delta * (inputs - input_cumheights) | |||
|
|||
discriminant = b.pow(2) - 4 * a * c | |||
|
|||
float_precision_mask = (torch.abs(discriminant)/(b.pow(2) + 1e-8)) < 1e-6 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add a comment explaining where this particular formula comes from, and define constants/parameters for the magic numbers?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have added a comment trying to explain the logic of the operation without being too verbose.
I am afraid the numbers are manly motivated by heuristics to manage numerical stability in our case.
P.S. Re. failing tests: unfortunately, the test can be flaky in this package due to numerical instability. Re-running usually helps. |
Thanks for reviewing! I have incorporated the required changes, let me know if I can do anything else before we can proceed with a merge! |
From a quick look, I have picky comments, suggesting that a deeper dive might be useful, but not answers: The If I'm not seeing the need to modify small discriminants. It seems safer to only modify negative ones? Perhaps (without having worked out the details) it would be better to first assert that none of them are "too" negative, and then zero out the negative ones. |
Hello and thank you for taking the time to review the pull request and for your insightful comments!
Would love to hear your thoughts on these points. Once again, thank you for your time and expertise. Best regards, |
My advice would be to try to nice to understand why you are hitting this issue when others haven't. Numerical problems often hint at trying to do something unreasonable or strange, that can encourage you to improve what you're doing. It's also possible you're in a regime that the code could generally serve better, for example solving the quadratic differently, but we'd want to understand it. Regardless, I think we can address this issue without the magic
and then zeroing out any negative values as you suggest. I think we can then remove the later assert. This new assert allows the discriminant to be "slightly" negative, due to round-off errors of a few I think this should solve your problem? But I haven't tried it, so please do. It's possible you'll need to replace It think @arturbekasov was asking to name the |
Hello, and thanks again for the terrific package
The present pull request address an issue me and my group have been facing a lot when using$\sqrt{b^{2} - 4ac}$ where a, b, c are defined as in eq 6 and onward in the original paper (Neural Spline Flows).
model.sample()
.When inverting a rational quadratic spline one must calculate the value
The rational quadratic splines code has a check to ensure that this discriminant satisfies$b^2 -4ac >= 0$ :
However, if the two components are equal up to the float precision,$b^2 = 4ac$ , instead of their difference being 0 as expected, it is sometimes set to an arbitrary bit value with arbitrary sign (e.g.
-3.0518e-5
), which can proc the AssertionError and cause a crash.To avoid this unwanted behaviour we implemented a simple check on the relative magnitude of the discriminant, which seems to be solving the issue effectively in our use case:
Please let me know if you need anything else on my part,
Best regards,
Francesco