Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

float precision check in rational quadratic spline discriminant #71

Open
wants to merge 3 commits into
base: master
Choose a base branch
from

Conversation

francesco-vaselli
Copy link
Contributor

Hello, and thanks again for the terrific package

The present pull request address an issue me and my group have been facing a lot when using model.sample().
When inverting a rational quadratic spline one must calculate the value $\sqrt{b^{2} - 4ac}$ where a, b, c are defined as in eq 6 and onward in the original paper (Neural Spline Flows).

The rational quadratic splines code has a check to ensure that this discriminant satisfies $b^2 -4ac >= 0$:

        discriminant = b.pow(2) - 4 * a * c
        assert (discriminant >= 0).all()

However, if the two components are equal up to the float precision, $b^2 = 4ac$, instead of their difference being 0 as expected, it is sometimes set to an arbitrary bit value with arbitrary sign (e.g. -3.0518e-5), which can proc the AssertionError and cause a crash.

To avoid this unwanted behaviour we implemented a simple check on the relative magnitude of the discriminant, which seems to be solving the issue effectively in our use case:

        discriminant = b.pow(2) - 4 * a * c

        float_precision_mask = (torch.abs(discriminant)/(b.pow(2) + 1e-8)) < 1e-6
        discriminant[float_precision_mask] = 0
        
        assert (discriminant >= 0).all()

Please let me know if you need anything else on my part,
Best regards,
Francesco

PLEASE NOTE: on my fork this commit causes two tests to crash with the following errors:
FAILED tests/transforms/autoregressive_test.py::MaskedPiecewiseQuadraticAutoregressiveTranformTest::test_forward_inverse_are_consistent - AssertionError: The tensors are different!
FAILED tests/transforms/splines/cubic_test.py::CubicSplineTest::test_forward_inverse_are_consistent - AssertionError: The tensors are different!
However this errors seems to be related to two transforms left untouched by my pull request, so I am not sure if they are actually related to my modification. Do you have any idea as to why they may be crashing?

Copy link
Contributor

@arturbekasov arturbekasov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey Francesco,

Thanks for taking the time to submit this.

Yeah, I see how this could be a problem when $b^2 \approx 4ac$, albeit I have never run into this myself.

Left some comments.

Artur

@@ -139,6 +139,10 @@ def rational_quadratic_spline(
c = -input_delta * (inputs - input_cumheights)

discriminant = b.pow(2) - 4 * a * c

float_precision_mask = (torch.abs(discriminant)/(b.pow(2) + 1e-8)) < 1e-6
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a comment explaining where this particular formula comes from, and define constants/parameters for the magic numbers?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added a comment trying to explain the logic of the operation without being too verbose.
I am afraid the numbers are manly motivated by heuristics to manage numerical stability in our case.

nflows/transforms/splines/rational_quadratic.py Outdated Show resolved Hide resolved
@arturbekasov
Copy link
Contributor

P.S. Re. failing tests: unfortunately, the test can be flaky in this package due to numerical instability. Re-running usually helps.

@francesco-vaselli
Copy link
Contributor Author

Thanks for reviewing! I have incorporated the required changes, let me know if I can do anything else before we can proceed with a merge!
Best,
Francesco

@imurray
Copy link
Contributor

imurray commented Oct 3, 2023

From a quick look, I have picky comments, suggesting that a deeper dive might be useful, but not answers:

The 1e-6 is a few times single floating point eps and makes sense given how the discriminant is then used. The 1e-8 seems like a hack: there doesn't seem to be an obvious absolute scale that determines what makes b "small" here?

If b is zero then this operation could introduce a new divide by zero that wasn't there before. If both b and the discriminant is zero, the correct root is usually 0 not NaN. The other solution to the quadratic would pick that up. I wonder if there is a reason that the paper assumed the solution form in the code (eq 29, for 4ac small) was the correct one, whereas apparently you're hitting cases where 4ac is as large as it can be?

I'm not seeing the need to modify small discriminants. It seems safer to only modify negative ones? Perhaps (without having worked out the details) it would be better to first assert that none of them are "too" negative, and then zero out the negative ones.

@francesco-vaselli
Copy link
Contributor Author

Hello and thank you for taking the time to review the pull request and for your insightful comments!

  1. On Magic Constants:
    I understand the concern regarding the arbitrary choice of the constants (1e-6 and 1e-8). These were selected primarily based on empirical observations to ensure numerical stability in our specific use-cases. I'm open to exploring a more principled approach. Would you recommend any alternatives?

  2. Division by Zero:
    Your point about introducing a potential division by zero is valid. My intention was not to override the cases where both b and the discriminant are zero, as the correct root in such a scenario should indeed be zero, not NaN.

  3. Modifying Small Discriminants:
    The primary reason for modifying small discriminants was to mitigate the effects of floating-point arithmetic errors that can sometimes cause the discriminant to be a tiny negative number, even when b**2 and 4*a*c are theoretically identical. I opted to alter small discriminants to prevent these tiny errors from propagating further into calculations.

    However, I do agree that a more conservative approach of only modifying the negative ones could be more appropriate. Just to understand, would something like this be preferrable?:

     discriminant = torch.where(discriminant<0, 
                                 torch.zeros_like(discriminant), discriminant)

Would love to hear your thoughts on these points. Once again, thank you for your time and expertise.

Best regards,
Francesco

@imurray
Copy link
Contributor

imurray commented Oct 12, 2023

My advice would be to try to nice to understand why you are hitting this issue when others haven't. Numerical problems often hint at trying to do something unreasonable or strange, that can encourage you to improve what you're doing. It's also possible you're in a regime that the code could generally serve better, for example solving the quadratic differently, but we'd want to understand it.

Regardless, I think we can address this issue without the magic 1e-8 by asserting:

        assert (discriminant >= -1e-6 * b.pow(2)).all()

and then zeroing out any negative values as you suggest. I think we can then remove the later assert.

This new assert allows the discriminant to be "slightly" negative, due to round-off errors of a few eps times the final numbers involved. After that we can zero out those negative values as you suggest, so we don't crash. We'll still get a divide by zero if both b and discriminant are zero. But by leaving all positive discriminants alone, at least we haven't introduced any new crashes.

I think this should solve your problem? But I haven't tried it, so please do. It's possible you'll need to replace b.pow(2) with something like torch.maximum(b.pow(2), torch.abs(4*a*c)) if you're getting really tiny values where b.pow(2) could be zero, and the other part tiny but non-zero. Erm, but then you'd probably be seeing divide by zero problems too, which we'd need to address, so I doubt it.

It think @arturbekasov was asking to name the 1e-6. I'm not sure what he wanted. eps_like_tol? Or whether some existing tolerance should be re-used.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants