From 4934a85b71752640a70b00a8c41e54e53562de78 Mon Sep 17 00:00:00 2001 From: Francesco Vaselli Date: Sat, 30 Sep 2023 10:45:48 +0200 Subject: [PATCH] comments and torch.where --- nflows/transforms/splines/rational_quadratic.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/nflows/transforms/splines/rational_quadratic.py b/nflows/transforms/splines/rational_quadratic.py index 5d752cf..eccbb61 100644 --- a/nflows/transforms/splines/rational_quadratic.py +++ b/nflows/transforms/splines/rational_quadratic.py @@ -140,8 +140,13 @@ def rational_quadratic_spline( discriminant = b.pow(2) - 4 * a * c - float_precision_mask = (torch.abs(discriminant)/(b.pow(2) + 1e-8)) < 1e-6 - discriminant[float_precision_mask] = 0 + # Correcting for floating-point errors in the discriminant calculation. + # The float_precision_mask identifies elements where the discriminant is essentially zero, + # but appears nonzero due to machine precision limitations. + # Threshold values (1e-8 and 1e-6) are heuristic-based to manage numerical stability. + float_precision_mask = (torch.abs(discriminant) / (b.pow(2) + 1e-8)) < 1e-6 + discriminant = torch.where(float_precision_mask, + torch.zeros_like(discriminant), discriminant) assert (discriminant >= 0).all()