Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make corr_cholesky_factor deal with large inputs correctly. #118

Merged
merged 2 commits into from
Dec 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "TransformVariables"
uuid = "84d833dd-6860-57f9-a1a7-6da5db126cff"
authors = ["Tamas K. Papp <tkpapp@gmail.com>"]
version = "0.8.9"
version = "0.8.10"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
Expand Down
57 changes: 37 additions & 20 deletions src/special_arrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,36 +7,53 @@ export UnitVector, UnitSimplex, CorrCholeskyFactor, corr_cholesky_factor
"""
$(SIGNATURES)

`log(abs(…))` of the derivative of `tanh`, calculated accurately.
Return a `NamedTuple` of

- `log_l2_rem`, for `log(1 - tanh(x)^2)`,

- `logjac`, for `log(abs( ∂(log(abs(tanh(x))) / ∂x ))`

Caller ensures that `x ≥ 0`. `x == 0` is handled correctly, but results in infinities.
"""
function _tanh_logabsderiv(x)
function tanh_helpers(x)
d = 2*x
log(4) + d - 2 * log1pexp(d)
log_denom = log1pexp(d) # log(exp(2x) + 1)
logjac = log(4) + d - 2 * log_denom # log(ab
log_l2_rem = 2*(log(2) + x - log_denom) # log(2exp(x) / (exp(2x) + 1))
(; logjac, log_l2_rem)
end

"""
(y, r, ℓ) = $SIGNATURES
(y, log_r, ℓ) = $SIGNATURES

Given ``x ∈ ℝ`` and ``0 ≤ r ≤ 1``, return `(y, r′)` such that
Given ``x ∈ ℝ`` and ``0 ≤ r ≤ 1``, we define `(y, r′)` such that

1. ``y² + (r′)² = r²``,

2. ``y: |y| ≤ r`` is mapped with a bijection from `x`.
2. ``y: |y| ≤ r`` is mapped with a bijection from `x`, with the sign depending on `x`,

but use `log(r)` for actual calculations so that large `y`s still give nonsingular results.

`ℓ` is the log Jacobian (whether it is evaluated depends on `flag`).
"""
@inline function l2_remainder_transform(flag::LogJacFlag, x, r)
@inline function l2_remainder_transform(flag::LogJacFlag, x, log_r)
@unpack logjac, log_l2_rem = tanh_helpers(x)
# note that 1-tanh(x)^2 = sech(x)^2
(tanh(x) * √r, r*sech(x)^2,
flag isa NoLogJac ? flag : _tanh_logabsderiv(x) + 0.5*log(r))
(tanh(x) * exp(log_r / 2),
log_r + log_l2_rem,
flag isa NoLogJac ? flag : logjac + 0.5*log_r)
end

"""
(x, r′) = $SIGNATURES

Inverse of [`l2_remainder_transform`](@ref) in `x` and `y`.
"""
@inline l2_remainder_inverse(y, r) = atanh(y/√r), r-y^2
@inline function l2_remainder_inverse(y, log_r)
x = atanh(y / exp(log_r / 2))
log_r′ = logsubexp(log_r, 2 * log(abs(y)))
x, log_r′
end

####
#### UnitVector
Expand Down Expand Up @@ -65,16 +82,16 @@ end
function transform_with(flag::LogJacFlag, t::UnitVector, x::AbstractVector, index)
@unpack n = t
T = robust_eltype(x)
r = one(T)
log_r = zero(T)
y = Vector{T}(undef, n)
ℓ = logjac_zero(flag, T)
@inbounds for i in 1:(n - 1)
xi = x[index]
index += 1
y[i], r, ℓi = l2_remainder_transform(flag, xi, r)
y[i], log_r, ℓi = l2_remainder_transform(flag, xi, log_r)
ℓ += ℓi
end
y[end] = √r
y[end] = exp(log_r / 2)
y, ℓ, index
end

