From 86f0195966295d6c3b2edd3615bedcb9c7691d5a Mon Sep 17 00:00:00 2001 From: David Widmann Date: Sat, 16 Oct 2021 01:45:09 +0200 Subject: [PATCH 1/2] Replace NonlinearSolve with Roots --- Project.toml | 4 ++-- src/Bijectors.jl | 4 ++-- src/bijectors/planar_layer.jl | 12 ++++-------- src/compat/reversediff.jl | 5 +++-- 4 files changed, 11 insertions(+), 14 deletions(-) diff --git a/Project.toml b/Project.toml index a3ff85ae..4fd4b2c0 100644 --- a/Project.toml +++ b/Project.toml @@ -12,10 +12,10 @@ IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" MappedArrays = "dbb5928d-eab1-5f90-85c2-b9b0edb7c900" -NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" Requires = "ae029012-a4dd-5104-9daa-d747884805df" +Roots = "f2b01f46-fcfa-551c-844a-d8ac1e96c665" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" @@ -28,7 +28,7 @@ Functors = "0.1, 0.2" IrrationalConstants = "0.1" LogExpFunctions = "0.3.3" MappedArrays = "0.2.2, 0.3, 0.4" -NonlinearSolve = "0.3" Reexport = "0.2, 1" Requires = "0.5, 1" +Roots = "1.3.4" julia = "1.3" diff --git a/src/Bijectors.jl b/src/Bijectors.jl index 5c02ec5f..f61fce87 100644 --- a/src/Bijectors.jl +++ b/src/Bijectors.jl @@ -35,11 +35,11 @@ using MappedArrays using Base.Iterators: drop using LinearAlgebra: AbstractTriangular +import ChainRulesCore import Functors import IrrationalConstants import LogExpFunctions -import NonlinearSolve -import ChainRulesCore +import Roots export TransformDistribution, PositiveDistribution, diff --git a/src/bijectors/planar_layer.jl b/src/bijectors/planar_layer.jl index 2ad0d488..2d7cace8 100644 --- a/src/bijectors/planar_layer.jl +++ b/src/bijectors/planar_layer.jl @@ -157,16 +157,12 @@ function find_alpha(wt_y::T, wt_u_hat::T, b::T) where {T<:Real} # Compute the initial bracket (see above). initial_bracket = (wt_y - abs(wt_u_hat), wt_y + abs(wt_u_hat)) - # Try to solve the root-finding problem, i.e., compute a final bracket - prob = NonlinearSolve.NonlinearProblem{false}(initial_bracket) do α, _ - α + wt_u_hat * tanh(α + b) - wt_y - end - sol = NonlinearSolve.solve(prob, NonlinearSolve.Falsi()) - if sol.retcode === NonlinearSolve.MAXITERS_EXCEED - @warn "Planar layer: root finding algorithm did not converge" sol + # Solve the root-finding problem + α0 = Roots.find_zero(initial_bracket) do α + return α + wt_u_hat * tanh(α + b) - wt_y end - return sol.left + return α0 end logabsdetjac(flow::PlanarLayer, x) = forward(flow, x).logabsdetjac diff --git a/src/compat/reversediff.jl b/src/compat/reversediff.jl index 62c0691d..116d8531 100644 --- a/src/compat/reversediff.jl +++ b/src/compat/reversediff.jl @@ -9,7 +9,8 @@ using ..Bijectors: Log, SimplexBijector, maphcat, simplex_link_jacobian, ReverseDiffAD, Inverse import ..Bijectors: _eps, logabsdetjac, _logabsdetjac_scale, _simplex_bijector, _simplex_inv_bijector, replace_diag, jacobian, getpd, lower, - _inv_link_chol_lkj, _link_chol_lkj, _transform_ordered, _transform_inverse_ordered + _inv_link_chol_lkj, _link_chol_lkj, _transform_ordered, _transform_inverse_ordered, + find_alpha import ChainRulesCore @@ -187,7 +188,7 @@ function find_alpha(wt_y::T, wt_u_hat::T, b::T) where {T<:TrackedReal} return track(find_alpha, wt_y, wt_u_hat, b) end @grad function find_alpha(wt_y::TrackedReal, wt_u_hat::TrackedReal, b::TrackedReal) - α = find_alpha(data(wt_y), data(wt_u_hat), data(b)) + α = find_alpha(value(wt_y), value(wt_u_hat), value(b)) ∂wt_y = inv(1 + wt_u_hat * sech(α + b)^2) ∂wt_u_hat = - tanh(α + b) * ∂wt_y From 5cb4c1645ea9ca466659cbfcda7c3615f7a39b1a Mon Sep 17 00:00:00 2001 From: David Widmann Date: Sat, 16 Oct 2021 01:45:31 +0200 Subject: [PATCH 2/2] Bump version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 4fd4b2c0..9328a979 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Bijectors" uuid = "76274a88-744f-5084-9051-94815aaf08c4" -version = "0.9.8" +version = "0.9.9" [deps] ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"