diff --git a/nflows/transforms/splines/rational_quadratic.py b/nflows/transforms/splines/rational_quadratic.py index bb6a8c4..5d752cf 100644 --- a/nflows/transforms/splines/rational_quadratic.py +++ b/nflows/transforms/splines/rational_quadratic.py @@ -139,6 +139,10 @@ def rational_quadratic_spline( c = -input_delta * (inputs - input_cumheights) discriminant = b.pow(2) - 4 * a * c + + float_precision_mask = (torch.abs(discriminant)/(b.pow(2) + 1e-8)) < 1e-6 + discriminant[float_precision_mask] = 0 + assert (discriminant >= 0).all() root = (2 * c) / (-b - torch.sqrt(discriminant))