From 706a08160921c97768257933a28fac8c93623232 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tam=C3=A1s=20K=2E=20Papp?= Date: Fri, 15 Dec 2023 15:20:44 +0100 Subject: [PATCH 1/2] Make corr_cholesky_factor deal with large inputs correctly. --- Project.toml | 2 +- src/special_arrays.jl | 57 ++++++++++++++++++++++++++++--------------- test/runtests.jl | 25 ++++++++++++++++--- test/utilities.jl | 5 ++-- 4 files changed, 63 insertions(+), 26 deletions(-) diff --git a/Project.toml b/Project.toml index 05078da..6d73afe 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TransformVariables" uuid = "84d833dd-6860-57f9-a1a7-6da5db126cff" authors = ["Tamas K. Papp "] -version = "0.8.9" +version = "0.8.10" [deps] ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" diff --git a/src/special_arrays.jl b/src/special_arrays.jl index dc19350..cd79874 100644 --- a/src/special_arrays.jl +++ b/src/special_arrays.jl @@ -7,28 +7,41 @@ export UnitVector, UnitSimplex, CorrCholeskyFactor, corr_cholesky_factor """ $(SIGNATURES) -`log(abs(…))` of the derivative of `tanh`, calculated accurately. +Return a `NamedTuple` of + +- `log_l2_rem`, for `log(1 - tanh(x)^2)`, + +- `logjac`, for `log(abs( ∂(log(abs(tanh(x))) / ∂x ))` + +Caller ensures that `x ≥ 0`. `x == 0` is handled correctly, but results in infinities. """ -function _tanh_logabsderiv(x) +function tanh_helpers(x) d = 2*x - log(4) + d - 2 * log1pexp(d) + log_denom = log1pexp(d) # log(exp(2x) + 1) + logjac = log(4) + d - 2 * log_denom # log(ab + log_l2_rem = 2*(log(2) + x - log_denom) # log(2exp(x) / (exp(2x) + 1)) + (; logjac, log_l2_rem) end """ - (y, r, ℓ) = $SIGNATURES + (y, log_r, ℓ) = $SIGNATURES -Given ``x ∈ ℝ`` and ``0 ≤ r ≤ 1``, return `(y, r′)` such that +Given ``x ∈ ℝ`` and ``0 ≤ r ≤ 1``, we define `(y, r′)` such that 1. ``y² + (r′)² = r²``, -2. ``y: |y| ≤ r`` is mapped with a bijection from `x`. +2. ``y: |y| ≤ r`` is mapped with a bijection from `x`, with the sign depending on `x`, + +but use `log(r)` for actual calculations so that large `y`s still give nonsingular results. `ℓ` is the log Jacobian (whether it is evaluated depends on `flag`). """ -@inline function l2_remainder_transform(flag::LogJacFlag, x, r) +@inline function l2_remainder_transform(flag::LogJacFlag, x, log_r) + (; logjac, log_l2_rem) = tanh_helpers(x) # note that 1-tanh(x)^2 = sech(x)^2 - (tanh(x) * √r, r*sech(x)^2, - flag isa NoLogJac ? flag : _tanh_logabsderiv(x) + 0.5*log(r)) + (tanh(x) * exp(log_r / 2), + log_r + log_l2_rem, + flag isa NoLogJac ? flag : logjac + 0.5*log_r) end """ @@ -36,7 +49,11 @@ end Inverse of [`l2_remainder_transform`](@ref) in `x` and `y`. """ -@inline l2_remainder_inverse(y, r) = atanh(y/√r), r-y^2 +@inline function l2_remainder_inverse(y, log_r) + x = atanh(y / exp(log_r / 2)) + log_r′ = logsubexp(log_r, 2 * log(abs(y))) + x, log_r′ +end #### #### UnitVector @@ -65,16 +82,16 @@ end function transform_with(flag::LogJacFlag, t::UnitVector, x::AbstractVector, index) @unpack n = t T = robust_eltype(x) - r = one(T) + log_r = zero(T) y = Vector{T}(undef, n) ℓ = logjac_zero(flag, T) @inbounds for i in 1:(n - 1) xi = x[index] index += 1 - y[i], r, ℓi = l2_remainder_transform(flag, xi, r) + y[i], log_r, ℓi = l2_remainder_transform(flag, xi, log_r) ℓ += ℓi end - y[end] = √r + y[end] = exp(log_r / 2) y, ℓ, index end @@ -83,9 +100,9 @@ inverse_eltype(t::UnitVector, y::AbstractVector) = robust_eltype(y) function inverse_at!(x::AbstractVector, index, t::UnitVector, y::AbstractVector) @unpack n = t @argcheck length(y) == n - r = one(eltype(y)) + log_r = zero(eltype(y)) @inbounds for yi in axes(y, 1)[1:(end-1)] - x[index], r = l2_remainder_inverse(y[yi], r) + x[index], log_r = l2_remainder_inverse(y[yi], log_r) index += 1 end index @@ -244,14 +261,14 @@ function calculate_corr_cholesky_factor!(U::AbstractMatrix{T}, flag::LogJacFlag, n = size(U, 1) ℓ = logjac_zero(flag, T) @inbounds for col_index in 1:n - r = one(T) + log_r = zero(T) for row_index in 1:(col_index-1) xi = x[index] - U[row_index, col_index], r, ℓi = l2_remainder_transform(flag, xi, r) + U[row_index, col_index], log_r, ℓi = l2_remainder_transform(flag, xi, log_r) ℓ += ℓi index += 1 end - U[col_index, col_index] = √r + U[col_index, col_index] = exp(log_r / 2) end U, ℓ, index end @@ -285,9 +302,9 @@ function inverse_at!(x::AbstractVector, index, n = result_size(t) @argcheck size(U, 1) == n @inbounds for col in 1:n - r = one(eltype(U)) + log_r = zero(eltype(U)) for row in 1:(col-1) - x[index], r = l2_remainder_inverse(U[row, col], r) + x[index], log_r = l2_remainder_inverse(U[row, col], log_r) index += 1 end end diff --git a/test/runtests.jl b/test/runtests.jl index 9c80888..56e8edb 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -6,7 +6,8 @@ using LogDensityProblems: logdensity, logdensity_and_gradient using LogDensityProblemsAD using TransformVariables: AbstractTransform, ScalarTransform, VectorTransform, ArrayTransformation, - unit_triangular_dimension, logistic, logistic_logjac, logit, inverse_and_logjac, NOLOGJAC, transform_with + unit_triangular_dimension, logistic, logistic_logjac, logit, inverse_and_logjac, + NOLOGJAC, transform_with import ChangesOfVariables, InverseFunctions using Enzyme: autodiff, ReverseWithPrimal, Active, Const @@ -136,9 +137,18 @@ end end end +@testset "tanh helpers" begin + for _ in 1:10000 + x = (rand() - 0.5) * 100 + @unpack log_l2_rem, logjac = TransformVariables.tanh_helpers(x) + @test Float64(AD_logjac(tanh, BigFloat(x))) ≈ logjac atol = 1e-4 + @test Float64(log(sech(BigFloat(x))^2)) ≈ log_l2_rem atol = 1e-4 + end +end + @testset "to correlation cholesky factor" begin @testset "dimension checks" begin - C = CorrCholeskyFactor(3) + C = corr_cholesky_factor(3) wrong_x = zeros(dimension(C) + 1) @test_throws ArgumentError transform(C, wrong_x) @@ -147,7 +157,7 @@ end @testset "consistency checks" begin for K in 1:8 - t = CorrCholeskyFactor(K) + t = corr_cholesky_factor(K) @test dimension(t) == (K - 1)*K/2 CIENV && @info "testing correlation cholesky K = $(K)" if K > 1 @@ -615,6 +625,15 @@ end end end +@testset "corr cholesky factor large inputs" begin + t = corr_cholesky_factor(7) + d = dimension(t) + for _ in 1:100 + x = sign.(rand(d) .- 0.5) .* 100 + @test isfinite(logdet(transform(t, x)) ) + end +end + @testset "pretty printing" begin t = as((a = asℝ₊, b = as(Array, asℝ₋, 3, 3), diff --git a/test/utilities.jl b/test/utilities.jl index 83a7d99..18549b8 100644 --- a/test/utilities.jl +++ b/test/utilities.jl @@ -3,11 +3,12 @@ $(SIGNATURES) Log jacobian abs determinant via automatic differentiation. For testing. """ +AD_logjac(f, x) = log(abs(ForwardDiff.derivative(f, x))) + AD_logjac(t::VectorTransform, x, vec_y) = logabsdet(ForwardDiff.jacobian(x -> vec_y(transform(t, x)), x))[1] -AD_logjac(t::ScalarTransform, x) = - log(abs(ForwardDiff.derivative(x -> transform(t, x), x))) +AD_logjac(t::ScalarTransform, x) = AD_logjac(x -> transform(t, x), x) """ $(SIGNATURES) From 344b8b695c726c6a76f5997427dc232d7f057a4a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tam=C3=A1s=20K=2E=20Papp?= Date: Fri, 15 Dec 2023 15:29:21 +0100 Subject: [PATCH 2/2] use @unpack instead of (; ...) syntax --- src/special_arrays.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/special_arrays.jl b/src/special_arrays.jl index cd79874..108fe1d 100644 --- a/src/special_arrays.jl +++ b/src/special_arrays.jl @@ -37,7 +37,7 @@ but use `log(r)` for actual calculations so that large `y`s still give nonsingul `ℓ` is the log Jacobian (whether it is evaluated depends on `flag`). """ @inline function l2_remainder_transform(flag::LogJacFlag, x, log_r) - (; logjac, log_l2_rem) = tanh_helpers(x) + @unpack logjac, log_l2_rem = tanh_helpers(x) # note that 1-tanh(x)^2 = sech(x)^2 (tanh(x) * exp(log_r / 2), log_r + log_l2_rem,