Skip to content

Commit

Permalink
comments and torch.where
Browse files Browse the repository at this point in the history
  • Loading branch information
francesco-vaselli committed Sep 30, 2023
1 parent d8b0858 commit 4934a85
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions nflows/transforms/splines/rational_quadratic.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,13 @@ def rational_quadratic_spline(

discriminant = b.pow(2) - 4 * a * c

float_precision_mask = (torch.abs(discriminant)/(b.pow(2) + 1e-8)) < 1e-6
discriminant[float_precision_mask] = 0
# Correcting for floating-point errors in the discriminant calculation.
# The float_precision_mask identifies elements where the discriminant is essentially zero,
# but appears nonzero due to machine precision limitations.
# Threshold values (1e-8 and 1e-6) are heuristic-based to manage numerical stability.
float_precision_mask = (torch.abs(discriminant) / (b.pow(2) + 1e-8)) < 1e-6
discriminant = torch.where(float_precision_mask,
torch.zeros_like(discriminant), discriminant)

assert (discriminant >= 0).all()

Expand Down

0 comments on commit 4934a85

Please sign in to comment.