Skip to content

Commit

Permalink
implement logabsdetjac for Inverse{<:TruncatedBijector} for better nu…
Browse files Browse the repository at this point in the history
…merical stability

Example of previous badness:  logabsdetjac(inverse(bijector(Uniform(-1,1))), 80) = -Inf (is now -79.30685281944005)
  • Loading branch information
acertain committed Aug 19, 2024
1 parent dc6b21f commit 213328d
Showing 1 changed file with 19 additions and 0 deletions.
19 changes: 19 additions & 0 deletions src/bijectors/truncated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,25 @@ end

with_logabsdet_jacobian(b::TruncatedBijector, x) = transform(b, x), logabsdetjac(b, x)

function truncated_inv_logabsdetjac(y, a, b)
lowerbounded, upperbounded = isfinite(a), isfinite(b)
if lowerbounded && upperbounded
abs_y = abs(y)
return log(b - a) - abs_y + 2 * LogExpFunctions.log1pexp(-abs_y)
elseif lowerbounded || upperbounded
return convert(promote_type(typeof(y), typeof(a), typeof(b)), y)
else
return zero(y)
end
end

function logabsdetjac(ib::Inverse{<:TruncatedBijector}, y)
a, b = ib.orig.lb, ib.orig.ub
return truncated_inv_logabsdetjac.(y, a, b)
end

with_logabsdet_jacobian(ib::Inverse{<:TruncatedBijector}, y) = transform(ib, y), logabsdetjac(ib, y)

# It's only monotonically decreasing if it's only upper-bounded.
# In the multivariate case, we can only say something reasonable if entries are monotonic.
function is_monotonically_increasing(b::TruncatedBijector)
Expand Down

0 comments on commit 213328d

Please sign in to comment.