diff --git a/nflows/transforms/splines/rational_quadratic.py b/nflows/transforms/splines/rational_quadratic.py index 5d752cf..eccbb61 100644 --- a/nflows/transforms/splines/rational_quadratic.py +++ b/nflows/transforms/splines/rational_quadratic.py @@ -140,8 +140,13 @@ def rational_quadratic_spline( discriminant = b.pow(2) - 4 * a * c - float_precision_mask = (torch.abs(discriminant)/(b.pow(2) + 1e-8)) < 1e-6 - discriminant[float_precision_mask] = 0 + # Correcting for floating-point errors in the discriminant calculation. + # The float_precision_mask identifies elements where the discriminant is essentially zero, + # but appears nonzero due to machine precision limitations. + # Threshold values (1e-8 and 1e-6) are heuristic-based to manage numerical stability. + float_precision_mask = (torch.abs(discriminant) / (b.pow(2) + 1e-8)) < 1e-6 + discriminant = torch.where(float_precision_mask, + torch.zeros_like(discriminant), discriminant) assert (discriminant >= 0).all()