diff --git a/nflows/transforms/splines/rational_quadratic.py b/nflows/transforms/splines/rational_quadratic.py index bb6a8c4..57cd8a5 100644 --- a/nflows/transforms/splines/rational_quadratic.py +++ b/nflows/transforms/splines/rational_quadratic.py @@ -139,6 +139,16 @@ def rational_quadratic_spline( c = -input_delta * (inputs - input_cumheights) discriminant = b.pow(2) - 4 * a * c + + # Correcting for floating-point errors in the discriminant calculation. + # The float_precision_mask identifies elements where the discriminant is essentially zero, + # compared to the magnitude of b.pow(2), + # but appears nonzero due to machine precision limitations. + # Threshold values (1e-8 and 1e-6) are heuristic-based to manage numerical stability. + float_precision_mask = (torch.abs(discriminant) / (b.pow(2) + 1e-8)) < 1e-6 + discriminant = torch.where(float_precision_mask, + torch.zeros_like(discriminant), discriminant) + assert (discriminant >= 0).all() root = (2 * c) / (-b - torch.sqrt(discriminant))