Expand All @@ -83,9 +100,9 @@ inverse_eltype(t::UnitVector, y::AbstractVector) = robust_eltype(y)
function inverse_at!(x::AbstractVector, index, t::UnitVector, y::AbstractVector)
@unpack n = t
@argcheck length(y) == n
r = one(eltype(y))
log_r = zero(eltype(y))
@inbounds for yi in axes(y, 1)[1:(end-1)]
x[index], r = l2_remainder_inverse(y[yi], r)
x[index], log_r = l2_remainder_inverse(y[yi], log_r)
index += 1
end
index
Expand Down Expand Up @@ -244,14 +261,14 @@ function calculate_corr_cholesky_factor!(U::AbstractMatrix{T}, flag::LogJacFlag,
n = size(U, 1)
ℓ = logjac_zero(flag, T)
@inbounds for col_index in 1:n
r = one(T)
log_r = zero(T)
for row_index in 1:(col_index-1)
xi = x[index]
U[row_index, col_index], r, ℓi = l2_remainder_transform(flag, xi, r)
U[row_index, col_index], log_r, ℓi = l2_remainder_transform(flag, xi, log_r)
ℓ += ℓi
index += 1
end
U[col_index, col_index] = √r
U[col_index, col_index] = exp(log_r / 2)
end
U, ℓ, index
end
Expand Down Expand Up @@ -285,9 +302,9 @@ function inverse_at!(x::AbstractVector, index,
n = result_size(t)
@argcheck size(U, 1) == n
@inbounds for col in 1:n
r = one(eltype(U))
log_r = zero(eltype(U))
for row in 1:(col-1)
x[index], r = l2_remainder_inverse(U[row, col], r)
x[index], log_r = l2_remainder_inverse(U[row, col], log_r)
index += 1
end
end
Expand Down
25 changes: 22 additions & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ using LogDensityProblems: logdensity, logdensity_and_gradient
using LogDensityProblemsAD
using TransformVariables:
AbstractTransform, ScalarTransform, VectorTransform, ArrayTransformation,
unit_triangular_dimension, logistic, logistic_logjac, logit, inverse_and_logjac, NOLOGJAC, transform_with
unit_triangular_dimension, logistic, logistic_logjac, logit, inverse_and_logjac,
NOLOGJAC, transform_with
import ChangesOfVariables, InverseFunctions
using Enzyme: autodiff, ReverseWithPrimal, Active, Const

Expand Down Expand Up @@ -136,9 +137,18 @@ end
end
end

@testset "tanh helpers" begin
for _ in 1:10000
x = (rand() - 0.5) * 100
@unpack log_l2_rem, logjac = TransformVariables.tanh_helpers(x)
@test Float64(AD_logjac(tanh, BigFloat(x))) ≈ logjac atol = 1e-4
@test Float64(log(sech(BigFloat(x))^2)) ≈ log_l2_rem atol = 1e-4
end
end

@testset "to correlation cholesky factor" begin
@testset "dimension checks" begin
C = CorrCholeskyFactor(3)
C = corr_cholesky_factor(3)
wrong_x = zeros(dimension(C) + 1)

@test_throws ArgumentError transform(C, wrong_x)
Expand All @@ -147,7 +157,7 @@ end

@testset "consistency checks" begin
for K in 1:8
t = CorrCholeskyFactor(K)
t = corr_cholesky_factor(K)
@test dimension(t) == (K - 1)*K/2
CIENV && @info "testing correlation cholesky K = $(K)"
if K > 1
Expand Down Expand Up @@ -615,6 +625,15 @@ end
end
end

@testset "corr cholesky factor large inputs" begin
t = corr_cholesky_factor(7)
d = dimension(t)
for _ in 1:100
x = sign.(rand(d) .- 0.5) .* 100
@test isfinite(logdet(transform(t, x)) )
end
end

@testset "pretty printing" begin
t = as((a = asℝ₊,
b = as(Array, asℝ₋, 3, 3),
Expand Down
5 changes: 3 additions & 2 deletions test/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@ $(SIGNATURES)

Log jacobian abs determinant via automatic differentiation. For testing.
"""
AD_logjac(f, x) = log(abs(ForwardDiff.derivative(f, x)))

AD_logjac(t::VectorTransform, x, vec_y) =
logabsdet(ForwardDiff.jacobian(x -> vec_y(transform(t, x)), x))[1]

AD_logjac(t::ScalarTransform, x) =
log(abs(ForwardDiff.derivative(x -> transform(t, x), x)))
AD_logjac(t::ScalarTransform, x) = AD_logjac(x -> transform(t, x), x)

"""
$(SIGNATURES)
Expand Down
Loading