Skip to content

Commit

Permalink
float precision check
Browse files Browse the repository at this point in the history
  • Loading branch information
francesco-vaselli committed Mar 15, 2023
1 parent 569c8ad commit d8b0858
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions nflows/transforms/splines/rational_quadratic.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,10 @@ def rational_quadratic_spline(
c = -input_delta * (inputs - input_cumheights)

discriminant = b.pow(2) - 4 * a * c

float_precision_mask = (torch.abs(discriminant)/(b.pow(2) + 1e-8)) < 1e-6
discriminant[float_precision_mask] = 0

assert (discriminant >= 0).all()

root = (2 * c) / (-b - torch.sqrt(discriminant))
Expand Down

0 comments on commit d8b0858

Please sign in to comment.