From 31adbe8150e0722d524d272d053462d1bd96e04f Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Fri, 10 Dec 2021 22:29:20 +0100 Subject: [PATCH 01/29] Add ChangesOfVariables and InverseFunctions to deps --- Project.toml | 4 ++++ src/Bijectors.jl | 2 ++ test/Project.toml | 4 ++++ test/runtests.jl | 3 +++ 4 files changed, 13 insertions(+) diff --git a/Project.toml b/Project.toml index 2c0e7ab3..1f296958 100644 --- a/Project.toml +++ b/Project.toml @@ -5,9 +5,11 @@ version = "0.9.11" [deps] ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" +InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" @@ -22,9 +24,11 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] ArgCheck = "1, 2" ChainRulesCore = "0.10.11, 1" +ChangesOfVariables = "0.1" Compat = "3" Distributions = "0.23.3, 0.24, 0.25" Functors = "0.1, 0.2" +InverseFunctions = "0.1" IrrationalConstants = "0.1" LogExpFunctions = "0.3.3" MappedArrays = "0.2.2, 0.3, 0.4" diff --git a/src/Bijectors.jl b/src/Bijectors.jl index f61fce87..36779fd4 100644 --- a/src/Bijectors.jl +++ b/src/Bijectors.jl @@ -36,7 +36,9 @@ using Base.Iterators: drop using LinearAlgebra: AbstractTriangular import ChainRulesCore +import ChangesOfVariables import Functors +import InverseFunctions import IrrationalConstants import LogExpFunctions import Roots diff --git a/test/Project.toml b/test/Project.toml index dd834763..82d08ffc 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,10 +1,12 @@ [deps] ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" +ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" +InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -15,11 +17,13 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] ChainRulesTestUtils = "0.7, 1" +ChangesOfVariables = "0.1" Combinatorics = "1.0.2" DistributionsAD = "0.6.3" FiniteDifferences = "0.11, 0.12" ForwardDiff = "0.10.12" Functors = "0.1, 0.2" +InverseFunctions = "0.1" LogExpFunctions = "0.3.1" ReverseDiff = "1.4.2" Tracker = "0.2.11" diff --git a/test/runtests.jl b/test/runtests.jl index 2794abf9..4f44833e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,16 +1,19 @@ using Bijectors using ChainRulesTestUtils +using ChangesOfVariables using Combinatorics using DistributionsAD using FiniteDifferences using ForwardDiff using Functors +using InverseFunctions using LogExpFunctions using ReverseDiff using Tracker using Zygote + using Random, LinearAlgebra, Test using Bijectors: Log, Exp, Shift, Scale, Logit, SimplexBijector, PDBijector, Permute, From 000c83ff73f36f5cd91a6496010de9bc7971766d Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Fri, 10 Dec 2021 23:49:06 +0100 Subject: [PATCH 02/29] Replace forward by with_logabsdet_jacobian --- README.md | 33 +++++------ src/Bijectors.jl | 5 +- src/bijectors/composed.jl | 20 +++---- src/bijectors/leaky_relu.jl | 14 ++--- src/bijectors/named_bijector.jl | 22 ++++---- src/bijectors/normalise.jl | 12 ++-- src/bijectors/planar_layer.jl | 4 +- src/bijectors/radial_layer.jl | 4 +- src/bijectors/rational_quadratic_spline.jl | 6 +- src/bijectors/stacked.jl | 6 +- src/interface.jl | 6 +- src/transformed_distribution.jl | 24 ++++---- test/bijectors/coupling.jl | 4 +- test/bijectors/leaky_relu.jl | 16 +++--- test/bijectors/utils.jl | 26 ++++----- test/interface.jl | 64 +++++++++++----------- test/norm_flows.jl | 8 +-- 17 files changed, 139 insertions(+), 135 deletions(-) diff --git a/README.md b/README.md index e255fe88..7fdc4424 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,7 @@ The following table lists mathematical operations for a bijector and the corresp | `x ↦ b(x)` | `b(x)` | × | | `y ↦ b⁻¹(y)` | `inv(b)(y)` | × | | `x ↦ log|det J(b, x)|` | `logabsdetjac(b, x)` | AD | -| `x ↦ b(x), log|det J(b, x)|` | `forward(b, x)` | ✓ | +| `x ↦ b(x), log|det J(b, x)|` | `with_logabsdet_jacobian(b, x)` | ✓ | | `p ↦ q := b_* p` | `q = transformed(p, b)` | ✓ | | `y ∼ q` | `y = rand(q)` | ✓ | | `p ↦ b` such that `support(b_* p) = ℝᵈ` | `bijector(p)` | ✓ | @@ -221,18 +221,18 @@ true which is always the case for a differentiable bijection with differentiable inverse. Therefore if you want to compute `logabsdetjac(b⁻¹, y)` and we know that `logabsdetjac(b, b⁻¹(y))` is actually more efficient, we'll return `-logabsdetjac(b, b⁻¹(y))` instead. For some bijectors it might be easy to compute, say, the forward pass `b(x)`, but expensive to compute `b⁻¹(y)`. Because of this you might want to avoid doing anything "backwards", i.e. using `b⁻¹`. This is where `forward` comes to good use: ```julia -julia> forward(b, x) -(rv = -0.5369949942509267, logabsdetjac = 1.4575353795716655) +julia> with_logabsdet_jacobian(b, x) +(-0.5369949942509267, 1.4575353795716655) ``` Similarily ```julia julia> forward(inv(b), y) -(rv = 0.3688868996596376, logabsdetjac = -1.4575353795716655) +(0.3688868996596376, -1.4575353795716655) ``` -In fact, the purpose of `forward` is to just _do the right thing_, not necessarily "forward". In this function we'll have access to both the original value `x` and the transformed value `y`, so we can compute `logabsdetjac(b, x)` in either direction. Furthermore, in a lot of cases we can re-use a lot of the computation from `b(x)` in the computation of `logabsdetjac(b, x)`, or vice-versa. `forward(b, x)` will take advantage of such opportunities (if implemented). +In fact, the purpose of `with_logabsdet_jacobian` is to just _do the right thing_, not necessarily "forward". In this function we'll have access to both the original value `x` and the transformed value `y`, so we can compute `logabsdetjac(b, x)` in either direction. Furthermore, in a lot of cases we can re-use a lot of the computation from `b(x)` in the computation of `logabsdetjac(b, x)`, or vice-versa. `with_logabsdet_jacobian(b, x)` will take advantage of such opportunities (if implemented). #### Sampling from `TransformedDistribution` At this point we've only shown that we can replicate the existing functionality. But we said `TransformedDistribution isa Distribution`, so we also have `rand`: @@ -481,7 +481,7 @@ julia> Flux.params(flow) Params([[-1.05099; 0.502079] (tracked), [-0.216248; -0.706424] (tracked), [-4.33747] (tracked)]) ``` -Another useful function is the `forward(d::Distribution)` method. It is similar to `forward(b::Bijector)` in the sense that it does a forward pass of the entire process "sample then transform" and returns all the most useful quantities in process using the most efficent computation path. +Another useful function is the `forward(d::Distribution)` method. It is similar to `with_logabsdet_jacobian(b::Bijector, x)` in the sense that it does a forward pass of the entire process "sample then transform" and returns all the most useful quantities in process using the most efficent computation path. ```julia julia> x, y, logjac, logpdf_y = forward(flow) # sample + transform and returns all the useful quantities in one pass @@ -555,28 +555,29 @@ Tracked 2-element Array{Float64,1}: -1.546158373866469 -1.6098711387913573 -julia> forward(b, 0.6) # defaults to `(rv=b(x), logabsdetjac=logabsdetjac(b, x))` -(rv = 0.4054651081081642, logabsdetjac = 1.4271163556401458) +julia> with_logabsdet_jacobian(b, 0.6) # defaults to `(b(x), logabsdetjac(b, x))` +(0.4054651081081642, 1.4271163556401458) ``` -For further efficiency, one could manually implement `forward(b::Logit, x)`: +For further efficiency, one could manually implement `with_logabsdet_jacobian(b::Logit, x)`: ```julia julia> import Bijectors: forward, Logit +julia> import ChangesOfVariables: with_logabsdet_jacobian -julia> function forward(b::Logit{<:Real}, x) +julia> function with_logabsdet_jacobian(b::Logit{<:Real}, x) totally_worth_saving = @. (x - b.a) / (b.b - b.a) # spoiler: it's probably not y = logit.(totally_worth_saving) logjac = @. - log((b.b - x) * totally_worth_saving) - return (rv=y, logabsdetjac = logjac) + return (y, logjac) end forward (generic function with 16 methods) -julia> forward(b, 0.6) -(rv = 0.4054651081081642, logabsdetjac = 1.4271163556401458) +julia> with_logabsdet_jacobian(b, 0.6) +(0.4054651081081642, 1.4271163556401458) -julia> @which forward(b, 0.6) -forward(b::Logit{#s4} where #s4<:Real, x) in Main at REPL[43]:2 +julia> @which with_logabsdet_jacobian(b, 0.6) +with_logabsdet_jacobian(b::Logit{#s4} where #s4<:Real, x) in Main at REPL[43]:2 ``` As you can see it's a very contrived example, but you get the idea. @@ -715,7 +716,7 @@ The following methods are implemented by all subtypes of `Bijector`, this also i - `(b::Bijector)(x)`: implements the transform of the `Bijector` - `inv(b::Bijector)`: returns the inverse of `b`, i.e. `ib::Bijector` s.t. `(ib ∘ b)(x) ≈ x`. In most cases this is `Inverse{<:Bijector}`. - `logabsdetjac(b::Bijector, x)`: computes log(abs(det(jacobian(b, x)))). -- `forward(b::Bijector, x)`: returns named tuple `(rv=b(x), logabsdetjac=logabsdetjac(b, x))` in the most efficient manner. +- `with_logabsdet_jacobian(b::Bijector, x)`: returns named tuple `(b(x), logabsdetjac(b, x))` in the most efficient manner. - `∘`, `composel`, `composer`: convenient and type-safe constructors for `Composed`. `composel(bs...)` composes s.t. the resulting composition is evaluated left-to-right, while `composer(bs...)` is evaluated right-to-left. `∘` is right-to-left, as excepted from standard mathematical notation. - `jacobian(b::Bijector, x)` [OPTIONAL]: returns the Jacobian of the transformation. In some cases the analytical Jacobian has been implemented for efficiency. - `dimension(b::Bijector)`: returns the dimensionality of `b`. diff --git a/src/Bijectors.jl b/src/Bijectors.jl index 36779fd4..2393283d 100644 --- a/src/Bijectors.jl +++ b/src/Bijectors.jl @@ -35,8 +35,9 @@ using MappedArrays using Base.Iterators: drop using LinearAlgebra: AbstractTriangular +import ChangesOfVariables: with_logabsdet_jacobian + import ChainRulesCore -import ChangesOfVariables import Functors import InverseFunctions import IrrationalConstants @@ -251,6 +252,8 @@ include("utils.jl") include("interface.jl") include("chainrules.jl") +Base.@deprecate forward(b::AbstractBijector, x) with_logabsdet_jacobian(b, x) + # Broadcasting here breaks Tracker for some reason maporbroadcast(f, x::AbstractArray{<:Any, N}...) where {N} = map(f, x...) maporbroadcast(f, x::AbstractArray...) = f.(x...) diff --git a/src/bijectors/composed.jl b/src/bijectors/composed.jl index f05b819c..6a02432e 100644 --- a/src/bijectors/composed.jl +++ b/src/bijectors/composed.jl @@ -179,8 +179,8 @@ function logabsdetjac(cb::Composed, x) y, logjac = forward(cb.ts[1], x) for i = 2:length(cb.ts) res = forward(cb.ts[i], y) - y = res.rv - logjac += res.logabsdetjac + y = res[1] + logjac += res[2] end return logjac @@ -195,8 +195,8 @@ end for i = 2:N - 1 temp = gensym(:res) push!(expr.args, :($temp = forward(cb.ts[$i], y))) - push!(expr.args, :(y = $temp.rv)) - push!(expr.args, :(logjac += $temp.logabsdetjac)) + push!(expr.args, :(y = $temp[1])) + push!(expr.args, :(logjac += $temp[2])) end # don't need to evaluate the last bijector, only it's `logabsdetjac` push!(expr.args, :(logjac += logabsdetjac(cb.ts[$N], y))) @@ -212,10 +212,10 @@ function forward(cb::Composed, x) for t in cb.ts[2:end] res = forward(t, rv) - rv = res.rv - logjac = res.logabsdetjac + logjac + rv = res[1] + logjac = res[2] + logjac end - return (rv=rv, logabsdetjac=logjac) + return (rv, logjac) end @@ -225,10 +225,10 @@ end for i = 2:length(T.parameters) temp = gensym(:temp) push!(expr.args, :($temp = forward(cb.ts[$i], y))) - push!(expr.args, :(y = $temp.rv)) - push!(expr.args, :(logjac += $temp.logabsdetjac)) + push!(expr.args, :(y = $temp[1])) + push!(expr.args, :(logjac += $temp[2])) end - push!(expr.args, :(return (rv = y, logabsdetjac = logjac))) + push!(expr.args, :(return (y, logjac))) return expr end diff --git a/src/bijectors/leaky_relu.jl b/src/bijectors/leaky_relu.jl index 65060c14..62413316 100644 --- a/src/bijectors/leaky_relu.jl +++ b/src/bijectors/leaky_relu.jl @@ -44,21 +44,21 @@ end logabsdetjac(b::LeakyReLU{<:Real, 0}, x::AbstractVector{<:Real}) = map(x -> logabsdetjac(b, x), x) -# We implement `forward` by hand since we can re-use the computation of +# We implement `with_logabsdet_jacobian` by hand since we can re-use the computation of # the Jacobian of the transformation. This will lead to faster sampling # when using `rand` on a `TransformedDistribution` making use of `LeakyReLU`. -function forward(b::LeakyReLU{<:Any, 0}, x::Real) +function with_logabsdet_jacobian(b::LeakyReLU{<:Any, 0}, x::Real) mask = x < zero(x) J = mask * b.α + !mask * one(x) - return (rv=J * x, logabsdetjac=log(abs(J))) + return (J * x, log(abs(J))) end # Batched version -function forward(b::LeakyReLU{<:Any, 0}, x::AbstractVector) +function with_logabsdet_jacobian(b::LeakyReLU{<:Any, 0}, x::AbstractVector) J = let T = eltype(x), z = zero(T), o = one(T) @. (x < z) * b.α + (x > z) * o end - return (rv=J .* x, logabsdetjac=log.(abs.(J))) + return (J .* x, log.(abs.(J))) end # (N=1) Multivariate case @@ -84,7 +84,7 @@ end # We implement `forward` by hand since we can re-use the computation of # the Jacobian of the transformation. This will lead to faster sampling # when using `rand` on a `TransformedDistribution` making use of `LeakyReLU`. -function forward(b::LeakyReLU{<:Any, 1}, x::AbstractVecOrMat) +function with_logabsdet_jacobian(b::LeakyReLU{<:Any, 1}, x::AbstractVecOrMat) # Is really diagonal of jacobian J = let T = eltype(x), z = zero(T), o = one(T) @. (x < z) * b.α + (x > z) * o @@ -97,5 +97,5 @@ function forward(b::LeakyReLU{<:Any, 1}, x::AbstractVecOrMat) end y = J .* x - return (rv=y, logabsdetjac=logjac) + return (y, logjac) end diff --git a/src/bijectors/named_bijector.jl b/src/bijectors/named_bijector.jl index 36f8691c..0543f91a 100644 --- a/src/bijectors/named_bijector.jl +++ b/src/bijectors/named_bijector.jl @@ -1,6 +1,6 @@ abstract type AbstractNamedBijector <: AbstractBijector end -forward(b::AbstractNamedBijector, x) = (rv = b(x), logabsdetjac = logabsdetjac(b, x)) +with_logabsdet_jacobian(b::AbstractNamedBijector, x) = (b(x), logabsdetjac(b, x)) ####################### ### `NamedBijector` ### @@ -125,8 +125,8 @@ function logabsdetjac(cb::NamedComposition, x) y, logjac = forward(cb.bs[1], x) for i = 2:length(cb.bs) res = forward(cb.bs[i], y) - y = res.rv - logjac += res.logabsdetjac + y = res[1] + logjac += res[2] end return logjac @@ -141,8 +141,8 @@ end for i = 2:N - 1 temp = gensym(:res) push!(expr.args, :($temp = forward(cb.bs[$i], y))) - push!(expr.args, :(y = $temp.rv)) - push!(expr.args, :(logjac += $temp.logabsdetjac)) + push!(expr.args, :(y = $temp[1])) + push!(expr.args, :(logjac += $temp[2])) end # don't need to evaluate the last bijector, only it's `logabsdetjac` push!(expr.args, :(logjac += logabsdetjac(cb.bs[$N], y))) @@ -158,10 +158,10 @@ function forward(cb::NamedComposition, x) for t in cb.bs[2:end] res = forward(t, rv) - rv = res.rv - logjac = res.logabsdetjac + logjac + rv = res[1] + logjac = res[2] + logjac end - return (rv=rv, logabsdetjac=logjac) + return (rv, logjac) end @@ -171,10 +171,10 @@ end for i = 2:length(T.parameters) temp = gensym(:temp) push!(expr.args, :($temp = forward(cb.bs[$i], y))) - push!(expr.args, :(y = $temp.rv)) - push!(expr.args, :(logjac += $temp.logabsdetjac)) + push!(expr.args, :(y = $temp[1])) + push!(expr.args, :(logjac += $temp[2])) end - push!(expr.args, :(return (rv = y, logabsdetjac = logjac))) + push!(expr.args, :(return (y, logjac))) return expr end diff --git a/src/bijectors/normalise.jl b/src/bijectors/normalise.jl index 81496468..defb447e 100644 --- a/src/bijectors/normalise.jl +++ b/src/bijectors/normalise.jl @@ -48,7 +48,7 @@ function Functors.functor(::Type{<:InvertibleBatchNorm}, x) return (b = x.b, logs = x.logs), reconstruct_invertiblebatchnorm end -function forward(bn::InvertibleBatchNorm, x) +function with_logabsdet_jacobian(bn::InvertibleBatchNorm, x) dims = ndims(x) size(x, dims - 1) == length(bn.b) || error("InvertibleBatchNorm expected $(length(bn.b)) channels, got $(size(x, dims - 1))") @@ -76,12 +76,12 @@ function forward(bn::InvertibleBatchNorm, x) logabsdetjac = ( fill(sum(logs - log.(v .+ bn.eps) / 2), size(x, dims)) ) - return (rv=rv, logabsdetjac=logabsdetjac) + return (rv, logabsdetjac) end -logabsdetjac(bn::InvertibleBatchNorm, x) = forward(bn, x).logabsdetjac +logabsdetjac(bn::InvertibleBatchNorm, x) = with_logabsdet_jacobian(bn, x)[2] -(bn::InvertibleBatchNorm)(x) = forward(bn, x).rv +(bn::InvertibleBatchNorm)(x) = with_logabsdet_jacobian(bn, x)[1] function forward(invbn::Inverse{<:InvertibleBatchNorm}, y) @assert !istraining() "`forward(::Inverse{InvertibleBatchNorm})` is only available in test mode." @@ -94,10 +94,10 @@ function forward(invbn::Inverse{<:InvertibleBatchNorm}, y) v = reshape(bn.v, as...) x = (y .- b) ./ s .* sqrt.(v .+ bn.eps) .+ m - return (rv=x, logabsdetjac=-logabsdetjac(bn, x)) + return (x, -logabsdetjac(bn, x)) end -(bn::Inverse{<:InvertibleBatchNorm})(y) = forward(bn, y).rv +(bn::Inverse{<:InvertibleBatchNorm})(y) = with_logabsdet_jacobian(bn, y)[1] function Base.show(io::IO, l::InvertibleBatchNorm) print(io, "InvertibleBatchNorm($(join(size(l.b), ", ")))") diff --git a/src/bijectors/planar_layer.jl b/src/bijectors/planar_layer.jl index 50070396..f2dbefff 100644 --- a/src/bijectors/planar_layer.jl +++ b/src/bijectors/planar_layer.jl @@ -101,7 +101,7 @@ function forward(flow::PlanarLayer, z::AbstractVecOrMat{<:Real}) b = first(flow.b) log_det_jacobian = log1p.(wT_û .* abs2.(sech.(_vec(wT_z) .+ b))) - return (rv = transformed, logabsdetjac = log_det_jacobian) + return (transformed, log_det_jacobian) end function (ib::Inverse{<:PlanarLayer})(y::AbstractVecOrMat{<:Real}) @@ -175,5 +175,5 @@ function find_alpha(wt_y::T, wt_u_hat::T, b::T) where {T<:Real} return α0 end -logabsdetjac(flow::PlanarLayer, x) = forward(flow, x).logabsdetjac +logabsdetjac(flow::PlanarLayer, x) = forward(flow, x)[2] isclosedform(b::Inverse{<:PlanarLayer}) = false diff --git a/src/bijectors/radial_layer.jl b/src/bijectors/radial_layer.jl index 7c79712c..11c3d799 100644 --- a/src/bijectors/radial_layer.jl +++ b/src/bijectors/radial_layer.jl @@ -63,7 +63,7 @@ function forward(flow::RadialLayer, z::AbstractVecOrMat) (d - 1) * log(1 + β_hat * h_) + log(1 + β_hat * h_ + β_hat * (- h_ ^ 2) * r) ) # from eq(14) - return (rv = transformed, logabsdetjac = log_det_jacobian) + return (transformed, log_det_jacobian) end function (ib::Inverse{<:RadialLayer})(y::AbstractVector{<:Real}) @@ -123,4 +123,4 @@ function compute_r(y_minus_z0::AbstractVector{<:Real}, α, α_plus_β_hat) return r end -logabsdetjac(flow::RadialLayer, x::AbstractVecOrMat) = forward(flow, x).logabsdetjac +logabsdetjac(flow::RadialLayer, x::AbstractVecOrMat) = forward(flow, x)[2] diff --git a/src/bijectors/rational_quadratic_spline.jl b/src/bijectors/rational_quadratic_spline.jl index 8f081fcd..ef34c436 100644 --- a/src/bijectors/rational_quadratic_spline.jl +++ b/src/bijectors/rational_quadratic_spline.jl @@ -343,7 +343,7 @@ function rqs_forward( T = promote_type(eltype(widths), eltype(heights), eltype(derivatives), eltype(x)) if (x ≤ -widths[end]) || (x ≥ widths[end]) - return (rv = one(T) * x, logabsdetjac = zero(T) * x) + return (one(T) * x, zero(T) * x) end # Find which bin `x` is in @@ -376,9 +376,9 @@ function rqs_forward( numerator_y = Δy * (s * ξ^2 + d_k * ξ * (1 - ξ)) y = h_k + numerator_y / denominator - return (rv = y, logabsdetjac = logjac) + return (y, logjac) end -function forward(b::RationalQuadraticSpline{<:AbstractVector, 0}, x::Real) +function with_logabsdet_jacobian(b::RationalQuadraticSpline{<:AbstractVector, 0}, x::Real) return rqs_forward(b.widths, b.heights, b.derivatives, x) end diff --git a/src/bijectors/stacked.jl b/src/bijectors/stacked.jl index 5c3fdb6b..144c7817 100644 --- a/src/bijectors/stacked.jl +++ b/src/bijectors/stacked.jl @@ -136,7 +136,7 @@ end # logjac = sum(_logjac) # (y_2, _logjac) = forward(b.bs[2], x[b.ranges[2]]) # logjac += sum(_logjac) -# return (rv = vcat(y_1, y_2), logabsdetjac = logjac) +# return (vcat(y_1, y_2), logjac) # end @generated function forward(b::Stacked{<:Tuple{Vararg{<:Any, N}}, <:Tuple{Vararg{<:Any, N}}}, x::AbstractVector) where {N} expr = Expr(:block) @@ -156,7 +156,7 @@ end push!(y_names, y_name) end - push!(expr.args, :(return (rv = vcat($(y_names...)), logabsdetjac = logjac))) + push!(expr.args, :(return (vcat($(y_names...)), logjac))) return expr end @@ -169,5 +169,5 @@ function forward(sb::Stacked, x::AbstractVector) logjac += sum(l) y end - return (rv = vcat(yinit, ys), logabsdetjac = logjac) + return (vcat(yinit, ys), logjac) end diff --git a/src/interface.jl b/src/interface.jl index c1ca115c..d2e8be7b 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -89,17 +89,17 @@ Default implementation for `Inverse{<:Bijector}` is implemented as logabsdetjac(ib::Inverse{<:Bijector}, y) = - logabsdetjac(ib.orig, ib(y)) """ - forward(b::Bijector, x) + with_logabsdet_jacobian(b::Bijector, x) Computes both `transform` and `logabsdetjac` in one forward pass, and -returns a named tuple `(rv=b(x), logabsdetjac=logabsdetjac(b, x))`. +returns a named tuple `(b(x), logabsdetjac(b, x))`. This defaults to the call above, but often one can re-use computation in the computation of the forward pass and the computation of the `logabsdetjac`. `forward` allows the user to take advantange of such efficiencies, if they exist. """ -forward(b::Bijector, x) = (rv=b(x), logabsdetjac=logabsdetjac(b, x)) +with_logabsdet_jacobian(b::Bijector, x) = (b(x), logabsdetjac(b, x)) """ logabsdetjacinv(b::Bijector, y) diff --git a/src/transformed_distribution.jl b/src/transformed_distribution.jl index 1712ba2e..48afd2ff 100644 --- a/src/transformed_distribution.jl +++ b/src/transformed_distribution.jl @@ -86,14 +86,14 @@ Base.size(td::Transformed) = size(td.dist) function logpdf(td::UnivariateTransformed, y::Real) res = forward(inv(td.transform), y) - return logpdf(td.dist, res.rv) + res.logabsdetjac + return logpdf(td.dist, res[1]) + res[2] end # TODO: implement more efficiently for flows in the case of `Matrix` function logpdf(td::MvTransformed, y::AbstractMatrix{<:Real}) # batch-implementation for multivariate res = forward(inv(td.transform), y) - return logpdf(td.dist, res.rv) + res.logabsdetjac + return logpdf(td.dist, res[1]) + res[2] end function logpdf(td::MvTransformed{<:Dirichlet}, y::AbstractMatrix{<:Real}) @@ -101,12 +101,12 @@ function logpdf(td::MvTransformed{<:Dirichlet}, y::AbstractMatrix{<:Real}) ϵ = _eps(T) res = forward(inv(td.transform), y) - return logpdf(td.dist, mappedarray(x->x+ϵ, res.rv)) + res.logabsdetjac + return logpdf(td.dist, mappedarray(x->x+ϵ, res[1])) + res[2] end function _logpdf(td::MvTransformed, y::AbstractVector{<:Real}) res = forward(inv(td.transform), y) - return logpdf(td.dist, res.rv) + res.logabsdetjac + return logpdf(td.dist, res[1]) + res[2] end function _logpdf(td::MvTransformed{<:Dirichlet}, y::AbstractVector{<:Real}) @@ -114,12 +114,12 @@ function _logpdf(td::MvTransformed{<:Dirichlet}, y::AbstractVector{<:Real}) ϵ = _eps(T) res = forward(inv(td.transform), y) - return logpdf(td.dist, mappedarray(x->x+ϵ, res.rv)) + res.logabsdetjac + return logpdf(td.dist, mappedarray(x->x+ϵ, res[1])) + res[2] end # TODO: should eventually drop using `logpdf_with_trans` and replace with # res = forward(inv(td.transform), y) -# logpdf(td.dist, res.rv) .- res.logabsdetjac +# logpdf(td.dist, res[1]) .- res[2] function _logpdf(td::MatrixTransformed, y::AbstractMatrix{<:Real}) return logpdf_with_trans(td.dist, inv(td.transform)(y), true) end @@ -164,18 +164,18 @@ and returns a tuple `(logpdf, logabsdetjac)`. """ function logpdf_with_jac(td::UnivariateTransformed, y::Real) res = forward(inv(td.transform), y) - return (logpdf(td.dist, res.rv) + res.logabsdetjac, res.logabsdetjac) + return (logpdf(td.dist, res[1]) + res[2], res[2]) end # TODO: implement more efficiently for flows in the case of `Matrix` function logpdf_with_jac(td::MvTransformed, y::AbstractVector{<:Real}) res = forward(inv(td.transform), y) - return (logpdf(td.dist, res.rv) + res.logabsdetjac, res.logabsdetjac) + return (logpdf(td.dist, res[1]) + res[2], res[2]) end function logpdf_with_jac(td::MvTransformed, y::AbstractMatrix{<:Real}) res = forward(inv(td.transform), y) - return (logpdf(td.dist, res.rv) + res.logabsdetjac, res.logabsdetjac) + return (logpdf(td.dist, res[1]) + res[2], res[2]) end function logpdf_with_jac(td::MvTransformed{<:Dirichlet}, y::AbstractVector{<:Real}) @@ -183,14 +183,14 @@ function logpdf_with_jac(td::MvTransformed{<:Dirichlet}, y::AbstractVector{<:Rea ϵ = _eps(T) res = forward(inv(td.transform), y) - lp = logpdf(td.dist, mappedarray(x->x+ϵ, res.rv)) + res.logabsdetjac - return (lp, res.logabsdetjac) + lp = logpdf(td.dist, mappedarray(x->x+ϵ, res[1])) + res[2] + return (lp, res[2]) end # TODO: should eventually drop using `logpdf_with_trans` function logpdf_with_jac(td::MatrixTransformed, y::AbstractMatrix{<:Real}) res = forward(inv(td.transform), y) - return (logpdf_with_trans(td.dist, res.rv, true), res.logabsdetjac) + return (logpdf_with_trans(td.dist, res[1], true), res[2]) end """ diff --git a/test/bijectors/coupling.jl b/test/bijectors/coupling.jl index fcf1c402..ead12bdb 100644 --- a/test/bijectors/coupling.jl +++ b/test/bijectors/coupling.jl @@ -45,8 +45,8 @@ using Bijectors: @test logabsdetjac(cl1, x) == logabsdetjac(b, x[1:1]) # forward - @test forward(cl1, x) == (rv = cl1(x), logabsdetjac = logabsdetjac(cl1, x)) - @test forward(icl1, cl1(x)) == (rv = x, logabsdetjac = - logabsdetjac(cl1, x)) + @test forward(cl1, x) == (cl1(x), logabsdetjac(cl1, x)) + @test forward(icl1, cl1(x)) == (x, - logabsdetjac(cl1, x)) end @testset "Classic" begin diff --git a/test/bijectors/leaky_relu.jl b/test/bijectors/leaky_relu.jl index 63ba8c18..5a98f5d9 100644 --- a/test/bijectors/leaky_relu.jl +++ b/test/bijectors/leaky_relu.jl @@ -41,12 +41,12 @@ true_logabsdetjac(b::Bijector{1}, xs::AbstractMatrix) = mapreduce(z -> true_loga # Forward f = forward(b, xs) - @test f.logabsdetjac ≈ logabsdetjac(b, xs) - @test f.rv ≈ b(xs) + @test f[2] ≈ logabsdetjac(b, xs) + @test f[1] ≈ b(xs) f = forward(b, Float32.(xs)) - @test f.logabsdetjac == logabsdetjac(b, Float32.(xs)) - @test f.rv ≈ b(Float32.(xs)) + @test f[2] == logabsdetjac(b, Float32.(xs)) + @test f[1] ≈ b(Float32.(xs)) end @testset "0-dim parameter, 1-dim input" begin @@ -67,12 +67,12 @@ end # Forward f = forward(b, xs) - @test f.logabsdetjac ≈ logabsdetjac(b, xs) - @test f.rv ≈ b(xs) + @test f[2] ≈ logabsdetjac(b, xs) + @test f[1] ≈ b(xs) f = forward(b, Float32.(xs)) - @test f.logabsdetjac == logabsdetjac(b, Float32.(xs)) - @test f.rv ≈ b(Float32.(xs)) + @test f[2] == logabsdetjac(b, Float32.(xs)) + @test f[1] ≈ b(Float32.(xs)) # Mixing of types # 1. Changes in input-type diff --git a/test/bijectors/utils.jl b/test/bijectors/utils.jl index a0fdb6f2..98919af9 100644 --- a/test/bijectors/utils.jl +++ b/test/bijectors/utils.jl @@ -17,21 +17,21 @@ function test_bijector_reals( ires = isequal ? @inferred(forward(inv(b), y_true)) : @inferred(forward(inv(b), y)) # Always want the following to hold - @test ires.rv ≈ x_true atol=tol - @test ires.logabsdetjac ≈ -logjac atol=tol + @test ires[1] ≈ x_true atol=tol + @test ires[2] ≈ -logjac atol=tol if isequal @test y ≈ y_true atol=tol # forward @test (@inferred ib(y_true)) ≈ x_true atol=tol # inverse @test logjac ≈ logjac_true # logjac forward - @test res.rv ≈ y_true atol=tol # forward using `forward` - @test res.logabsdetjac ≈ logjac_true atol=tol # logjac using `forward` + @test res[1] ≈ y_true atol=tol # forward using `forward` + @test res[2] ≈ logjac_true atol=tol # logjac using `forward` else @test y ≠ y_true # forward @test (@inferred ib(y)) ≈ x_true atol=tol # inverse @test logjac ≠ logjac_true # logjac forward - @test res.rv ≠ y_true # forward using `forward` - @test res.logabsdetjac ≠ logjac_true # logjac using `forward` + @test res[1] ≠ y_true # forward using `forward` + @test res[2] ≠ logjac_true # logjac using `forward` end end @@ -54,25 +54,25 @@ function test_bijector_arrays( # always want the following to hold @test ys isa typeof(ys_true) @test logjacs isa typeof(logjacs_true) - @test mean(abs, ires.rv - xs_true) ≤ tol - @test mean(abs, ires.logabsdetjac + logjacs) ≤ tol + @test mean(abs, ires[1] - xs_true) ≤ tol + @test mean(abs, ires[2] + logjacs) ≤ tol if isequal @test mean(abs, ys - ys_true) ≤ tol # forward @test mean(abs, (ib(ys_true)) - xs_true) ≤ tol # inverse @test mean(abs, logjacs - logjacs_true) ≤ tol # logjac forward - @test mean(abs, res.rv - ys_true) ≤ tol # forward using `forward` - @test mean(abs, res.logabsdetjac - logjacs_true) ≤ tol # logjac `forward` - @test mean(abs, ires.logabsdetjac + logjacs_true) ≤ tol # inverse logjac `forward` + @test mean(abs, res[1] - ys_true) ≤ tol # forward using `forward` + @test mean(abs, res[2] - logjacs_true) ≤ tol # logjac `forward` + @test mean(abs, ires[2] + logjacs_true) ≤ tol # inverse logjac `forward` else # Don't want the following to be equal to their "true" values @test mean(abs, ys - ys_true) > tol # forward @test mean(abs, logjacs - logjacs_true) > tol # logjac forward - @test mean(abs, res.rv - ys_true) > tol # forward using `forward` + @test mean(abs, res[1] - ys_true) > tol # forward using `forward` # Still want the following to be equal to the COMPUTED values @test mean(abs, ib(ys) - xs_true) ≤ tol # inverse - @test mean(abs, res.logabsdetjac - logjacs) ≤ tol # logjac forward using `forward` + @test mean(abs, res[2] - logjacs) ≤ tol # logjac forward using `forward` end end diff --git a/test/interface.jl b/test/interface.jl index cee597f6..27da4517 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -200,15 +200,15 @@ end @test size(x_) == size(x) @test size(xs_) == size(xs) - @test size(result.rv) == size(x) - @test size(results.rv) == size(xs) + @test size(result[1]) == size(x) + @test size(results[1]) == size(xs) - @test size(iresult.rv) == size(y) - @test size(iresults.rv) == size(ys) + @test size(iresult[1]) == size(y) + @test size(iresults[1]) == size(ys) # Values @test ys ≈ hcat([b(xs[:, i]) for i = 1:size(xs, 2)]...) - @test ys ≈ results.rv + @test ys ≈ results[1] if D == 0 # Sizes @@ -220,8 +220,8 @@ end @test @inferred(logabsdetjac(b, param(xs))) isa Union{Array, TrackedArray} @test @inferred(logabsdetjac(ib, param(ys))) isa Union{Array, TrackedArray} - @test size(results.logabsdetjac) == size(xs, ) - @test size(iresults.logabsdetjac) == size(ys, ) + @test size(results[2]) == size(xs, ) + @test size(iresults[2]) == size(ys, ) # Values b_logjac_ad = [(log ∘ abs)(ForwardDiff.derivative(b, xs[i])) for i = 1:length(xs)] @@ -234,8 +234,8 @@ end @test logabsdetjac.(b, param(xs)) == @inferred(logabsdetjac(b, param(xs))) @test logabsdetjac.(ib, param(ys)) == @inferred(logabsdetjac(ib, param(ys))) - @test results.logabsdetjac ≈ vec(logabsdetjac.(b, xs)) - @test iresults.logabsdetjac ≈ vec(logabsdetjac.(ib, ys)) + @test results[2] ≈ vec(logabsdetjac.(b, xs)) + @test iresults[2] ≈ vec(logabsdetjac.(ib, ys)) elseif D == 1 @test y == ys[:, 1] # Comparing sizes instead of lengths ensures we catch errors s.t. @@ -247,15 +247,15 @@ end @test @inferred(logabsdetjac(b, param(xs))) isa Union{Array, TrackedArray} @test @inferred(logabsdetjac(ib, param(ys))) isa Union{Array, TrackedArray} - @test size(results.logabsdetjac) == (size(xs, 2), ) - @test size(iresults.logabsdetjac) == (size(ys, 2), ) + @test size(results[2]) == (size(xs, 2), ) + @test size(iresults[2]) == (size(ys, 2), ) # Test all values @test @inferred(logabsdetjac(b, xs)) ≈ vec(mapslices(z -> logabsdetjac(b, z), xs; dims = 1)) @test @inferred(logabsdetjac(ib, ys)) ≈ vec(mapslices(z -> logabsdetjac(ib, z), ys; dims = 1)) - @test results.logabsdetjac ≈ vec(mapslices(z -> logabsdetjac(b, z), xs; dims = 1)) - @test iresults.logabsdetjac ≈ vec(mapslices(z -> logabsdetjac(ib, z), ys; dims = 1)) + @test results[2] ≈ vec(mapslices(z -> logabsdetjac(b, z), xs; dims = 1)) + @test iresults[2] ≈ vec(mapslices(z -> logabsdetjac(ib, z), ys; dims = 1)) # FIXME: `SimplexBijector` results in ∞ gradient if not in the domain if !contains(t -> t isa SimplexBijector, b) @@ -575,17 +575,17 @@ end res1 = forward(sb1, [x, x, y, y]) @test sb1(param([x, x, y, y])) isa TrackedArray - @test sb1([x, x, y, y]) ≈ res1.rv + @test sb1([x, x, y, y]) ≈ res1[1] @test logabsdetjac(sb1, [x, x, y, y]) ≈ 0 atol=1e-6 - @test res1.logabsdetjac ≈ 0 atol=1e-6 + @test res1[2] ≈ 0 atol=1e-6 sb2 = Stacked([b, b, inv(b), inv(b)]) # <= Array res2 = forward(sb2, [x, x, y, y]) @test sb2(param([x, x, y, y])) isa TrackedArray - @test sb2([x, x, y, y]) ≈ res2.rv + @test sb2([x, x, y, y]) ≈ res2[1] @test logabsdetjac(sb2, [x, x, y, y]) ≈ 0.0 atol=1e-12 - @test res2.logabsdetjac ≈ 0.0 atol=1e-12 + @test res2[2] ≈ 0.0 atol=1e-12 # `logabsdetjac` with AD b = MyADBijector(d) @@ -595,17 +595,17 @@ end res1 = forward(sb1, [x, x, y, y]) @test sb1(param([x, x, y, y])) isa TrackedArray - @test sb1([x, x, y, y]) == res1.rv + @test sb1([x, x, y, y]) == res1[1] @test logabsdetjac(sb1, [x, x, y, y]) ≈ 0 atol=1e-12 - @test res1.logabsdetjac ≈ 0.0 atol=1e-12 + @test res1[2] ≈ 0.0 atol=1e-12 sb2 = Stacked([b, b, inv(b), inv(b)]) # <= Array res2 = forward(sb2, [x, x, y, y]) @test sb2(param([x, x, y, y])) isa TrackedArray - @test sb2([x, x, y, y]) == res2.rv + @test sb2([x, x, y, y]) == res2[1] @test logabsdetjac(sb2, [x, x, y, y]) ≈ 0.0 atol=1e-12 - @test res2.logabsdetjac ≈ 0.0 atol=1e-12 + @test res2[2] ≈ 0.0 atol=1e-12 # value-test x = ones(3) @@ -613,9 +613,9 @@ end res = forward(sb, x) @test sb(param(x)) isa TrackedArray @test sb(x) == [exp(x[1]), log(x[2]), x[3] + 5.0] - @test res.rv == [exp(x[1]), log(x[2]), x[3] + 5.0] + @test res[1] == [exp(x[1]), log(x[2]), x[3] + 5.0] @test logabsdetjac(sb, x) == sum([sum(logabsdetjac(sb.bs[i], x[sb.ranges[i]])) for i = 1:3]) - @test res.logabsdetjac == logabsdetjac(sb, x) + @test res[2] == logabsdetjac(sb, x) # TODO: change when we have dimensionality in the type @@ -624,9 +624,9 @@ end res = @inferred forward(sb, x) @test sb(param(x)) isa TrackedArray @test sb(x) == [exp(x[1]), sb.bs[2](x[2:3])...] - @test res.rv == [exp(x[1]), sb.bs[2](x[2:3])...] + @test res[1] == [exp(x[1]), sb.bs[2](x[2:3])...] @test logabsdetjac(sb, x) == sum([sum(logabsdetjac(sb.bs[i], x[sb.ranges[i]])) for i = 1:2]) - @test res.logabsdetjac == logabsdetjac(sb, x) + @test res[2] == logabsdetjac(sb, x) x = ones(4) ./ 4.0 @test_throws AssertionError sb(x) @@ -637,9 +637,9 @@ end res = forward(sb, x) @test sb(param(x)) isa TrackedArray @test sb(x) == [exp(x[1]), sb.bs[2](x[2:3])...] - @test res.rv == [exp(x[1]), sb.bs[2](x[2:3])...] + @test res[1] == [exp(x[1]), sb.bs[2](x[2:3])...] @test logabsdetjac(sb, x) == sum([sum(logabsdetjac(sb.bs[i], x[sb.ranges[i]])) for i = 1:2]) - @test res.logabsdetjac == logabsdetjac(sb, x) + @test res[2] == logabsdetjac(sb, x) x = ones(4) ./ 4.0 @test_throws AssertionError sb(x) @@ -651,9 +651,9 @@ end res = forward(sb, x) @test sb(param(x)) isa TrackedArray @test sb(x) == [exp(x[1]), sb.bs[2](x[2:3])...] - @test res.rv == [exp(x[1]), sb.bs[2](x[2:3])...] + @test res[1] == [exp(x[1]), sb.bs[2](x[2:3])...] @test logabsdetjac(sb, x) == sum([sum(logabsdetjac(sb.bs[i], x[sb.ranges[i]])) for i = 1:2]) - @test res.logabsdetjac == logabsdetjac(sb, x) + @test res[2] == logabsdetjac(sb, x) x = ones(4) ./ 4.0 @test_throws AssertionError sb(x) @@ -664,9 +664,9 @@ end res = forward(sb, x) @test sb(param(x)) isa TrackedArray @test sb(x) == [exp(x[1]), sb.bs[2](x[2:3])...] - @test res.rv == [exp(x[1]), sb.bs[2](x[2:3])...] + @test res[1] == [exp(x[1]), sb.bs[2](x[2:3])...] @test logabsdetjac(sb, x) == sum([sum(logabsdetjac(sb.bs[i], x[sb.ranges[i]])) for i = 1:2]) - @test res.logabsdetjac == logabsdetjac(sb, x) + @test res[2] == logabsdetjac(sb, x) x = ones(4) ./ 4.0 @test_throws AssertionError sb(x) @@ -748,7 +748,7 @@ end x = [.5, 1.] @test sb(x) == x @test logabsdetjac(sb, x) == 0 - @test forward(sb, x) == (rv = x, logabsdetjac = zero(eltype(x))) + @test forward(sb, x) == (x, zero(eltype(x))) end end diff --git a/test/norm_flows.jl b/test/norm_flows.jl index dbbbc36f..9fadf573 100644 --- a/test/norm_flows.jl +++ b/test/norm_flows.jl @@ -26,7 +26,7 @@ end flow = PlanarLayer(2) z = randn(2, 20) forward_diff = log(abs(det(ForwardDiff.jacobian(t -> flow(t), z)))) - our_method = sum(forward(flow, z).logabsdetjac) + our_method = sum(forward(flow, z)[2]) @test our_method ≈ forward_diff @test inv(flow)(flow(z)) ≈ z @@ -74,7 +74,7 @@ end flow = RadialLayer(2) z = randn(2, 20) forward_diff = log(abs(det(ForwardDiff.jacobian(t -> flow(t), z)))) - our_method = sum(forward(flow, z).logabsdetjac) + our_method = sum(forward(flow, z)[2]) @test our_method ≈ forward_diff @test inv(flow)(flow(z)) ≈ z rtol=0.2 @@ -103,9 +103,9 @@ end x = rand(d) y = flow.transform(x) res = forward(flow.transform, x) - lp = logpdf_forward(flow, x, res.logabsdetjac) + lp = logpdf_forward(flow, x, res[2]) - @test res.rv ≈ y + @test res[1] ≈ y @test logpdf(flow, y) ≈ lp rtol=0.1 # flow with unconstrained-to-constrained From 0c1bf4871a6a8f528e2c0a2dff5d29b5bf809c14 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Fri, 10 Dec 2021 23:49:06 +0100 Subject: [PATCH 03/29] Replace Base.inv with InverseFunctions.inverse --- README.md | 38 ++++++------- src/Bijectors.jl | 11 ++-- src/bijectors/composed.jl | 6 +- src/bijectors/corr.jl | 2 +- src/bijectors/coupling.jl | 4 +- src/bijectors/exp_log.jl | 4 +- src/bijectors/leaky_relu.jl | 2 +- src/bijectors/named_bijector.jl | 14 ++--- src/bijectors/normalise.jl | 2 +- src/bijectors/permute.jl | 6 +- src/bijectors/shift.jl | 2 +- src/bijectors/simplex.jl | 4 +- src/bijectors/stacked.jl | 8 +-- src/compat/distributionsad.jl | 4 +- src/compat/reversediff.jl | 4 +- src/interface.jl | 14 ++--- src/transformed_distribution.jl | 30 +++++----- test/ad/flows.jl | 4 +- test/bijectors/coupling.jl | 4 +- test/bijectors/leaky_relu.jl | 8 +-- test/bijectors/named_bijector.jl | 10 ++-- test/bijectors/permute.jl | 16 +++--- test/bijectors/utils.jl | 12 ++-- test/interface.jl | 94 ++++++++++++++++---------------- test/norm_flows.jl | 30 +++++----- 25 files changed, 168 insertions(+), 165 deletions(-) diff --git a/README.md b/README.md index 7fdc4424..4cbbf576 100644 --- a/README.md +++ b/README.md @@ -18,11 +18,11 @@ The following table lists mathematical operations for a bijector and the corresp | Operation | Method | Automatic | |:------------------------------------:|:-----------------:|:-----------:| -| `b ↦ b⁻¹` | `inv(b)` | ✓ | +| `b ↦ b⁻¹` | `inverse(b)` | ✓ | | `(b₁, b₂) ↦ (b₁ ∘ b₂)` | `b₁ ∘ b₂` | ✓ | | `(b₁, b₂) ↦ [b₁, b₂]` | `stack(b₁, b₂)` | ✓ | | `x ↦ b(x)` | `b(x)` | × | -| `y ↦ b⁻¹(y)` | `inv(b)(y)` | × | +| `y ↦ b⁻¹(y)` | `inverse(b)(y)` | × | | `x ↦ log|det J(b, x)|` | `logabsdetjac(b, x)` | AD | | `x ↦ b(x), log|det J(b, x)|` | `with_logabsdet_jacobian(b, x)` | ✓ | | `p ↦ q := b_* p` | `q = transformed(p, b)` | ✓ | @@ -123,7 +123,7 @@ true What about `invlink`? ```julia -julia> b⁻¹ = inv(b) +julia> b⁻¹ = inverse(b) Inverse{Logit{Float64},0}(Logit{Float64}(0.0, 1.0)) julia> b⁻¹(y) @@ -133,7 +133,7 @@ julia> b⁻¹(y) ≈ invlink(dist, y) true ``` -Pretty neat, huh? `Inverse{Logit}` is also a `Bijector` where we've defined `(ib::Inverse{<:Logit})(y)` as the inverse transformation of `(b::Logit)(x)`. Note that it's not always the case that `inv(b) isa Inverse`, e.g. the inverse of `Exp` is simply `Log` so `inv(Exp()) isa Log` is true. +Pretty neat, huh? `Inverse{Logit}` is also a `Bijector` where we've defined `(ib::Inverse{<:Logit})(y)` as the inverse transformation of `(b::Logit)(x)`. Note that it's not always the case that `inverse(b) isa Inverse`, e.g. the inverse of `Exp` is simply `Log` so `inverse(Exp()) isa Log` is true. #### Dimensionality One more thing. See the `0` in `Inverse{Logit{Float64}, 0}`? It represents the *dimensionality* of the bijector, in the same sense as for an `AbstractArray` with the exception of `0` which means it expects 0-dim input and output, i.e. `<:Real`. This can also be accessed through `dimension(b)`: @@ -162,7 +162,7 @@ true And since `Composed isa Bijector`: ```julia -julia> id_x = inv(id_y) +julia> id_x = inverse(id_y) Composed{Tuple{Inverse{Logit{Float64},0},Logit{Float64}},0}((Inverse{Logit{Float64},0}(Logit{Float64}(0.0, 1.0)), Logit{Float64}(0.0, 1.0))) julia> id_x(x) ≈ x @@ -201,7 +201,7 @@ julia> logpdf_forward(td, x) #### `logabsdetjac` and `forward` -In the computation of both `logpdf` and `logpdf_forward` we need to compute `log(abs(det(jacobian(inv(b), y))))` and `log(abs(det(jacobian(b, x))))`, respectively. This computation is available using the `logabsdetjac` method +In the computation of both `logpdf` and `logpdf_forward` we need to compute `log(abs(det(jacobian(inverse(b), y))))` and `log(abs(det(jacobian(b, x))))`, respectively. This computation is available using the `logabsdetjac` method ```julia julia> logabsdetjac(b⁻¹, y) @@ -228,7 +228,7 @@ julia> with_logabsdet_jacobian(b, x) Similarily ```julia -julia> forward(inv(b), y) +julia> forward(inverse(b), y) (0.3688868996596376, -1.4575353795716655) ``` @@ -241,7 +241,7 @@ At this point we've only shown that we can replicate the existing functionality. julia> y = rand(td) # ∈ ℝ 0.999166054552483 -julia> x = inv(td.transform)(y) # transform back to interval [0, 1] +julia> x = inverse(td.transform)(y) # transform back to interval [0, 1] 0.7308945834125756 ``` @@ -261,7 +261,7 @@ Beta{Float64}(α=2.0, β=2.0) julia> b = bijector(dist) # (0, 1) → ℝ Logit{Float64}(0.0, 1.0) -julia> b⁻¹ = inv(b) # ℝ → (0, 1) +julia> b⁻¹ = inverse(b) # ℝ → (0, 1) Inverse{Logit{Float64},0}(Logit{Float64}(0.0, 1.0)) julia> td = transformed(Normal(), b⁻¹) # x ∼ 𝓝(0, 1) then b(x) ∈ (0, 1) @@ -280,7 +280,7 @@ It's worth noting that `support(Beta)` is the _closed_ interval `[0, 1]`, while ```julia td = transformed(Beta()) -inv(td.transform)(rand(td)) +inverse(td.transform)(rand(td)) ``` will never result in `0` or `1` though any sample arbitrarily close to either `0` or `1` is possible. _Disclaimer: numerical accuracy is limited, so you might still see `0` and `1` if you're lucky._ @@ -335,7 +335,7 @@ julia> # Construct the transform bs = bijector.(dists) # constrained-to-unconstrained bijectors for dists (Logit{Float64}(0.0, 1.0), Log{0}(), SimplexBijector{true}()) -julia> ibs = inv.(bs) # invert, so we get unconstrained-to-constrained +julia> ibs = inverse.(bs) # invert, so we get unconstrained-to-constrained (Inverse{Logit{Float64},0}(Logit{Float64}(0.0, 1.0)), Exp{0}(), Inverse{SimplexBijector{true},1}(SimplexBijector{true}())) julia> sb = Stacked(ibs, ranges) # => Stacked <: Bijector @@ -411,7 +411,7 @@ Similarily to the multivariate ADVI example, we could use `Stacked` to get a _bo ```julia julia> d = MvNormal(zeros(2), ones(2)); -julia> ibs = inv.(bijector.((InverseGamma(2, 3), Beta()))); +julia> ibs = inverse.(bijector.((InverseGamma(2, 3), Beta()))); julia> sb = stack(ibs...) # == Stacked(ibs) == Stacked(ibs, [i:i for i = 1:length(ibs)] Stacked{Tuple{Exp{0},Inverse{Logit{Float64},0}},2}((Exp{0}(), Inverse{Logit{Float64},0}(Logit{Float64}(0.0, 1.0))), (1:1, 2:2)) @@ -542,7 +542,7 @@ Logit{Float64}(0.0, 1.0) julia> b(0.6) 0.4054651081081642 -julia> inv(b)(y) +julia> inverse(b)(y) Tracked 2-element Array{Float64,1}: 0.3078149833748082 0.72380041667891 @@ -550,7 +550,7 @@ Tracked 2-element Array{Float64,1}: julia> logabsdetjac(b, 0.6) 1.4271163556401458 -julia> logabsdetjac(inv(b), y) # defaults to `- logabsdetjac(b, inv(b)(x))` +julia> logabsdetjac(inverse(b), y) # defaults to `- logabsdetjac(b, inverse(b)(x))` Tracked 2-element Array{Float64,1}: -1.546158373866469 -1.6098711387913573 @@ -614,10 +614,10 @@ julia> logabsdetjac(b_ad, 0.6) julia> y = b_ad(0.6) 0.4054651081081642 -julia> inv(b_ad)(y) +julia> inverse(b_ad)(y) 0.6 -julia> logabsdetjac(inv(b_ad), y) +julia> logabsdetjac(inverse(b_ad), y) -1.4271163556401458 ``` @@ -666,7 +666,7 @@ help?> Bijectors.Composed A Bijector representing composition of bijectors. composel and composer results in a Composed for which application occurs from left-to-right and right-to-left, respectively. - Note that all the alternative ways of constructing a Composed returns a Tuple of bijectors. This ensures type-stability of implementations of all relating methdos, e.g. inv. + Note that all the alternative ways of constructing a Composed returns a Tuple of bijectors. This ensures type-stability of implementations of all relating methods, e.g. inverse. If you want to use an Array as the container instead you can do @@ -714,7 +714,7 @@ The distribution interface consists of: #### Methods The following methods are implemented by all subtypes of `Bijector`, this also includes bijectors such as `Composed`. - `(b::Bijector)(x)`: implements the transform of the `Bijector` -- `inv(b::Bijector)`: returns the inverse of `b`, i.e. `ib::Bijector` s.t. `(ib ∘ b)(x) ≈ x`. In most cases this is `Inverse{<:Bijector}`. +- `inverse(b::Bijector)`: returns the inverse of `b`, i.e. `ib::Bijector` s.t. `(ib ∘ b)(x) ≈ x`. In most cases this is `Inverse{<:Bijector}`. - `logabsdetjac(b::Bijector, x)`: computes log(abs(det(jacobian(b, x)))). - `with_logabsdet_jacobian(b::Bijector, x)`: returns named tuple `(b(x), logabsdetjac(b, x))` in the most efficient manner. - `∘`, `composel`, `composer`: convenient and type-safe constructors for `Composed`. `composel(bs...)` composes s.t. the resulting composition is evaluated left-to-right, while `composer(bs...)` is evaluated right-to-left. `∘` is right-to-left, as excepted from standard mathematical notation. @@ -726,7 +726,7 @@ For `TransformedDistribution`, together with default implementations for `Distri - `bijector(d::Distribution)`: returns the default constrained-to-unconstrained bijector for `d` - `transformed(d::Distribution)`, `transformed(d::Distribution, b::Bijector)`: constructs a `TransformedDistribution` from `d` and `b`. - `logpdf_forward(d::Distribution, x)`, `logpdf_forward(d::Distribution, x, logjac)`: computes the `logpdf(td, td.transform(x))` using the forward pass, which is potentially faster depending on the transform at hand. -- `forward(d::Distribution)`: returns `(x = rand(dist), y = b(x), logabsdetjac = logabsdetjac(b, x), logpdf = logpdf_forward(td, x))` where `b = td.transform`. This combines sampling from base distribution and transforming into one function. The intention is that this entire process should be performed in the most efficient manner, e.g. the `logabsdetjac(b, x)` call might instead be implemented as `- logabsdetjac(inv(b), b(x))` depending on which is most efficient. +- `forward(d::Distribution)`: returns `(x = rand(dist), y = b(x), logabsdetjac = logabsdetjac(b, x), logpdf = logpdf_forward(td, x))` where `b = td.transform`. This combines sampling from base distribution and transforming into one function. The intention is that this entire process should be performed in the most efficient manner, e.g. the `logabsdetjac(b, x)` call might instead be implemented as `- logabsdetjac(inverse(b), b(x))` depending on which is most efficient. # Bibliography 1. Rezende, D. J., & Mohamed, S. (2015). Variational Inference With Normalizing Flows. [arXiv:1505.05770](https://arxiv.org/abs/1505.05770v6). diff --git a/src/Bijectors.jl b/src/Bijectors.jl index 2393283d..a540f435 100644 --- a/src/Bijectors.jl +++ b/src/Bijectors.jl @@ -36,10 +36,10 @@ using Base.Iterators: drop using LinearAlgebra: AbstractTriangular import ChangesOfVariables: with_logabsdet_jacobian +import InverseFunctions: inverse import ChainRulesCore import Functors -import InverseFunctions import IrrationalConstants import LogExpFunctions import Roots @@ -124,7 +124,7 @@ end # Distributions link(d::Distribution, x) = bijector(d)(x) -invlink(d::Distribution, y) = inv(bijector(d))(y) +invlink(d::Distribution, y) = inverse(bijector(d))(y) function logpdf_with_trans(d::Distribution, x, transform::Bool) if ispd(d) return pd_logpdf_with_trans(d, x, transform) @@ -191,14 +191,14 @@ function invlink( y::AbstractVecOrMat{<:Real}, ::Val{proj}=Val(true), ) where {proj} - return inv(SimplexBijector{proj}())(y) + return inverse(SimplexBijector{proj}())(y) end function invlink_jacobian( d::Dirichlet, y::AbstractVector{<:Real}, ::Val{proj}=Val(true), ) where {proj} - return jacobian(inv(SimplexBijector{proj}()), y) + return jacobian(inverse(SimplexBijector{proj}()), y) end ## Matrix @@ -254,6 +254,9 @@ include("chainrules.jl") Base.@deprecate forward(b::AbstractBijector, x) with_logabsdet_jacobian(b, x) +import Base.inv +Base.@deprecate inv(b::AbstractBijector) inverse(b) + # Broadcasting here breaks Tracker for some reason maporbroadcast(f, x::AbstractArray{<:Any, N}...) where {N} = map(f, x...) maporbroadcast(f, x::AbstractArray...) = f.(x...) diff --git a/src/bijectors/composed.jl b/src/bijectors/composed.jl index 6a02432e..482a2760 100644 --- a/src/bijectors/composed.jl +++ b/src/bijectors/composed.jl @@ -17,7 +17,7 @@ A `Bijector` representing composition of bijectors. `composel` and `composer` re `Composed` for which application occurs from left-to-right and right-to-left, respectively. Note that all the alternative ways of constructing a `Composed` returns a `Tuple` of bijectors. -This ensures type-stability of implementations of all relating methdos, e.g. `inv`. +This ensures type-stability of implementations of all relating methdos, e.g. `inverse`. If you want to use an `Array` as the container instead you can do @@ -41,7 +41,7 @@ Composed{Tuple{Exp{0},Exp{0}},0}((Exp{0}(), Exp{0}())) julia> (b ∘ b)(1.0) == exp(exp(1.0)) # evaluation true -julia> inv(b ∘ b)(exp(exp(1.0))) == 1.0 # inversion +julia> inverse(b ∘ b)(exp(exp(1.0))) == 1.0 # inversion true julia> logabsdetjac(b ∘ b, 1.0) # determinant of jacobian @@ -153,7 +153,7 @@ end ∘(::Identity{N}, b::Bijector{N}) where {N} = b ∘(b::Bijector{N}, ::Identity{N}) where {N} = b -inv(ct::Composed) = Composed(reverse(map(inv, ct.ts))) +inverse(ct::Composed) = Composed(reverse(map(inv, ct.ts))) # # TODO: should arrays also be using recursive implementation instead? function (cb::Composed{<:AbstractArray{<:Bijector}})(x) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index 18363db1..5ec999db 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -100,7 +100,7 @@ function logabsdetjac(b::CorrBijector, X::AbstractMatrix{<:Real}) `logabsdetjac(::Inverse{CorrBijector}, y::AbstractMatrix{<:Real})` if possible. =# - return -logabsdetjac(inv(b), (b(X))) + return -logabsdetjac(inverse(b), (b(X))) end function logabsdetjac(b::CorrBijector, X::AbstractArray{<:AbstractMatrix{<:Real}}) return mapvcat(X) do x diff --git a/src/bijectors/coupling.jl b/src/bijectors/coupling.jl index 088caf2a..03b00ba5 100644 --- a/src/bijectors/coupling.jl +++ b/src/bijectors/coupling.jl @@ -151,7 +151,7 @@ julia> cl(x) 2.0 3.0 -julia> inv(cl)(cl(x)) +julia> inverse(cl)(cl(x)) 3-element Array{Float64,1}: 1.0 2.0 @@ -214,7 +214,7 @@ function (icl::Inverse{<:Coupling})(y::AbstractVector) y_1, y_2, y_3 = partition(cl.mask, y) b = cl.θ(y_2) - ib = inv(b) + ib = inverse(b) return combine(cl.mask, ib(y_1), y_2, y_3) end diff --git a/src/bijectors/exp_log.jl b/src/bijectors/exp_log.jl index 0e2da142..0f5f4683 100644 --- a/src/bijectors/exp_log.jl +++ b/src/bijectors/exp_log.jl @@ -27,8 +27,8 @@ Log() = Log{0}() (b::Exp{2})(y::AbstractArray{<:AbstractMatrix{<:Real}}) = map(b, y) (b::Log{2})(x::AbstractArray{<:AbstractMatrix{<:Real}}) = map(b, x) -inv(b::Exp{N}) where {N} = Log{N}() -inv(b::Log{N}) where {N} = Exp{N}() +inverse(b::Exp{N}) where {N} = Log{N}() +inverse(b::Log{N}) where {N} = Exp{N}() logabsdetjac(b::Exp{0}, x::Real) = x logabsdetjac(b::Exp{0}, x::AbstractVector) = x diff --git a/src/bijectors/leaky_relu.jl b/src/bijectors/leaky_relu.jl index 62413316..c91e0faf 100644 --- a/src/bijectors/leaky_relu.jl +++ b/src/bijectors/leaky_relu.jl @@ -31,7 +31,7 @@ function (b::LeakyReLU{<:Any, 0})(x::Real) end (b::LeakyReLU{<:Any, 0})(x::AbstractVector{<:Real}) = map(b, x) -function Base.inv(b::LeakyReLU{<:Any,N}) where N +function inverse(b::LeakyReLU{<:Any,N}) where N invα = inv.(b.α) return LeakyReLU{typeof(invα),N}(invα) end diff --git a/src/bijectors/named_bijector.jl b/src/bijectors/named_bijector.jl index 0543f91a..8f90464e 100644 --- a/src/bijectors/named_bijector.jl +++ b/src/bijectors/named_bijector.jl @@ -55,8 +55,8 @@ names_to_bijectors(b::NamedBijector) = b.bs return :($(exprs...), ) end -@generated function Base.inv(b::NamedBijector{names}) where {names} - return :(NamedBijector(($([:($n = inv(b.bs.$n)) for n in names]...), ))) +@generated function inverse(b::NamedBijector{names}) where {names} + return :(NamedBijector(($([:($n = inverse(b.bs.$n)) for n in names]...), ))) end @generated function logabsdetjac(b::NamedBijector{names}, x::NamedTuple) where {names} @@ -78,10 +78,10 @@ See also: [`Inverse`](@ref) struct NamedInverse{B<:AbstractNamedBijector} <: AbstractNamedBijector orig::B end -Base.inv(nb::AbstractNamedBijector) = NamedInverse(nb) -Base.inv(ni::NamedInverse) = ni.orig +inverse(nb::AbstractNamedBijector) = NamedInverse(nb) +inverse(ni::NamedInverse) = ni.orig -logabsdetjac(ni::NamedInverse, y::NamedTuple) = -logabsdetjac(inv(ni), ni(y)) +logabsdetjac(ni::NamedInverse, y::NamedTuple) = -logabsdetjac(inverse(ni), ni(y)) ########################## ### `NamedComposition` ### @@ -107,7 +107,7 @@ composel(bs::AbstractNamedBijector...) = NamedComposition(bs) composer(bs::AbstractNamedBijector...) = NamedComposition(reverse(bs)) ∘(b1::AbstractNamedBijector, b2::AbstractNamedBijector) = composel(b2, b1) -inv(ct::NamedComposition) = NamedComposition(reverse(map(inv, ct.bs))) +inverse(ct::NamedComposition) = NamedComposition(reverse(map(inv, ct.bs))) function (cb::NamedComposition{<:AbstractArray{<:AbstractNamedBijector}})(x) @assert length(cb.bs) > 0 @@ -232,7 +232,7 @@ end ) where {target, deps, F} return quote b = ni.orig.f($([:(x.$d) for d in deps]...)) - return merge(x, ($target = inv(b)(x.$target), )) + return merge(x, ($target = inverse(b)(x.$target), )) end end diff --git a/src/bijectors/normalise.jl b/src/bijectors/normalise.jl index defb447e..43972eeb 100644 --- a/src/bijectors/normalise.jl +++ b/src/bijectors/normalise.jl @@ -87,7 +87,7 @@ function forward(invbn::Inverse{<:InvertibleBatchNorm}, y) @assert !istraining() "`forward(::Inverse{InvertibleBatchNorm})` is only available in test mode." dims = ndims(y) as = ntuple(i -> i == ndims(y) - 1 ? size(y, i) : 1, dims) - bn = inv(invbn) + bn = inverse(invbn) s = reshape(exp.(bn.logs), as...) b = reshape(bn.b, as...) m = reshape(bn.m, as...) diff --git a/src/bijectors/permute.jl b/src/bijectors/permute.jl index d4fdef7b..5ba9e5cf 100644 --- a/src/bijectors/permute.jl +++ b/src/bijectors/permute.jl @@ -71,10 +71,10 @@ julia> b4([1., 2., 3.]) 1.0 3.0 -julia> inv(b1) +julia> inverse(b1) Permute{LinearAlgebra.Transpose{Int64,Array{Int64,2}}}([0 1 0; 1 0 0; 0 0 1]) -julia> inv(b1)(b1([1., 2., 3.])) +julia> inverse(b1)(b1([1., 2., 3.])) 3-element Array{Float64,1}: 1.0 2.0 @@ -151,7 +151,7 @@ end @inline (b::Permute)(x::AbstractVecOrMat) = b.A * x -@inline inv(b::Permute) = Permute(transpose(b.A)) +@inline inverse(b::Permute) = Permute(transpose(b.A)) logabsdetjac(b::Permute, x::AbstractVector) = zero(eltype(x)) logabsdetjac(b::Permute, x::AbstractMatrix) = zero(eltype(x), size(x, 2)) diff --git a/src/bijectors/shift.jl b/src/bijectors/shift.jl index f6c39f82..e4e9960c 100644 --- a/src/bijectors/shift.jl +++ b/src/bijectors/shift.jl @@ -24,7 +24,7 @@ up1(b::Shift{T, N}) where {T, N} = Shift{T, N + 1}(b.a) (b::Shift)(x) = b.a .+ x (b::Shift{<:Any, 2})(x::AbstractArray{<:AbstractMatrix}) = map(b, x) -inv(b::Shift{T, N}) where {T, N} = Shift{T, N}(-b.a) +inverse(b::Shift{T, N}) where {T, N} = Shift{T, N}(-b.a) # FIXME: implement custom adjoint to ensure we don't get tracking logabsdetjac(b::Shift{T, N}, x) where {T, N} = _logabsdetjac_shift(b.a, x, Val(N)) diff --git a/src/bijectors/simplex.jl b/src/bijectors/simplex.jl index 48d19b73..10ba4db3 100644 --- a/src/bijectors/simplex.jl +++ b/src/bijectors/simplex.jl @@ -127,10 +127,10 @@ function (ib::Inverse{<:SimplexBijector{1}})( _simplex_inv_bijector!(X, Y, ib.orig) end function (ib::Inverse{<:SimplexBijector{2, proj}})(Y::AbstractMatrix) where {proj} - inv(SimplexBijector{1, proj}())(Y) + inverse(SimplexBijector{1, proj}())(Y) end function (ib::Inverse{<:SimplexBijector{2, proj}})(X::AbstractMatrix, Y::AbstractMatrix) where {proj} - inv(SimplexBijector{1, proj}())(X, Y) + inverse(SimplexBijector{1, proj}())(X, Y) end (ib::Inverse{<:SimplexBijector{2}})(Y::AbstractArray{<:AbstractMatrix}) = map(ib, Y) function _simplex_inv_bijector(Y::AbstractMatrix, b::SimplexBijector{1}) diff --git a/src/bijectors/stacked.jl b/src/bijectors/stacked.jl index 144c7817..9a49c34d 100644 --- a/src/bijectors/stacked.jl +++ b/src/bijectors/stacked.jl @@ -50,14 +50,14 @@ isclosedform(b::Stacked) = all(isclosedform, b.bs) stack(bs::Bijector{0}...) = Stacked(bs) -# For some reason `inv.(sb.bs)` was unstable... This works though. -inv(sb::Stacked) = Stacked(map(inv, sb.bs), sb.ranges) +# For some reason `inverse.(sb.bs)` was unstable... This works though. +inverse(sb::Stacked) = Stacked(map(inverse, sb.bs), sb.ranges) # map is not type stable for many stacked bijectors as a large tuple # hence the generated function -@generated function inv(sb::Stacked{A}) where {A <: Tuple} +@generated function inverse(sb::Stacked{A}) where {A <: Tuple} exprs = [] for i = 1:length(A.parameters) - push!(exprs, :(inv(sb.bs[$i]))) + push!(exprs, :(inverse(sb.bs[$i]))) end :(Stacked(($(exprs...), ), sb.ranges)) end diff --git a/src/compat/distributionsad.jl b/src/compat/distributionsad.jl index a2fea928..85ef72ad 100644 --- a/src/compat/distributionsad.jl +++ b/src/compat/distributionsad.jl @@ -57,14 +57,14 @@ function invlink( y::AbstractVecOrMat{<:Real}, ::Val{proj}=Val(true), ) where {proj} - return inv(SimplexBijector{proj}())(y) + return inverse(SimplexBijector{proj}())(y) end function invlink_jacobian( d::TuringDirichlet, y::AbstractVector{<:Real}, ::Val{proj}=Val(true), ) where {proj} - return jacobian(inv(SimplexBijector{proj}()), y) + return jacobian(inverse(SimplexBijector{proj}()), y) end ispd(::TuringWishart) = true diff --git a/src/compat/reversediff.jl b/src/compat/reversediff.jl index 116d8531..5c11a4db 100644 --- a/src/compat/reversediff.jl +++ b/src/compat/reversediff.jl @@ -60,7 +60,7 @@ function _logabsdetjac_scale(a::TrackedReal, x::Real, ::Val{0}) return track(_logabsdetjac_scale, a, value(x), Val(0)) end @grad function _logabsdetjac_scale(a::Real, x::Real, v::Val{0}) - return _logabsdetjac_scale(value(a), value(x), Val(0)), Δ -> (inv(value(a)) .* Δ, nothing, nothing) + return _logabsdetjac_scale(value(a), value(x), Val(0)), Δ -> (inverse(value(a)) .* Δ, nothing, nothing) end # Need to treat `AbstractVector` and `AbstractMatrix` separately due to ambiguity errors function _logabsdetjac_scale(a::TrackedReal, x::AbstractVector, ::Val{0}) @@ -68,7 +68,7 @@ function _logabsdetjac_scale(a::TrackedReal, x::AbstractVector, ::Val{0}) end @grad function _logabsdetjac_scale(a::Real, x::AbstractVector, v::Val{0}) da = value(a) - J = fill(inv.(da), length(x)) + J = fill(inverse.(da), length(x)) return _logabsdetjac_scale(da, value(x), Val(0)), Δ -> (transpose(J) * Δ, nothing, nothing) end function _logabsdetjac_scale(a::TrackedReal, x::AbstractMatrix, ::Val{0}) diff --git a/src/interface.jl b/src/interface.jl index d2e8be7b..3782dee6 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -1,4 +1,4 @@ -import Base: inv, ∘ +import Base: ∘ import Random: AbstractRNG import Distributions: logpdf, rand, rand!, _rand!, _logpdf @@ -56,7 +56,7 @@ requires an iterative procedure to evaluate. isclosedform(b::Bijector) = true """ - inv(b::Bijector) +inverse(b::Bijector) Inverse(b::Bijector) A `Bijector` representing the inverse transform of `b`. @@ -72,8 +72,8 @@ Functors.@functor Inverse up1(b::Inverse) = Inverse(up1(b.orig)) -inv(b::Bijector) = Inverse(b) -inv(ib::Inverse{<:Bijector}) = ib.orig +inverse(b::Bijector) = Inverse(b) +inverse(ib::Inverse{<:Bijector}) = ib.orig Base.:(==)(b1::Inverse{<:Bijector}, b2::Inverse{<:Bijector}) = b1.orig == b2.orig """ @@ -104,9 +104,9 @@ with_logabsdet_jacobian(b::Bijector, x) = (b(x), logabsdetjac(b, x)) """ logabsdetjacinv(b::Bijector, y) -Just an alias for `logabsdetjac(inv(b), y)`. +Just an alias for `logabsdetjac(inverse(b), y)`. """ -logabsdetjacinv(b::Bijector, y) = logabsdetjac(inv(b), y) +logabsdetjacinv(b::Bijector, y) = logabsdetjac(inverse(b), y) ############################## # Example bijector: Identity # @@ -114,7 +114,7 @@ logabsdetjacinv(b::Bijector, y) = logabsdetjac(inv(b), y) struct Identity{N} <: Bijector{N} end (::Identity)(x) = copy(x) -inv(b::Identity) = b +inverse(b::Identity) = b up1(::Identity{N}) where {N} = Identity{N + 1}() logabsdetjac(::Identity{0}, x::Real) = zero(eltype(x)) diff --git a/src/transformed_distribution.jl b/src/transformed_distribution.jl index 48afd2ff..d7d40114 100644 --- a/src/transformed_distribution.jl +++ b/src/transformed_distribution.jl @@ -35,7 +35,7 @@ transformed(d) = transformed(d, bijector(d)) Returns the constrained-to-unconstrained bijector for distribution `d`. """ -bijector(td::TransformedDistribution) = bijector(td.dist) ∘ inv(td.transform) +bijector(td::TransformedDistribution) = bijector(td.dist) ∘ inverse(td.transform) bijector(d::DiscreteUnivariateDistribution) = Identity{0}() bijector(d::DiscreteMultivariateDistribution) = Identity{1}() bijector(d::ContinuousUnivariateDistribution) = TruncatedBijector(minimum(d), maximum(d)) @@ -85,14 +85,14 @@ Base.length(td::Transformed) = length(td.dist) Base.size(td::Transformed) = size(td.dist) function logpdf(td::UnivariateTransformed, y::Real) - res = forward(inv(td.transform), y) + res = forward(inverse(td.transform), y) return logpdf(td.dist, res[1]) + res[2] end # TODO: implement more efficiently for flows in the case of `Matrix` function logpdf(td::MvTransformed, y::AbstractMatrix{<:Real}) # batch-implementation for multivariate - res = forward(inv(td.transform), y) + res = forward(inverse(td.transform), y) return logpdf(td.dist, res[1]) + res[2] end @@ -100,12 +100,12 @@ function logpdf(td::MvTransformed{<:Dirichlet}, y::AbstractMatrix{<:Real}) T = eltype(y) ϵ = _eps(T) - res = forward(inv(td.transform), y) + res = forward(inverse(td.transform), y) return logpdf(td.dist, mappedarray(x->x+ϵ, res[1])) + res[2] end function _logpdf(td::MvTransformed, y::AbstractVector{<:Real}) - res = forward(inv(td.transform), y) + res = forward(inverse(td.transform), y) return logpdf(td.dist, res[1]) + res[2] end @@ -113,15 +113,15 @@ function _logpdf(td::MvTransformed{<:Dirichlet}, y::AbstractVector{<:Real}) T = eltype(y) ϵ = _eps(T) - res = forward(inv(td.transform), y) + res = forward(inverse(td.transform), y) return logpdf(td.dist, mappedarray(x->x+ϵ, res[1])) + res[2] end # TODO: should eventually drop using `logpdf_with_trans` and replace with -# res = forward(inv(td.transform), y) +# res = forward(inverse(td.transform), y) # logpdf(td.dist, res[1]) .- res[2] function _logpdf(td::MatrixTransformed, y::AbstractMatrix{<:Real}) - return logpdf_with_trans(td.dist, inv(td.transform)(y), true) + return logpdf_with_trans(td.dist, inverse(td.transform)(y), true) end # rand @@ -163,18 +163,18 @@ Makes use of the `forward` method to potentially re-use computation and returns a tuple `(logpdf, logabsdetjac)`. """ function logpdf_with_jac(td::UnivariateTransformed, y::Real) - res = forward(inv(td.transform), y) + res = forward(inverse(td.transform), y) return (logpdf(td.dist, res[1]) + res[2], res[2]) end # TODO: implement more efficiently for flows in the case of `Matrix` function logpdf_with_jac(td::MvTransformed, y::AbstractVector{<:Real}) - res = forward(inv(td.transform), y) + res = forward(inverse(td.transform), y) return (logpdf(td.dist, res[1]) + res[2], res[2]) end function logpdf_with_jac(td::MvTransformed, y::AbstractMatrix{<:Real}) - res = forward(inv(td.transform), y) + res = forward(inverse(td.transform), y) return (logpdf(td.dist, res[1]) + res[2], res[2]) end @@ -182,14 +182,14 @@ function logpdf_with_jac(td::MvTransformed{<:Dirichlet}, y::AbstractVector{<:Rea T = eltype(y) ϵ = _eps(T) - res = forward(inv(td.transform), y) + res = forward(inverse(td.transform), y) lp = logpdf(td.dist, mappedarray(x->x+ϵ, res[1])) + res[2] return (lp, res[2]) end # TODO: should eventually drop using `logpdf_with_trans` function logpdf_with_jac(td::MatrixTransformed, y::AbstractMatrix{<:Real}) - res = forward(inv(td.transform), y) + res = forward(inverse(td.transform), y) return (logpdf_with_trans(td.dist, res[1], true), res[2]) end @@ -293,7 +293,7 @@ logabsdetjacinv(d::MultivariateDistribution, x::AbstractVector{T}) where {T<:Rea Computes the `logabsdetjac` of the _inverse_ transformation, since `rand(td)` returns the _transformed_ random variable. """ -logabsdetjacinv(td::UnivariateTransformed, y::Real) = logabsdetjac(inv(td.transform), y) +logabsdetjacinv(td::UnivariateTransformed, y::Real) = logabsdetjac(inverse(td.transform), y) function logabsdetjacinv(td::MvTransformed, y::AbstractVector{<:Real}) - return logabsdetjac(inv(td.transform), y) + return logabsdetjac(inverse(td.transform), y) end diff --git a/test/ad/flows.jl b/test/ad/flows.jl index 351fa987..335f6333 100644 --- a/test/ad/flows.jl +++ b/test/ad/flows.jl @@ -14,12 +14,12 @@ # logpdf of a flow with the inverse of a planar layer and two-dimensional inputs test_ad(randn(7)) do θ layer = PlanarLayer(θ[1:2], θ[3:4], θ[5:5]) - flow = transformed(MvNormal(zeros(2), I), inv(layer)) + flow = transformed(MvNormal(zeros(2), I), inverse(layer)) return logpdf_forward(flow, θ[6:7]) end test_ad(randn(11)) do θ layer = PlanarLayer(θ[1:2], θ[3:4], θ[5:5]) - flow = transformed(MvNormal(zeros(2), I), inv(layer)) + flow = transformed(MvNormal(zeros(2), I), inverse(layer)) return sum(logpdf_forward(flow, reshape(θ[6:end], 2, :))) end end diff --git a/test/bijectors/coupling.jl b/test/bijectors/coupling.jl index ead12bdb..298eab47 100644 --- a/test/bijectors/coupling.jl +++ b/test/bijectors/coupling.jl @@ -34,9 +34,9 @@ using Bijectors: @test cl2(x) == cl1(x) # inversion - icl1 = inv(cl1) + icl1 = inverse(cl1) @test icl1(cl1(x)) == x - @test inv(cl2)(cl2(x)) == x + @test inverse(cl2)(cl2(x)) == x # This `cl2` should result in b = Shift(x[2:2]) diff --git a/test/bijectors/leaky_relu.jl b/test/bijectors/leaky_relu.jl index 5a98f5d9..d06a112a 100644 --- a/test/bijectors/leaky_relu.jl +++ b/test/bijectors/leaky_relu.jl @@ -14,8 +14,8 @@ true_logabsdetjac(b::Bijector{1}, xs::AbstractMatrix) = mapreduce(z -> true_loga @testset "0-dim parameter, 0-dim input" begin b = LeakyReLU(0.1; dim=Val(0)) x = 1. - @test inv(b)(b(x)) == x - @test inv(b)(b(-x)) == -x + @test inverse(b)(b(x)) == x + @test inverse(b)(b(-x)) == -x # Mixing of types # 1. Changes in input-type @@ -54,8 +54,8 @@ end b = LeakyReLU(0.1; dim=Val(1)) x = ones(d) - @test inv(b)(b(x)) == x - @test inv(b)(b(-x)) == -x + @test inverse(b)(b(x)) == x + @test inverse(b)(b(-x)) == -x # Batch xs = randn(d, 10) diff --git a/test/bijectors/named_bijector.jl b/test/bijectors/named_bijector.jl index 015e0a7f..a7248fae 100644 --- a/test/bijectors/named_bijector.jl +++ b/test/bijectors/named_bijector.jl @@ -18,7 +18,7 @@ end nc2 = b ∘ b @test nc1 == nc2 - inc2 = inv(nc2) + inc2 = inverse(nc2) @test (inc2 ∘ nc2)(x) == x @test logabsdetjac((inc2 ∘ nc2), x) ≈ 0.0 end @@ -37,18 +37,18 @@ end x = (a = 1.0, b = 0.5, c = 99999.) @test Bijectors.coupling(nc)(x.a) isa Logit - @test inv(nc)(nc(x)) == x + @test inverse(nc)(nc(x)) == x @test logabsdetjac(nc, x) == logabsdetjac(Logit(0., 1.), x.b) - @test logabsdetjac(inv(nc), nc(x)) == -logabsdetjac(nc, x) + @test logabsdetjac(inverse(nc), nc(x)) == -logabsdetjac(nc, x) x = (a = 0.0, b = 2.0, c = 1.0) nc = NamedCoupling(:c, (:a, :b), (a, b) -> Logit(a, b)) @test nc(x).c == 0.0 - @test inv(nc)(nc(x)) == x + @test inverse(nc)(nc(x)) == x x = (a = 0.0, b = 2.0, c = 1.0) nc = NamedCoupling(:c, (:b, ), b -> Shift(b)) @test nc(x).c == 3.0 - @test inv(nc)(nc(x)) == x + @test inverse(nc)(nc(x)) == x end diff --git a/test/bijectors/permute.jl b/test/bijectors/permute.jl index db9f20ac..6602bb03 100644 --- a/test/bijectors/permute.jl +++ b/test/bijectors/permute.jl @@ -21,10 +21,10 @@ using Bijectors: Permute @test b1.A == b2.A == b3.A == b4.A x = [1., 2.] - @test (inv(b1) ∘ b1)(x) == x - @test (inv(b2) ∘ b2)(x) == x - @test (inv(b3) ∘ b3)(x) == x - @test (inv(b4) ∘ b4)(x) == x + @test (inverse(b1) ∘ b1)(x) == x + @test (inverse(b2) ∘ b2)(x) == x + @test (inverse(b3) ∘ b3)(x) == x + @test (inverse(b4) ∘ b4)(x) == x # Slightly more complex case; one entry is not permuted b1 = Permute([ @@ -39,10 +39,10 @@ using Bijectors: Permute @test b1.A == b2.A == b3.A == b4.A x = [1., 2., 3.] - @test (inv(b1) ∘ b1)(x) == x - @test (inv(b2) ∘ b2)(x) == x - @test (inv(b3) ∘ b3)(x) == x - @test (inv(b4) ∘ b4)(x) == x + @test (inverse(b1) ∘ b1)(x) == x + @test (inverse(b2) ∘ b2)(x) == x + @test (inverse(b3) ∘ b3)(x) == x + @test (inverse(b4) ∘ b4)(x) == x # logabsdetjac @test logabsdetjac(b1, x) == 0.0 diff --git a/test/bijectors/utils.jl b/test/bijectors/utils.jl index 98919af9..ce2d03c3 100644 --- a/test/bijectors/utils.jl +++ b/test/bijectors/utils.jl @@ -6,7 +6,7 @@ function test_bijector_reals( isequal = true, tol = 1e-6 ) - ib = @inferred inv(b) + ib = @inferred inverse(b) y = @inferred b(x_true) logjac = @inferred logabsdetjac(b, x_true) ilogjac = @inferred logabsdetjac(ib, y_true) @@ -14,7 +14,7 @@ function test_bijector_reals( # If `isequal` is false, then we use the computed `y`, # but if it's true, we use the true `y`. - ires = isequal ? @inferred(forward(inv(b), y_true)) : @inferred(forward(inv(b), y)) + ires = isequal ? @inferred(forward(inverse(b), y_true)) : @inferred(forward(inverse(b), y)) # Always want the following to hold @test ires[1] ≈ x_true atol=tol @@ -43,13 +43,13 @@ function test_bijector_arrays( isequal = true, tol = 1e-6 ) - ib = @inferred inv(b) + ib = @inferred inverse(b) ys = @inferred b(xs_true) logjacs = @inferred logabsdetjac(b, xs_true) res = @inferred forward(b, xs_true) # If `isequal` is false, then we use the computed `y`, # but if it's true, we use the true `y`. - ires = isequal ? @inferred(forward(inv(b), ys_true)) : @inferred(forward(inv(b), ys)) + ires = isequal ? @inferred(forward(inverse(b), ys_true)) : @inferred(forward(inverse(b), ys)) # always want the following to hold @test ys isa typeof(ys_true) @@ -118,7 +118,7 @@ function test_bijector( logjacs_true::AbstractVector{<:Real}; kwargs... ) - ib = inv(b) + ib = inverse(b) # Batch test_bijector_arrays(b, xs_true, ys_true, logjacs_true; kwargs...) @@ -148,7 +148,7 @@ function test_bijector( logjacs_true::AbstractVector{<:Real}; kwargs... ) - ib = inv(b) + ib = inverse(b) # Batch test_bijector_arrays(b, xs_true, ys_true, logjacs_true; kwargs...) diff --git a/test/interface.jl b/test/interface.jl index 27da4517..d358d6cc 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -18,7 +18,7 @@ MyADBijector(d::Distribution) = MyADBijector{Bijectors.ADBackend()}(d) MyADBijector{AD}(d::Distribution) where {AD} = MyADBijector{AD}(bijector(d)) MyADBijector{AD}(b::B) where {AD, N, B <: Bijector{N}} = MyADBijector{AD, N, B}(b) (b::MyADBijector)(x) = b.b(x) -(b::Inverse{<:MyADBijector})(x) = inv(b.orig.b)(x) +(b::Inverse{<:MyADBijector})(x) = inverse(b.orig.b)(x) struct NonInvertibleBijector{AD} <: ADBijector{AD, 1} end @@ -73,7 +73,7 @@ end # single sample y = @inferred rand(td) - x = @inferred inv(td.transform)(y) + x = @inferred inverse(td.transform)(y) @test y ≈ @inferred td.transform(x) @test @inferred(logpdf(td, y)) ≈ @inferred(logpdf_with_trans(dist, x, true)) @@ -84,7 +84,7 @@ end # multi-sample y = @inferred rand(td, 10) - x = inv(td.transform).(y) + x = inverse(td.transform).(y) @test logpdf.(td, y) ≈ logpdf_with_trans.(dist, x, true) # logpdf corresponds to logpdf_with_trans @@ -92,12 +92,12 @@ end b = @inferred bijector(d) x = rand(d) y = @inferred b(x) - @test logpdf(d, inv(b)(y)) + logabsdetjacinv(b, y) ≈ logpdf_with_trans(d, x, true) + @test logpdf(d, inverse(b)(y)) + logabsdetjacinv(b, y) ≈ logpdf_with_trans(d, x, true) @test logpdf(d, x) - logabsdetjac(b, x) ≈ logpdf_with_trans(d, x, true) # forward f = @inferred forward(td) - @test f.x ≈ inv(td.transform)(f.y) + @test f.x ≈ inverse(td.transform)(f.y) @test f.y ≈ td.transform(f.x) @test f.logabsdetjac ≈ logabsdetjac(td.transform, f.x) @test f.logpdf ≈ logpdf_with_trans(td.dist, f.x, true) @@ -111,7 +111,7 @@ end # `ForwardDiff.derivative` can lead to some numerical inaccuracy, # so we use a slightly higher `atol` than default. @test log(abs(ForwardDiff.derivative(b, x))) ≈ logabsdetjac(b, x) atol=1e-6 - @test log(abs(ForwardDiff.derivative(inv(b), y))) ≈ logabsdetjac(inv(b), y) atol=1e-6 + @test log(abs(ForwardDiff.derivative(inverse(b), y))) ≈ logabsdetjac(inverse(b), y) atol=1e-6 end @testset "$dist: ForwardDiff AD" begin @@ -122,7 +122,7 @@ end @test logabsdetjac(b, x) ≠ Inf y = b(x) - b⁻¹ = inv(b) + b⁻¹ = inverse(b) @test abs(det(Bijectors.jacobian(b⁻¹, y))) > 0 @test logabsdetjac(b⁻¹, y) ≠ Inf end @@ -135,7 +135,7 @@ end @test logabsdetjac(b, x) ≠ Inf y = b(x) - b⁻¹ = inv(b) + b⁻¹ = inverse(b) @test abs(det(Bijectors.jacobian(b⁻¹, y))) > 0 @test logabsdetjac(b⁻¹, y) ≠ Inf end @@ -153,7 +153,7 @@ end (Exp{0}(), randn(3)), (Exp{1}(), randn(2, 3)), (Log{1}() ∘ Exp{1}(), randn(2, 3)), - (inv(Logit(-1.0, 1.0)), randn(3)), + (inverse(Logit(-1.0, 1.0)), randn(3)), (Identity{0}(), randn(3)), (Identity{1}(), randn(2, 3)), (PlanarLayer(2), randn(2, 3)), @@ -173,7 +173,7 @@ end for (b, xs) in bs_xs @testset "$b" begin D = @inferred Bijectors.dimension(b) - ib = @inferred inv(b) + ib = @inferred inverse(b) @test Bijectors.dimension(ib) == D @@ -285,9 +285,9 @@ end @test logabsdetjac(cb1, 1.) isa Real @test logabsdetjac(cb1, 1.) == 1. - @test inv(cb1) isa Composed{<:Tuple} - @test inv(cb2) isa Composed{<:Tuple} - @test inv(cb3) isa Composed{<:Tuple} + @test inverse(cb1) isa Composed{<:Tuple} + @test inverse(cb2) isa Composed{<:Tuple} + @test inverse(cb3) isa Composed{<:Tuple} # Check that type-unstable composition stays type-unstable cb1 = Composed([Exp(), Log()]) ∘ Exp() @@ -300,9 +300,9 @@ end @test logabsdetjac(cb1, 1.) isa Real @test logabsdetjac(cb1, 1.) == 1. - @test inv(cb1) isa Composed{<:AbstractArray} - @test inv(cb2) isa Composed{<:AbstractArray} - @test inv(cb3) isa Composed{<:AbstractArray} + @test inverse(cb1) isa Composed{<:AbstractArray} + @test inverse(cb2) isa Composed{<:AbstractArray} + @test inverse(cb3) isa Composed{<:AbstractArray} # combining the two @test_throws ErrorException (Log() ∘ Exp()) ∘ cb1 @@ -376,7 +376,7 @@ end x = rand(d) y = b(x) @test y ≈ link(d, x) - @test inv(b)(y) ≈ x + @test inverse(b)(y) ≈ x @test logabsdetjac(b, x) ≈ logpdf_with_trans(d, x, false) - logpdf_with_trans(d, x, true) d = truncated(Normal(), -Inf, 1) @@ -384,7 +384,7 @@ end x = rand(d) y = b(x) @test y ≈ link(d, x) - @test inv(b)(y) ≈ x + @test inverse(b)(y) ≈ x @test logabsdetjac(b, x) ≈ logpdf_with_trans(d, x, false) - logpdf_with_trans(d, x, true) d = truncated(Normal(), 1, Inf) @@ -392,7 +392,7 @@ end x = rand(d) y = b(x) @test y ≈ link(d, x) - @test inv(b)(y) ≈ x + @test inverse(b)(y) ≈ x @test logabsdetjac(b, x) ≈ logpdf_with_trans(d, x, false) - logpdf_with_trans(d, x, true) end @@ -415,8 +415,8 @@ end # single sample y = rand(td) - x = inv(td.transform)(y) - @test inv(td.transform)(param(y)) isa TrackedArray + x = inverse(td.transform)(y) + @test inverse(td.transform)(param(y)) isa TrackedArray @test y ≈ td.transform(x) @test td.transform(param(x)) isa TrackedArray @test logpdf(td, y) ≈ logpdf_with_trans(dist, x, true) @@ -428,13 +428,13 @@ end # multi-sample y = rand(td, 10) - x = inv(td.transform)(y) - @test inv(td.transform)(param(y)) isa TrackedArray + x = inverse(td.transform)(y) + @test inverse(td.transform)(param(y)) isa TrackedArray @test logpdf(td, y) ≈ logpdf_with_trans(dist, x, true) # forward f = forward(td) - @test f.x ≈ inv(td.transform)(f.y) + @test f.x ≈ inverse(td.transform)(f.y) @test f.y ≈ td.transform(f.x) @test f.logabsdetjac ≈ logabsdetjac(td.transform, f.x) @test f.logpdf ≈ logpdf_with_trans(td.dist, f.x, true) @@ -447,7 +447,7 @@ end y = b(x) @test b(param(x)) isa TrackedArray @test log(abs(det(ForwardDiff.jacobian(b, x)))) ≈ logabsdetjac(b, x) - @test log(abs(det(ForwardDiff.jacobian(inv(b), y)))) ≈ logabsdetjac(inv(b), y) + @test log(abs(det(ForwardDiff.jacobian(inverse(b), y)))) ≈ logabsdetjac(inverse(b), y) else b = bijector(dist) x = rand(dist) @@ -456,7 +456,7 @@ end # so we use a slightly higher `atol` than default. @test b(param(x)) isa TrackedArray @test log(abs(det(ForwardDiff.jacobian(b, x)))) ≈ logabsdetjac(b, x) atol=1e-6 - @test log(abs(det(ForwardDiff.jacobian(inv(b), y)))) ≈ logabsdetjac(inv(b), y) atol=1e-6 + @test log(abs(det(ForwardDiff.jacobian(inverse(b), y)))) ≈ logabsdetjac(inverse(b), y) atol=1e-6 end end end @@ -481,8 +481,8 @@ end # single sample y = rand(td) - x = inv(td.transform)(y) - @test inv(td.transform)(param(y)) isa TrackedArray + x = inverse(td.transform)(y) + @test inverse(td.transform)(param(y)) isa TrackedArray @test logpdf(td, y) ≈ logpdf_with_trans(dist, x, true) # TODO: implement `logabsdetjac` for these @@ -493,8 +493,8 @@ end # multi-sample y = rand(td, 10) - x = inv(td.transform)(y) - @test inv(td.transform)(param.(y)) isa Vector{<:TrackedArray} + x = inverse(td.transform)(y) + @test inverse(td.transform)(param.(y)) isa Vector{<:TrackedArray} @test logpdf(td, y) ≈ logpdf_with_trans(dist, x, true) end end @@ -508,10 +508,10 @@ end y = td.transform(x) b = @inferred Bijectors.composel(td.transform, Bijectors.Identity{0}()) - ib = @inferred inv(b) + ib = @inferred inverse(b) @test forward(b, x) == forward(td.transform, x) - @test forward(ib, y) == forward(inv(td.transform), y) + @test forward(ib, y) == forward(inverse(td.transform), y) @test forward(b, x) == forward(Bijectors.composer(b.ts...), x) @@ -524,26 +524,26 @@ end # ensures that the `logabsdetjac` is correct x = rand(d) - b = inv(bijector(d)) + b = inverse(bijector(d)) @test logabsdetjac(b ∘ b, x) ≈ logabsdetjac(b, b(x)) + logabsdetjac(b, x) # order of composed evaluation b1 = MyADBijector(d) b2 = MyADBijector(Gamma()) - cb = inv(b1) ∘ b2 - @test cb(x) ≈ inv(b1)(b2(x)) + cb = inverse(b1) ∘ b2 + @test cb(x) ≈ inverse(b1)(b2(x)) # contrived example b = bijector(d) - cb = @inferred inv(b) ∘ b + cb = @inferred inverse(b) ∘ b cb = @inferred cb ∘ cb @test @inferred(cb ∘ cb ∘ cb ∘ cb ∘ cb)(x) ≈ x # forward for tuple and array d = Beta() - b = @inferred inv(bijector(d)) - b⁻¹ = @inferred inv(b) + b = @inferred inverse(bijector(d)) + b⁻¹ = @inferred inverse(b) x = rand(d) cb_t = b⁻¹ ∘ b⁻¹ @@ -571,7 +571,7 @@ end x = rand(d) y = b(x) - sb1 = @inferred stack(b, b, inv(b), inv(b)) # <= Tuple + sb1 = @inferred stack(b, b, inverse(b), inverse(b)) # <= Tuple res1 = forward(sb1, [x, x, y, y]) @test sb1(param([x, x, y, y])) isa TrackedArray @@ -579,7 +579,7 @@ end @test logabsdetjac(sb1, [x, x, y, y]) ≈ 0 atol=1e-6 @test res1[2] ≈ 0 atol=1e-6 - sb2 = Stacked([b, b, inv(b), inv(b)]) # <= Array + sb2 = Stacked([b, b, inverse(b), inverse(b)]) # <= Array res2 = forward(sb2, [x, x, y, y]) @test sb2(param([x, x, y, y])) isa TrackedArray @@ -591,7 +591,7 @@ end b = MyADBijector(d) y = b(x) - sb1 = stack(b, b, inv(b), inv(b)) # <= Tuple + sb1 = stack(b, b, inverse(b), inverse(b)) # <= Tuple res1 = forward(sb1, [x, x, y, y]) @test sb1(param([x, x, y, y])) isa TrackedArray @@ -599,7 +599,7 @@ end @test logabsdetjac(sb1, [x, x, y, y]) ≈ 0 atol=1e-12 @test res1[2] ≈ 0.0 atol=1e-12 - sb2 = Stacked([b, b, inv(b), inv(b)]) # <= Array + sb2 = Stacked([b, b, inverse(b), inverse(b)]) # <= Array res2 = forward(sb2, [x, x, y, y]) @test sb2(param([x, x, y, y])) isa TrackedArray @@ -719,9 +719,9 @@ end # Stacked{<:Tuple} bs = bijector.(tuple(dists...)) - ibs = inv.(bs) + ibs = inverse.(bs) sb = @inferred Stacked(ibs, ranges) - isb = @inferred inv(sb) + isb = @inferred inverse(sb) @test sb isa Stacked{<:Tuple} # inverse @@ -756,7 +756,7 @@ end # Usage in ADVI d = Beta() b = bijector(d) # [0, 1] → ℝ - ib = inv(b) # ℝ → [0, 1] + ib = inverse(b) # ℝ → [0, 1] td = transformed(Normal(), ib) # x ∼ 𝓝(0, 1) then f(x) ∈ [0, 1] x = rand(td) # ∈ [0, 1] @test 0 ≤ x ≤ 1 @@ -764,7 +764,7 @@ end @testset "Jacobians of SimplexBijector" begin b = SimplexBijector() - ib = inv(b) + ib = inverse(b) x = ib(randn(10)) y = b(x) @@ -847,7 +847,7 @@ end for i in 1:length(bs), j in 1:length(bs) if i == j @test bs[i] == deepcopy(bs[j]) - @test inv(bs[i]) == inv(deepcopy(bs[j])) + @test inverse(bs[i]) == inverse(deepcopy(bs[j])) else @test bs[i] != bs[j] end diff --git a/test/norm_flows.jl b/test/norm_flows.jl index 9fadf573..38ea7e5f 100644 --- a/test/norm_flows.jl +++ b/test/norm_flows.jl @@ -8,15 +8,15 @@ seed!(1) x = randn(2, 20) bn = InvertibleBatchNorm(2) - @test inv(inv(bn)) == bn - @test inv(bn)(bn(x)) ≈ x - @test (inv(bn) ∘ bn)(x) ≈ x + @test inverse(inverse(bn)) == bn + @test inverse(bn)(bn(x)) ≈ x + @test (inverse(bn) ∘ bn)(x) ≈ x @test_throws ErrorException forward(bn, randn(10,2)) - @test logabsdetjac(inv(bn), bn(x)) ≈ - logabsdetjac(bn, x) + @test logabsdetjac(inverse(bn), bn(x)) ≈ - logabsdetjac(bn, x) y, ladj = forward(bn, x) @test log(abs(det(ForwardDiff.jacobian(bn, x)))) ≈ sum(ladj) - @test log(abs(det(ForwardDiff.jacobian(inv(bn), y)))) ≈ sum(logabsdetjac(inv(bn), y)) + @test log(abs(det(ForwardDiff.jacobian(inverse(bn), y)))) ≈ sum(logabsdetjac(inverse(bn), y)) test_functor(bn, (b = bn.b, logs = bn.logs)) end @@ -29,8 +29,8 @@ end our_method = sum(forward(flow, z)[2]) @test our_method ≈ forward_diff - @test inv(flow)(flow(z)) ≈ z - @test (inv(flow) ∘ flow)(z) ≈ z + @test inverse(flow)(flow(z)) ≈ z + @test (inverse(flow) ∘ flow)(z) ≈ z end w = ones(10) @@ -38,10 +38,10 @@ end b = 1.0 flow = PlanarLayer(w, u, b) z = ones(10, 100) - @test inv(flow)(flow(z)) ≈ z + @test inverse(flow)(flow(z)) ≈ z test_functor(flow, (w = w, u = u, b = b)) - test_functor(inv(flow), (orig = flow,)) + test_functor(inverse(flow), (orig = flow,)) @testset "find_alpha" begin for wt_y in (-20.3, -3, -3//2, 0.0, 5, 29//4, 12.3) @@ -77,8 +77,8 @@ end our_method = sum(forward(flow, z)[2]) @test our_method ≈ forward_diff - @test inv(flow)(flow(z)) ≈ z rtol=0.2 - @test (inv(flow) ∘ flow)(z) ≈ z rtol=0.2 + @test inverse(flow)(flow(z)) ≈ z rtol=0.2 + @test (inverse(flow) ∘ flow)(z) ≈ z rtol=0.2 end α_ = 1.0 @@ -86,10 +86,10 @@ end z_0 = zeros(10) z = ones(10, 100) flow = RadialLayer(α_, β, z_0) - @test inv(flow)(flow(z)) ≈ z + @test inverse(flow)(flow(z)) ≈ z test_functor(flow, (α_ = α_, β = β, z_0 = z_0)) - test_functor(inv(flow), (orig = flow,)) + test_functor(inverse(flow), (orig = flow,)) end @testset "Flows" begin @@ -110,9 +110,9 @@ end # flow with unconstrained-to-constrained d1 = Beta() - b1 = inv(bijector(d1)) + b1 = inverse(bijector(d1)) d2 = InverseGamma() - b2 = inv(bijector(d2)) + b2 = inverse(bijector(d2)) x = rand(d) .+ 10 y = b(x) From f6385ea4f2d2af8fe4fadfe06c4584411e95767b Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sat, 11 Dec 2021 00:19:51 +0100 Subject: [PATCH 04/29] Improve deprecation scheme for forward Co-authored-by: David Widmann --- src/Bijectors.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Bijectors.jl b/src/Bijectors.jl index a540f435..59c4934f 100644 --- a/src/Bijectors.jl +++ b/src/Bijectors.jl @@ -252,7 +252,7 @@ include("utils.jl") include("interface.jl") include("chainrules.jl") -Base.@deprecate forward(b::AbstractBijector, x) with_logabsdet_jacobian(b, x) +Base.@deprecate forward(b::AbstractBijector, x) NamedTuple{(:rv,:logabsdetjac)}(with_logabsdet_jacobian(b, x)) import Base.inv Base.@deprecate inv(b::AbstractBijector) inverse(b) From 0aab88b5ba9fd531c673a326da9aace2076eb7b7 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sat, 11 Dec 2021 00:24:50 +0100 Subject: [PATCH 05/29] Improve deprecation scheme for inv --- src/Bijectors.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/Bijectors.jl b/src/Bijectors.jl index 59c4934f..351c7095 100644 --- a/src/Bijectors.jl +++ b/src/Bijectors.jl @@ -254,8 +254,10 @@ include("chainrules.jl") Base.@deprecate forward(b::AbstractBijector, x) NamedTuple{(:rv,:logabsdetjac)}(with_logabsdet_jacobian(b, x)) -import Base.inv -Base.@deprecate inv(b::AbstractBijector) inverse(b) +@noinline function Base.inv(b::AbstractBijector) + Base.depwarn("`Base.inv(b::AbstractBijector)` is deprecated, use `InverseFunctions.inverse(b)` instead.", :(Base.inv)) + inverse(b) +end # Broadcasting here breaks Tracker for some reason maporbroadcast(f, x::AbstractArray{<:Any, N}...) where {N} = map(f, x...) From e6f549dba71438bab0dd877290f1f8c705d0106a Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sat, 11 Dec 2021 00:33:23 +0100 Subject: [PATCH 06/29] Test forward and inv deprecations --- test/interface.jl | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/test/interface.jl b/test/interface.jl index d358d6cc..36a039fc 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -854,3 +854,12 @@ end end end +@testset "deprecations" begin + b = Bijectors.Exp() + x = 0.3 + + @test let r = forward(b, x) + (r.rv, r.logabsdetjac) == with_logabsdet_jacobian(b, x) + end + @test inv(b) == inverse(b) +end From 4c7f7067b08a1716ef7aac0151e5ce8e8f6c977d Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sat, 11 Dec 2021 01:42:41 +0100 Subject: [PATCH 07/29] Apply suggestions from code review Co-authored-by: David Widmann --- README.md | 8 +++--- src/Bijectors.jl | 4 +-- src/bijectors/composed.jl | 26 ++++++++----------- src/bijectors/named_bijector.jl | 26 ++++++++----------- src/bijectors/normalise.jl | 6 ++--- src/bijectors/planar_layer.jl | 2 +- src/bijectors/radial_layer.jl | 2 +- src/interface.jl | 2 +- src/transformed_distribution.jl | 46 ++++++++++++++++----------------- 9 files changed, 57 insertions(+), 65 deletions(-) diff --git a/README.md b/README.md index 4cbbf576..b23a60e7 100644 --- a/README.md +++ b/README.md @@ -199,7 +199,7 @@ julia> logpdf_forward(td, x) -1.123311289915276 ``` -#### `logabsdetjac` and `forward` +#### `logabsdetjac` and `with_logabsdet_jacobian` In the computation of both `logpdf` and `logpdf_forward` we need to compute `log(abs(det(jacobian(inverse(b), y))))` and `log(abs(det(jacobian(b, x))))`, respectively. This computation is available using the `logabsdetjac` method @@ -218,7 +218,7 @@ julia> logabsdetjac(b, x) ≈ -logabsdetjac(b⁻¹, y) true ``` -which is always the case for a differentiable bijection with differentiable inverse. Therefore if you want to compute `logabsdetjac(b⁻¹, y)` and we know that `logabsdetjac(b, b⁻¹(y))` is actually more efficient, we'll return `-logabsdetjac(b, b⁻¹(y))` instead. For some bijectors it might be easy to compute, say, the forward pass `b(x)`, but expensive to compute `b⁻¹(y)`. Because of this you might want to avoid doing anything "backwards", i.e. using `b⁻¹`. This is where `forward` comes to good use: +which is always the case for a differentiable bijection with differentiable inverse. Therefore if you want to compute `logabsdetjac(b⁻¹, y)` and we know that `logabsdetjac(b, b⁻¹(y))` is actually more efficient, we'll return `-logabsdetjac(b, b⁻¹(y))` instead. For some bijectors it might be easy to compute, say, the forward pass `b(x)`, but expensive to compute `b⁻¹(y)`. Because of this you might want to avoid doing anything "backwards", i.e. using `b⁻¹`. This is where `with_logabsdet_jacobian` comes to good use: ```julia julia> with_logabsdet_jacobian(b, x) @@ -228,7 +228,7 @@ julia> with_logabsdet_jacobian(b, x) Similarily ```julia -julia> forward(inverse(b), y) +julia> with_logabsdet_jacobian(inverse(b), y) (0.3688868996596376, -1.4575353795716655) ``` @@ -716,7 +716,7 @@ The following methods are implemented by all subtypes of `Bijector`, this also i - `(b::Bijector)(x)`: implements the transform of the `Bijector` - `inverse(b::Bijector)`: returns the inverse of `b`, i.e. `ib::Bijector` s.t. `(ib ∘ b)(x) ≈ x`. In most cases this is `Inverse{<:Bijector}`. - `logabsdetjac(b::Bijector, x)`: computes log(abs(det(jacobian(b, x)))). -- `with_logabsdet_jacobian(b::Bijector, x)`: returns named tuple `(b(x), logabsdetjac(b, x))` in the most efficient manner. +- `with_logabsdet_jacobian(b::Bijector, x)`: returns the tuple `(b(x), logabsdetjac(b, x))` in the most efficient manner. - `∘`, `composel`, `composer`: convenient and type-safe constructors for `Composed`. `composel(bs...)` composes s.t. the resulting composition is evaluated left-to-right, while `composer(bs...)` is evaluated right-to-left. `∘` is right-to-left, as excepted from standard mathematical notation. - `jacobian(b::Bijector, x)` [OPTIONAL]: returns the Jacobian of the transformation. In some cases the analytical Jacobian has been implemented for efficiency. - `dimension(b::Bijector)`: returns the dimensionality of `b`. diff --git a/src/Bijectors.jl b/src/Bijectors.jl index 351c7095..b3a5f710 100644 --- a/src/Bijectors.jl +++ b/src/Bijectors.jl @@ -35,8 +35,8 @@ using MappedArrays using Base.Iterators: drop using LinearAlgebra: AbstractTriangular -import ChangesOfVariables: with_logabsdet_jacobian -import InverseFunctions: inverse +using ChangesOfVariables: ChangesOfVariables, with_logabsdet_jacobian +using InverseFunctions: InverseFunctions, inverse import ChainRulesCore import Functors diff --git a/src/bijectors/composed.jl b/src/bijectors/composed.jl index 482a2760..e6516181 100644 --- a/src/bijectors/composed.jl +++ b/src/bijectors/composed.jl @@ -176,11 +176,10 @@ end end function logabsdetjac(cb::Composed, x) - y, logjac = forward(cb.ts[1], x) + y, logjac = with_logabsdet_jacobian(cb.ts[1], x) for i = 2:length(cb.ts) - res = forward(cb.ts[i], y) - y = res[1] - logjac += res[2] + y, res_logjac = with_logabsdet_jacobian(cb.ts[i], y) + logjac += res_logjac end return logjac @@ -193,10 +192,9 @@ end push!(expr.args, :((y, logjac) = forward(cb.ts[1], x))) for i = 2:N - 1 - temp = gensym(:res) - push!(expr.args, :($temp = forward(cb.ts[$i], y))) - push!(expr.args, :(y = $temp[1])) - push!(expr.args, :(logjac += $temp[2])) + temp = gensym(:res_logjac) + push!(expr.args, :(y, $temp = with_logabsdet_jacobian(cb.ts[$i], y))) + push!(expr.args, :(logjac += $temp)) end # don't need to evaluate the last bijector, only it's `logabsdetjac` push!(expr.args, :(logjac += logabsdetjac(cb.ts[$N], y))) @@ -211,9 +209,8 @@ function forward(cb::Composed, x) rv, logjac = forward(cb.ts[1], x) for t in cb.ts[2:end] - res = forward(t, rv) - rv = res[1] - logjac = res[2] + logjac + rv, res_logjac = with_logabsdet_jacobian(t, rv) + logjac += res_logjac end return (rv, logjac) end @@ -223,10 +220,9 @@ end expr = Expr(:block) push!(expr.args, :((y, logjac) = forward(cb.ts[1], x))) for i = 2:length(T.parameters) - temp = gensym(:temp) - push!(expr.args, :($temp = forward(cb.ts[$i], y))) - push!(expr.args, :(y = $temp[1])) - push!(expr.args, :(logjac += $temp[2])) + temp = gensym(:res_logjac) + push!(expr.args, :(y, $temp = with_logabsdet_jacobian(cb.ts[$i], y))) + push!(expr.args, :(logjac += $temp)) end push!(expr.args, :(return (y, logjac))) diff --git a/src/bijectors/named_bijector.jl b/src/bijectors/named_bijector.jl index 8f90464e..45af6b95 100644 --- a/src/bijectors/named_bijector.jl +++ b/src/bijectors/named_bijector.jl @@ -122,11 +122,10 @@ end (cb::NamedComposition{<:Tuple})(x) = foldl(|>, cb.bs; init=x) function logabsdetjac(cb::NamedComposition, x) - y, logjac = forward(cb.bs[1], x) + y, logjac = with_logabsdet_jacobian(cb.bs[1], x) for i = 2:length(cb.bs) - res = forward(cb.bs[i], y) - y = res[1] - logjac += res[2] + y, res_logjac = with_logabsdet_jacobian(cb.bs[i], y) + logjac += res_logjac end return logjac @@ -139,10 +138,9 @@ end push!(expr.args, :((y, logjac) = forward(cb.bs[1], x))) for i = 2:N - 1 - temp = gensym(:res) - push!(expr.args, :($temp = forward(cb.bs[$i], y))) - push!(expr.args, :(y = $temp[1])) - push!(expr.args, :(logjac += $temp[2])) + temp = gensym(:res_logjac) + push!(expr.args, :(y, $temp = with_logabsdet_jacobian(cb.bs[$i], y))) + push!(expr.args, :(logjac += $temp)) end # don't need to evaluate the last bijector, only it's `logabsdetjac` push!(expr.args, :(logjac += logabsdetjac(cb.bs[$N], y))) @@ -157,9 +155,8 @@ function forward(cb::NamedComposition, x) rv, logjac = forward(cb.bs[1], x) for t in cb.bs[2:end] - res = forward(t, rv) - rv = res[1] - logjac = res[2] + logjac + rv, res_logjac = with_logabsdet_jacobian(t, rv) + logjac += res_logjac end return (rv, logjac) end @@ -169,10 +166,9 @@ end expr = Expr(:block) push!(expr.args, :((y, logjac) = forward(cb.bs[1], x))) for i = 2:length(T.parameters) - temp = gensym(:temp) - push!(expr.args, :($temp = forward(cb.bs[$i], y))) - push!(expr.args, :(y = $temp[1])) - push!(expr.args, :(logjac += $temp[2])) + temp = gensym(:res_logjac) + push!(expr.args, :(y, $temp = with_logabsdet_jacobian(cb.bs[$i], y))) + push!(expr.args, :(logjac += $temp)) end push!(expr.args, :(return (y, logjac))) diff --git a/src/bijectors/normalise.jl b/src/bijectors/normalise.jl index 43972eeb..fcbaff91 100644 --- a/src/bijectors/normalise.jl +++ b/src/bijectors/normalise.jl @@ -79,9 +79,9 @@ function with_logabsdet_jacobian(bn::InvertibleBatchNorm, x) return (rv, logabsdetjac) end -logabsdetjac(bn::InvertibleBatchNorm, x) = with_logabsdet_jacobian(bn, x)[2] +logabsdetjac(bn::InvertibleBatchNorm, x) = last(with_logabsdet_jacobian(bn, x)) -(bn::InvertibleBatchNorm)(x) = with_logabsdet_jacobian(bn, x)[1] +(bn::InvertibleBatchNorm)(x) = first(with_logabsdet_jacobian(bn, x)) function forward(invbn::Inverse{<:InvertibleBatchNorm}, y) @assert !istraining() "`forward(::Inverse{InvertibleBatchNorm})` is only available in test mode." @@ -97,7 +97,7 @@ function forward(invbn::Inverse{<:InvertibleBatchNorm}, y) return (x, -logabsdetjac(bn, x)) end -(bn::Inverse{<:InvertibleBatchNorm})(y) = with_logabsdet_jacobian(bn, y)[1] +(bn::Inverse{<:InvertibleBatchNorm})(y) = first(with_logabsdet_jacobian(bn, y)) function Base.show(io::IO, l::InvertibleBatchNorm) print(io, "InvertibleBatchNorm($(join(size(l.b), ", ")))") diff --git a/src/bijectors/planar_layer.jl b/src/bijectors/planar_layer.jl index f2dbefff..1f6c022e 100644 --- a/src/bijectors/planar_layer.jl +++ b/src/bijectors/planar_layer.jl @@ -175,5 +175,5 @@ function find_alpha(wt_y::T, wt_u_hat::T, b::T) where {T<:Real} return α0 end -logabsdetjac(flow::PlanarLayer, x) = forward(flow, x)[2] +logabsdetjac(flow::PlanarLayer, x) = last(with_logabsdet_jacobian(flow, x)) isclosedform(b::Inverse{<:PlanarLayer}) = false diff --git a/src/bijectors/radial_layer.jl b/src/bijectors/radial_layer.jl index 11c3d799..f2d656e7 100644 --- a/src/bijectors/radial_layer.jl +++ b/src/bijectors/radial_layer.jl @@ -123,4 +123,4 @@ function compute_r(y_minus_z0::AbstractVector{<:Real}, α, α_plus_β_hat) return r end -logabsdetjac(flow::RadialLayer, x::AbstractVecOrMat) = forward(flow, x)[2] +logabsdetjac(flow::RadialLayer, x::AbstractVecOrMat) = last(with_logabsdet_jacobian(flow, x)) diff --git a/src/interface.jl b/src/interface.jl index 3782dee6..276ce631 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -56,7 +56,7 @@ requires an iterative procedure to evaluate. isclosedform(b::Bijector) = true """ -inverse(b::Bijector) + inverse(b::Bijector) Inverse(b::Bijector) A `Bijector` representing the inverse transform of `b`. diff --git a/src/transformed_distribution.jl b/src/transformed_distribution.jl index d7d40114..10b9ff9d 100644 --- a/src/transformed_distribution.jl +++ b/src/transformed_distribution.jl @@ -85,41 +85,41 @@ Base.length(td::Transformed) = length(td.dist) Base.size(td::Transformed) = size(td.dist) function logpdf(td::UnivariateTransformed, y::Real) - res = forward(inverse(td.transform), y) - return logpdf(td.dist, res[1]) + res[2] + x, logjac = with_logabsdet_jacobian(inverse(td.transform), y) + return logpdf(td.dist, x) + logjac end # TODO: implement more efficiently for flows in the case of `Matrix` function logpdf(td::MvTransformed, y::AbstractMatrix{<:Real}) # batch-implementation for multivariate - res = forward(inverse(td.transform), y) - return logpdf(td.dist, res[1]) + res[2] + x, logjac = with_logabsdet_jacobian(inverse(td.transform), y) + return logpdf(td.dist, x) + logjac end function logpdf(td::MvTransformed{<:Dirichlet}, y::AbstractMatrix{<:Real}) T = eltype(y) ϵ = _eps(T) - res = forward(inverse(td.transform), y) - return logpdf(td.dist, mappedarray(x->x+ϵ, res[1])) + res[2] + x, logjac = with_logabsdet_jacobian(inverse(td.transform), y) + return logpdf(td.dist, mappedarray(x->x+ϵ, x)) + logjac end function _logpdf(td::MvTransformed, y::AbstractVector{<:Real}) - res = forward(inverse(td.transform), y) - return logpdf(td.dist, res[1]) + res[2] + x, logjac = with_logabsdet_jacobian(inverse(td.transform), y) + return logpdf(td.dist, x) + logjac end function _logpdf(td::MvTransformed{<:Dirichlet}, y::AbstractVector{<:Real}) T = eltype(y) ϵ = _eps(T) - res = forward(inverse(td.transform), y) - return logpdf(td.dist, mappedarray(x->x+ϵ, res[1])) + res[2] + x, logjac = with_logabsdet_jacobian(inverse(td.transform), y) + return logpdf(td.dist, mappedarray(x->x+ϵ, x)) + logjac end # TODO: should eventually drop using `logpdf_with_trans` and replace with -# res = forward(inverse(td.transform), y) -# logpdf(td.dist, res[1]) .- res[2] +# x, logjac = with_logabsdet_jacobian(inverse(td.transform), y) +# logpdf(td.dist, x) .- logjac function _logpdf(td::MatrixTransformed, y::AbstractMatrix{<:Real}) return logpdf_with_trans(td.dist, inverse(td.transform)(y), true) end @@ -163,34 +163,34 @@ Makes use of the `forward` method to potentially re-use computation and returns a tuple `(logpdf, logabsdetjac)`. """ function logpdf_with_jac(td::UnivariateTransformed, y::Real) - res = forward(inverse(td.transform), y) - return (logpdf(td.dist, res[1]) + res[2], res[2]) + x, logjac = with_logabsdet_jacobian(inverse(td.transform), y) + return (logpdf(td.dist, x) + logjac, logjac) end # TODO: implement more efficiently for flows in the case of `Matrix` function logpdf_with_jac(td::MvTransformed, y::AbstractVector{<:Real}) - res = forward(inverse(td.transform), y) - return (logpdf(td.dist, res[1]) + res[2], res[2]) + x, logjac = with_logabsdet_jacobian(inverse(td.transform), y) + return (logpdf(td.dist, x) + logjac, logjac) end function logpdf_with_jac(td::MvTransformed, y::AbstractMatrix{<:Real}) - res = forward(inverse(td.transform), y) - return (logpdf(td.dist, res[1]) + res[2], res[2]) + x, logjac = with_logabsdet_jacobian(inverse(td.transform), y) + return (logpdf(td.dist, x) + logjac, logjac) end function logpdf_with_jac(td::MvTransformed{<:Dirichlet}, y::AbstractVector{<:Real}) T = eltype(y) ϵ = _eps(T) - res = forward(inverse(td.transform), y) - lp = logpdf(td.dist, mappedarray(x->x+ϵ, res[1])) + res[2] - return (lp, res[2]) + x, logjac = with_logabsdet_jacobian(inverse(td.transform), y) + lp = logpdf(td.dist, mappedarray(x->x+ϵ, x)) + logjac + return (lp, logjac) end # TODO: should eventually drop using `logpdf_with_trans` function logpdf_with_jac(td::MatrixTransformed, y::AbstractMatrix{<:Real}) - res = forward(inverse(td.transform), y) - return (logpdf_with_trans(td.dist, res[1], true), res[2]) + x, logjac = with_logabsdet_jacobian(inverse(td.transform), y) + return (logpdf_with_trans(td.dist, x, true), logjac) end """ From 3815505bdaeb08941f37bcb2512af251dcea53f7 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sat, 11 Dec 2021 01:57:40 +0100 Subject: [PATCH 08/29] Fixes regarding with_logabsdet_jacobian and inverse --- src/bijectors/composed.jl | 41 ++++++++++++---------- src/bijectors/exp_log.jl | 4 +-- src/bijectors/leaky_relu.jl | 8 ++--- src/bijectors/named_bijector.jl | 20 +++++------ src/bijectors/normalise.jl | 6 ++-- src/bijectors/permute.jl | 2 +- src/bijectors/planar_layer.jl | 2 +- src/bijectors/radial_layer.jl | 2 +- src/bijectors/rational_quadratic_spline.jl | 2 +- src/bijectors/shift.jl | 2 +- src/bijectors/stacked.jl | 20 +++++------ src/interface.jl | 8 ++--- test/interface.jl | 10 +++--- 13 files changed, 66 insertions(+), 61 deletions(-) diff --git a/src/bijectors/composed.jl b/src/bijectors/composed.jl index e6516181..29360690 100644 --- a/src/bijectors/composed.jl +++ b/src/bijectors/composed.jl @@ -153,7 +153,7 @@ end ∘(::Identity{N}, b::Bijector{N}) where {N} = b ∘(b::Bijector{N}, ::Identity{N}) where {N} = b -inverse(ct::Composed) = Composed(reverse(map(inv, ct.ts))) +InverseFunctions.inverse(ct::Composed) = Composed(reverse(map(inv, ct.ts))) # # TODO: should arrays also be using recursive implementation instead? function (cb::Composed{<:AbstractArray{<:Bijector}})(x) @@ -189,24 +189,27 @@ end N = length(T.parameters) expr = Expr(:block) - push!(expr.args, :((y, logjac) = forward(cb.ts[1], x))) - + sym_y, sym_ladj, sym_tmp_ladj = gensym(:y), gensym(:lady), gensym(:tmp_lady) + push!(expr.args, :(($sym_y, $sym_ladj) = with_logabsdet_jacobian(cb.ts[1], x))) + sym_last_y, sym_last_ladj = sym_y, sym_ladj for i = 2:N - 1 - temp = gensym(:res_logjac) - push!(expr.args, :(y, $temp = with_logabsdet_jacobian(cb.ts[$i], y))) - push!(expr.args, :(logjac += $temp)) + sym_y, sym_ladj, sym_tmp_ladj = gensym(:y), gensym(:lady), gensym(:tmp_lady) + push!(expr.args, :(($sym_y, $sym_tmp_ladj) = with_logabsdet_jacobian(cb.ts[$i], $sym_last_y))) + push!(expr.args, :($sym_ladj = $sym_tmp_ladj + $sym_last_ladj)) + sym_last_y, sym_last_ladj = sym_y, sym_ladj end # don't need to evaluate the last bijector, only it's `logabsdetjac` - push!(expr.args, :(logjac += logabsdetjac(cb.ts[$N], y))) - - push!(expr.args, :(return logjac)) + sym_ladj, sym_tmp_ladj = gensym(:lady), gensym(:tmp_lady) + push!(expr.args, :($sym_tmp_ladj = logabsdetjac(cb.ts[$N], $sym_last_y))) + push!(expr.args, :($sym_ladj = $sym_tmp_ladj + $sym_last_ladj)) + push!(expr.args, :(return $sym_ladj)) return expr end -function forward(cb::Composed, x) - rv, logjac = forward(cb.ts[1], x) +function ChangesOfVariables.with_logabsdet_jacobian(cb::Composed, x) + rv, logjac = with_logabsdet_jacobian(cb.ts[1], x) for t in cb.ts[2:end] rv, res_logjac = with_logabsdet_jacobian(t, rv) @@ -215,16 +218,18 @@ function forward(cb::Composed, x) return (rv, logjac) end - -@generated function forward(cb::Composed{T}, x) where {T<:Tuple} +@generated function ChangesOfVariables.with_logabsdet_jacobian(cb::Composed{T}, x) where {T<:Tuple} expr = Expr(:block) - push!(expr.args, :((y, logjac) = forward(cb.ts[1], x))) + sym_y, sym_ladj, sym_tmp_ladj = gensym(:y), gensym(:lady), gensym(:tmp_lady) + push!(expr.args, :(($sym_y, $sym_ladj) = with_logabsdet_jacobian(cb.ts[1], x))) + sym_last_y, sym_last_ladj = sym_y, sym_ladj for i = 2:length(T.parameters) - temp = gensym(:res_logjac) - push!(expr.args, :(y, $temp = with_logabsdet_jacobian(cb.ts[$i], y))) - push!(expr.args, :(logjac += $temp)) + sym_y, sym_ladj, sym_tmp_ladj = gensym(:y), gensym(:lady), gensym(:tmp_lady) + push!(expr.args, :(($sym_y, $sym_tmp_ladj) = with_logabsdet_jacobian(cb.ts[$i], $sym_last_y))) + push!(expr.args, :($sym_ladj = $sym_tmp_ladj + $sym_last_ladj)) + sym_last_y, sym_last_ladj = sym_y, sym_ladj end - push!(expr.args, :(return (y, logjac))) + push!(expr.args, :(return ($sym_y, $sym_ladj))) return expr end diff --git a/src/bijectors/exp_log.jl b/src/bijectors/exp_log.jl index 0f5f4683..aacaa901 100644 --- a/src/bijectors/exp_log.jl +++ b/src/bijectors/exp_log.jl @@ -27,8 +27,8 @@ Log() = Log{0}() (b::Exp{2})(y::AbstractArray{<:AbstractMatrix{<:Real}}) = map(b, y) (b::Log{2})(x::AbstractArray{<:AbstractMatrix{<:Real}}) = map(b, x) -inverse(b::Exp{N}) where {N} = Log{N}() -inverse(b::Log{N}) where {N} = Exp{N}() +InverseFunctions.inverse(b::Exp{N}) where {N} = Log{N}() +InverseFunctions.inverse(b::Log{N}) where {N} = Exp{N}() logabsdetjac(b::Exp{0}, x::Real) = x logabsdetjac(b::Exp{0}, x::AbstractVector) = x diff --git a/src/bijectors/leaky_relu.jl b/src/bijectors/leaky_relu.jl index c91e0faf..8cf769a7 100644 --- a/src/bijectors/leaky_relu.jl +++ b/src/bijectors/leaky_relu.jl @@ -31,7 +31,7 @@ function (b::LeakyReLU{<:Any, 0})(x::Real) end (b::LeakyReLU{<:Any, 0})(x::AbstractVector{<:Real}) = map(b, x) -function inverse(b::LeakyReLU{<:Any,N}) where N +function InverseFunctions.inverse(b::LeakyReLU{<:Any,N}) where N invα = inv.(b.α) return LeakyReLU{typeof(invα),N}(invα) end @@ -47,14 +47,14 @@ logabsdetjac(b::LeakyReLU{<:Real, 0}, x::AbstractVector{<:Real}) = map(x -> loga # We implement `with_logabsdet_jacobian` by hand since we can re-use the computation of # the Jacobian of the transformation. This will lead to faster sampling # when using `rand` on a `TransformedDistribution` making use of `LeakyReLU`. -function with_logabsdet_jacobian(b::LeakyReLU{<:Any, 0}, x::Real) +function ChangesOfVariables.with_logabsdet_jacobian(b::LeakyReLU{<:Any, 0}, x::Real) mask = x < zero(x) J = mask * b.α + !mask * one(x) return (J * x, log(abs(J))) end # Batched version -function with_logabsdet_jacobian(b::LeakyReLU{<:Any, 0}, x::AbstractVector) +function ChangesOfVariables.with_logabsdet_jacobian(b::LeakyReLU{<:Any, 0}, x::AbstractVector) J = let T = eltype(x), z = zero(T), o = one(T) @. (x < z) * b.α + (x > z) * o end @@ -84,7 +84,7 @@ end # We implement `forward` by hand since we can re-use the computation of # the Jacobian of the transformation. This will lead to faster sampling # when using `rand` on a `TransformedDistribution` making use of `LeakyReLU`. -function with_logabsdet_jacobian(b::LeakyReLU{<:Any, 1}, x::AbstractVecOrMat) +function ChangesOfVariables.with_logabsdet_jacobian(b::LeakyReLU{<:Any, 1}, x::AbstractVecOrMat) # Is really diagonal of jacobian J = let T = eltype(x), z = zero(T), o = one(T) @. (x < z) * b.α + (x > z) * o diff --git a/src/bijectors/named_bijector.jl b/src/bijectors/named_bijector.jl index 45af6b95..4ab972d5 100644 --- a/src/bijectors/named_bijector.jl +++ b/src/bijectors/named_bijector.jl @@ -1,6 +1,6 @@ abstract type AbstractNamedBijector <: AbstractBijector end -with_logabsdet_jacobian(b::AbstractNamedBijector, x) = (b(x), logabsdetjac(b, x)) +ChangesOfVariables.with_logabsdet_jacobian(b::AbstractNamedBijector, x) = (b(x), logabsdetjac(b, x)) ####################### ### `NamedBijector` ### @@ -55,7 +55,7 @@ names_to_bijectors(b::NamedBijector) = b.bs return :($(exprs...), ) end -@generated function inverse(b::NamedBijector{names}) where {names} +@generated function InverseFunctions.inverse(b::NamedBijector{names}) where {names} return :(NamedBijector(($([:($n = inverse(b.bs.$n)) for n in names]...), ))) end @@ -78,8 +78,8 @@ See also: [`Inverse`](@ref) struct NamedInverse{B<:AbstractNamedBijector} <: AbstractNamedBijector orig::B end -inverse(nb::AbstractNamedBijector) = NamedInverse(nb) -inverse(ni::NamedInverse) = ni.orig +InverseFunctions.inverse(nb::AbstractNamedBijector) = NamedInverse(nb) +InverseFunctions.inverse(ni::NamedInverse) = ni.orig logabsdetjac(ni::NamedInverse, y::NamedTuple) = -logabsdetjac(inverse(ni), ni(y)) @@ -107,7 +107,7 @@ composel(bs::AbstractNamedBijector...) = NamedComposition(bs) composer(bs::AbstractNamedBijector...) = NamedComposition(reverse(bs)) ∘(b1::AbstractNamedBijector, b2::AbstractNamedBijector) = composel(b2, b1) -inverse(ct::NamedComposition) = NamedComposition(reverse(map(inv, ct.bs))) +InverseFunctions.inverse(ct::NamedComposition) = NamedComposition(reverse(map(inv, ct.bs))) function (cb::NamedComposition{<:AbstractArray{<:AbstractNamedBijector}})(x) @assert length(cb.bs) > 0 @@ -135,7 +135,7 @@ end N = length(T.parameters) expr = Expr(:block) - push!(expr.args, :((y, logjac) = forward(cb.bs[1], x))) + push!(expr.args, :((y, logjac) = with_logabsdet_jacobian(cb.bs[1], x))) for i = 2:N - 1 temp = gensym(:res_logjac) @@ -151,8 +151,8 @@ end end -function forward(cb::NamedComposition, x) - rv, logjac = forward(cb.bs[1], x) +function ChangesOfVariables.with_logabsdet_jacobian(cb::NamedComposition, x) + rv, logjac = with_logabsdet_jacobian(cb.bs[1], x) for t in cb.bs[2:end] rv, res_logjac = with_logabsdet_jacobian(t, rv) @@ -162,9 +162,9 @@ function forward(cb::NamedComposition, x) end -@generated function forward(cb::NamedComposition{T}, x) where {T<:Tuple} +@generated function ChangesOfVariables.with_logabsdet_jacobian(cb::NamedComposition{T}, x) where {T<:Tuple} expr = Expr(:block) - push!(expr.args, :((y, logjac) = forward(cb.bs[1], x))) + push!(expr.args, :((y, logjac) = with_logabsdet_jacobian(cb.bs[1], x))) for i = 2:length(T.parameters) temp = gensym(:res_logjac) push!(expr.args, :(y, $temp = with_logabsdet_jacobian(cb.bs[$i], y))) diff --git a/src/bijectors/normalise.jl b/src/bijectors/normalise.jl index fcbaff91..1bd7851b 100644 --- a/src/bijectors/normalise.jl +++ b/src/bijectors/normalise.jl @@ -48,7 +48,7 @@ function Functors.functor(::Type{<:InvertibleBatchNorm}, x) return (b = x.b, logs = x.logs), reconstruct_invertiblebatchnorm end -function with_logabsdet_jacobian(bn::InvertibleBatchNorm, x) +function ChangesOfVariables.with_logabsdet_jacobian(bn::InvertibleBatchNorm, x) dims = ndims(x) size(x, dims - 1) == length(bn.b) || error("InvertibleBatchNorm expected $(length(bn.b)) channels, got $(size(x, dims - 1))") @@ -83,8 +83,8 @@ logabsdetjac(bn::InvertibleBatchNorm, x) = last(with_logabsdet_jacobian(bn, x)) (bn::InvertibleBatchNorm)(x) = first(with_logabsdet_jacobian(bn, x)) -function forward(invbn::Inverse{<:InvertibleBatchNorm}, y) - @assert !istraining() "`forward(::Inverse{InvertibleBatchNorm})` is only available in test mode." +function ChangesOfVariables.with_logabsdet_jacobian(invbn::Inverse{<:InvertibleBatchNorm}, y) + @assert !istraining() "`with_logabsdet_jacobian(::Inverse{InvertibleBatchNorm})` is only available in test mode." dims = ndims(y) as = ntuple(i -> i == ndims(y) - 1 ? size(y, i) : 1, dims) bn = inverse(invbn) diff --git a/src/bijectors/permute.jl b/src/bijectors/permute.jl index 5ba9e5cf..b8a7c30f 100644 --- a/src/bijectors/permute.jl +++ b/src/bijectors/permute.jl @@ -151,7 +151,7 @@ end @inline (b::Permute)(x::AbstractVecOrMat) = b.A * x -@inline inverse(b::Permute) = Permute(transpose(b.A)) +@inline InverseFunctions.inverse(b::Permute) = Permute(transpose(b.A)) logabsdetjac(b::Permute, x::AbstractVector) = zero(eltype(x)) logabsdetjac(b::Permute, x::AbstractMatrix) = zero(eltype(x), size(x, 2)) diff --git a/src/bijectors/planar_layer.jl b/src/bijectors/planar_layer.jl index 1f6c022e..8b24fb0f 100644 --- a/src/bijectors/planar_layer.jl +++ b/src/bijectors/planar_layer.jl @@ -94,7 +94,7 @@ we get \\log |det ∂f(z)/∂z| = \\log(1 + sech²(wᵀz + b) wᵀû). ``` =# -function forward(flow::PlanarLayer, z::AbstractVecOrMat{<:Real}) +function ChangesOfVariables.with_logabsdet_jacobian(flow::PlanarLayer, z::AbstractVecOrMat{<:Real}) transformed, wT_û, wT_z = _transform(flow, z) # Compute ``\\log |det ∂f(z)/∂z|`` (see above). diff --git a/src/bijectors/radial_layer.jl b/src/bijectors/radial_layer.jl index f2d656e7..809c8fd6 100644 --- a/src/bijectors/radial_layer.jl +++ b/src/bijectors/radial_layer.jl @@ -49,7 +49,7 @@ end (b::RadialLayer)(z::AbstractMatrix{<:Real}) = _transform(b, z).transformed (b::RadialLayer)(z::AbstractVector{<:Real}) = vec(_transform(b, z).transformed) -function forward(flow::RadialLayer, z::AbstractVecOrMat) +function ChangesOfVariables.with_logabsdet_jacobian(flow::RadialLayer, z::AbstractVecOrMat) transformed, α, β_hat, r = _transform(flow, z) # Compute log_det_jacobian d = size(flow.z_0, 1) diff --git a/src/bijectors/rational_quadratic_spline.jl b/src/bijectors/rational_quadratic_spline.jl index ef34c436..f9d3e59b 100644 --- a/src/bijectors/rational_quadratic_spline.jl +++ b/src/bijectors/rational_quadratic_spline.jl @@ -379,6 +379,6 @@ function rqs_forward( return (y, logjac) end -function with_logabsdet_jacobian(b::RationalQuadraticSpline{<:AbstractVector, 0}, x::Real) +function ChangesOfVariables.with_logabsdet_jacobian(b::RationalQuadraticSpline{<:AbstractVector, 0}, x::Real) return rqs_forward(b.widths, b.heights, b.derivatives, x) end diff --git a/src/bijectors/shift.jl b/src/bijectors/shift.jl index e4e9960c..ad7816f2 100644 --- a/src/bijectors/shift.jl +++ b/src/bijectors/shift.jl @@ -24,7 +24,7 @@ up1(b::Shift{T, N}) where {T, N} = Shift{T, N + 1}(b.a) (b::Shift)(x) = b.a .+ x (b::Shift{<:Any, 2})(x::AbstractArray{<:AbstractMatrix}) = map(b, x) -inverse(b::Shift{T, N}) where {T, N} = Shift{T, N}(-b.a) +InverseFunctions.inverse(b::Shift{T, N}) where {T, N} = Shift{T, N}(-b.a) # FIXME: implement custom adjoint to ensure we don't get tracking logabsdetjac(b::Shift{T, N}, x) where {T, N} = _logabsdetjac_shift(b.a, x, Val(N)) diff --git a/src/bijectors/stacked.jl b/src/bijectors/stacked.jl index 9a49c34d..31a52d77 100644 --- a/src/bijectors/stacked.jl +++ b/src/bijectors/stacked.jl @@ -51,10 +51,10 @@ isclosedform(b::Stacked) = all(isclosedform, b.bs) stack(bs::Bijector{0}...) = Stacked(bs) # For some reason `inverse.(sb.bs)` was unstable... This works though. -inverse(sb::Stacked) = Stacked(map(inverse, sb.bs), sb.ranges) +InverseFunctions.inverse(sb::Stacked) = Stacked(map(inverse, sb.bs), sb.ranges) # map is not type stable for many stacked bijectors as a large tuple # hence the generated function -@generated function inverse(sb::Stacked{A}) where {A <: Tuple} +@generated function InverseFunctions.inverse(sb::Stacked{A}) where {A <: Tuple} exprs = [] for i = 1:length(A.parameters) push!(exprs, :(inverse(sb.bs[$i]))) @@ -132,23 +132,23 @@ end # Generates something similar to: # # quote -# (y_1, _logjac) = forward(b.bs[1], x[b.ranges[1]]) +# (y_1, _logjac) = with_logabsdet_jacobian(b.bs[1], x[b.ranges[1]]) # logjac = sum(_logjac) -# (y_2, _logjac) = forward(b.bs[2], x[b.ranges[2]]) +# (y_2, _logjac) = with_logabsdet_jacobian(b.bs[2], x[b.ranges[2]]) # logjac += sum(_logjac) # return (vcat(y_1, y_2), logjac) # end -@generated function forward(b::Stacked{<:Tuple{Vararg{<:Any, N}}, <:Tuple{Vararg{<:Any, N}}}, x::AbstractVector) where {N} +@generated function ChangesOfVariables.with_logabsdet_jacobian(b::Stacked{<:Tuple{Vararg{<:Any, N}}, <:Tuple{Vararg{<:Any, N}}}, x::AbstractVector) where {N} expr = Expr(:block) y_names = [] - push!(expr.args, :((y_1, _logjac) = forward(b.bs[1], x[b.ranges[1]]))) + push!(expr.args, :((y_1, _logjac) = with_logabsdet_jacobian(b.bs[1], x[b.ranges[1]]))) # TODO: drop the `sum` when we have dimensionality push!(expr.args, :(logjac = sum(_logjac))) push!(y_names, :y_1) for i = 2:N y_name = Symbol("y_$i") - push!(expr.args, :(($y_name, _logjac) = forward(b.bs[$i], x[b.ranges[$i]]))) + push!(expr.args, :(($y_name, _logjac) = with_logabsdet_jacobian(b.bs[$i], x[b.ranges[$i]]))) # TODO: drop the `sum` when we have dimensionality push!(expr.args, :(logjac += sum(_logjac))) @@ -160,12 +160,12 @@ end return expr end -function forward(sb::Stacked, x::AbstractVector) +function ChangesOfVariables.with_logabsdet_jacobian(sb::Stacked, x::AbstractVector) N = length(sb.bs) - yinit, linit = forward(sb.bs[1], x[sb.ranges[1]]) + yinit, linit = with_logabsdet_jacobian(sb.bs[1], x[sb.ranges[1]]) logjac = sum(linit) ys = mapvcat(drop(sb.bs, 1), drop(sb.ranges, 1)) do b, r - y, l = forward(b, x[r]) + y, l = with_logabsdet_jacobian(b, x[r]) logjac += sum(l) y end diff --git a/src/interface.jl b/src/interface.jl index 276ce631..80c2a140 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -72,8 +72,8 @@ Functors.@functor Inverse up1(b::Inverse) = Inverse(up1(b.orig)) -inverse(b::Bijector) = Inverse(b) -inverse(ib::Inverse{<:Bijector}) = ib.orig +InverseFunctions.inverse(b::Bijector) = Inverse(b) +InverseFunctions.inverse(ib::Inverse{<:Bijector}) = ib.orig Base.:(==)(b1::Inverse{<:Bijector}, b2::Inverse{<:Bijector}) = b1.orig == b2.orig """ @@ -99,7 +99,7 @@ in the computation of the forward pass and the computation of the `logabsdetjac`. `forward` allows the user to take advantange of such efficiencies, if they exist. """ -with_logabsdet_jacobian(b::Bijector, x) = (b(x), logabsdetjac(b, x)) +ChangesOfVariables.with_logabsdet_jacobian(b::Bijector, x) = (b(x), logabsdetjac(b, x)) """ logabsdetjacinv(b::Bijector, y) @@ -114,7 +114,7 @@ logabsdetjacinv(b::Bijector, y) = logabsdetjac(inverse(b), y) struct Identity{N} <: Bijector{N} end (::Identity)(x) = copy(x) -inverse(b::Identity) = b +InverseFunctions.inverse(b::Identity) = b up1(::Identity{N}) where {N} = Identity{N + 1}() logabsdetjac(::Identity{0}, x::Real) = zero(eltype(x)) diff --git a/test/interface.jl b/test/interface.jl index 36a039fc..77e66c06 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -187,11 +187,11 @@ end xs_ = @inferred ib(ys) @inferred(ib(param(ys))) - result = @inferred forward(b, x) - results = @inferred forward(b, xs) + result = @inferred with_logabsdet_jacobian(b, x) + results = @inferred with_logabsdet_jacobian(b, xs) - iresult = @inferred forward(ib, y) - iresults = @inferred forward(ib, ys) + iresult = @inferred with_logabsdet_jacobian(ib, y) + iresults = @inferred with_logabsdet_jacobian(ib, ys) # Sizes @test size(y) == size(x) @@ -748,7 +748,7 @@ end x = [.5, 1.] @test sb(x) == x @test logabsdetjac(sb, x) == 0 - @test forward(sb, x) == (x, zero(eltype(x))) + @test with_logabsdet_jacobian(sb, x) == (x, zero(eltype(x))) end end From 2b935603f189077d8294c127fc5b8d366b23994a Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sat, 11 Dec 2021 20:30:28 +0100 Subject: [PATCH 09/29] Fix with_logabsdet_jacobian for NamedComposition --- src/bijectors/named_bijector.jl | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/bijectors/named_bijector.jl b/src/bijectors/named_bijector.jl index 4ab972d5..7eb71c14 100644 --- a/src/bijectors/named_bijector.jl +++ b/src/bijectors/named_bijector.jl @@ -164,13 +164,17 @@ end @generated function ChangesOfVariables.with_logabsdet_jacobian(cb::NamedComposition{T}, x) where {T<:Tuple} expr = Expr(:block) - push!(expr.args, :((y, logjac) = with_logabsdet_jacobian(cb.bs[1], x))) + + sym_y, sym_ladj, sym_tmp_ladj = gensym(:y), gensym(:lady), gensym(:tmp_lady) + push!(expr.args, :(($sym_y, $sym_ladj) = with_logabsdet_jacobian(cb.bs[1], x))) + sym_last_y, sym_last_ladj = sym_y, sym_ladj for i = 2:length(T.parameters) - temp = gensym(:res_logjac) - push!(expr.args, :(y, $temp = with_logabsdet_jacobian(cb.bs[$i], y))) - push!(expr.args, :(logjac += $temp)) + sym_y, sym_ladj, sym_tmp_ladj = gensym(:y), gensym(:lady), gensym(:tmp_lady) + push!(expr.args, :(($sym_y, $sym_tmp_ladj) = with_logabsdet_jacobian(cb.bs[$i], $sym_last_y))) + push!(expr.args, :($sym_ladj = $sym_tmp_ladj + $sym_last_ladj)) + sym_last_y, sym_last_ladj = sym_y, sym_ladj end - push!(expr.args, :(return (y, logjac))) + push!(expr.args, :(return ($sym_y, $sym_ladj))) return expr end From 20e50d45aa6343ab7fd6a949ade4732fb73ee72c Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sat, 11 Dec 2021 21:14:27 +0100 Subject: [PATCH 10/29] Fix deprecation of inv --- src/Bijectors.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Bijectors.jl b/src/Bijectors.jl index b3a5f710..38ad4dbd 100644 --- a/src/Bijectors.jl +++ b/src/Bijectors.jl @@ -255,7 +255,7 @@ include("chainrules.jl") Base.@deprecate forward(b::AbstractBijector, x) NamedTuple{(:rv,:logabsdetjac)}(with_logabsdet_jacobian(b, x)) @noinline function Base.inv(b::AbstractBijector) - Base.depwarn("`Base.inv(b::AbstractBijector)` is deprecated, use `InverseFunctions.inverse(b)` instead.", :(Base.inv)) + Base.depwarn("`Base.inv(b::AbstractBijector)` is deprecated, use `InverseFunctions.inverse(b)` instead.", :inv) inverse(b) end From fc990bb416f6c3f1f3d395dedf84509c0f396c8d Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sat, 11 Dec 2021 21:15:04 +0100 Subject: [PATCH 11/29] Use inverse instead of inv for Composed --- src/bijectors/composed.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/bijectors/composed.jl b/src/bijectors/composed.jl index 29360690..02dd5b6f 100644 --- a/src/bijectors/composed.jl +++ b/src/bijectors/composed.jl @@ -153,7 +153,7 @@ end ∘(::Identity{N}, b::Bijector{N}) where {N} = b ∘(b::Bijector{N}, ::Identity{N}) where {N} = b -InverseFunctions.inverse(ct::Composed) = Composed(reverse(map(inv, ct.ts))) +InverseFunctions.inverse(ct::Composed) = Composed(reverse(map(inverse, ct.ts))) # # TODO: should arrays also be using recursive implementation instead? function (cb::Composed{<:AbstractArray{<:Bijector}})(x) From 625fbb7e952bd7697e3a256f32a23e45b1a41fab Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sat, 11 Dec 2021 21:17:58 +0100 Subject: [PATCH 12/29] Use with_logabsdet_jacobian instead of forward --- src/transformed_distribution.jl | 2 +- test/bijectors/coupling.jl | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformed_distribution.jl b/src/transformed_distribution.jl index 10b9ff9d..a22c6400 100644 --- a/src/transformed_distribution.jl +++ b/src/transformed_distribution.jl @@ -232,7 +232,7 @@ function _forward(d::Distribution, x) end function _forward(td::Transformed, x) - y, logjac = forward(td.transform, x) + y, logjac = with_logabsdet_jacobian(td.transform, x) return ( x = x, y = y, diff --git a/test/bijectors/coupling.jl b/test/bijectors/coupling.jl index 298eab47..38ebd763 100644 --- a/test/bijectors/coupling.jl +++ b/test/bijectors/coupling.jl @@ -44,9 +44,9 @@ using Bijectors: # logabsdetjac @test logabsdetjac(cl1, x) == logabsdetjac(b, x[1:1]) - # forward - @test forward(cl1, x) == (cl1(x), logabsdetjac(cl1, x)) - @test forward(icl1, cl1(x)) == (x, - logabsdetjac(cl1, x)) + # with_logabsdet_jacobian + @test with_logabsdet_jacobian(cl1, x) == (cl1(x), logabsdetjac(cl1, x)) + @test with_logabsdet_jacobian(icl1, cl1(x)) == (x, - logabsdetjac(cl1, x)) end @testset "Classic" begin From 8273628bd5ee55ab2ff9cf8b527e32716f185f30 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sat, 11 Dec 2021 22:56:01 +0100 Subject: [PATCH 13/29] Workaround for intermittent failures in Dirichlet test --- test/transform.jl | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/test/transform.jl b/test/transform.jl index 6d1fddeb..e119c1b2 100644 --- a/test/transform.jl +++ b/test/transform.jl @@ -145,7 +145,15 @@ let ϵ = eps(Float64) single_sample_tests(dist) # This should fail at the minute. Not sure what the correct way to test this is. - x = rand(dist) + + # Workaround for intermittent test failures, result of `logpdf_with_trans(dist, x, true)` + # is incorrect for `x == [0.9999999999999998, 0.0]`: + x = if params(dist) == params(Dirichlet([1000 * one(Float64), eps(Float64)])) + [1.0, 0.0] + else + rand(dist) + end + logpdf_turing = logpdf_with_trans(dist, x, true) J = jacobian(x->link(dist, x, Val(false)), x) @test logpdf(dist, x .+ ϵ) - _logabsdet(J) ≈ logpdf_turing From 167a26e3158282d2aab4a95e5dd697d474e4f38c Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sat, 11 Dec 2021 23:24:27 +0100 Subject: [PATCH 14/29] Use with_logabsdet_jacobian instead of forward --- test/bijectors/permute.jl | 10 +++++----- test/bijectors/utils.jl | 8 ++++---- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/test/bijectors/permute.jl b/test/bijectors/permute.jl index 6602bb03..49e90a82 100644 --- a/test/bijectors/permute.jl +++ b/test/bijectors/permute.jl @@ -50,16 +50,16 @@ using Bijectors: Permute @test logabsdetjac(b3, x) == 0.0 @test logabsdetjac(b4, x) == 0.0 - # forward - y, logjac = forward(b1, x) + # with_logabsdet_jacobian + y, logjac = with_logabsdet_jacobian(b1, x) @test (y == b1(x)) & (logjac == 0.0) - y, logjac = forward(b2, x) + y, logjac = with_logabsdet_jacobian(b2, x) @test (y == b2(x)) & (logjac == 0.0) - y, logjac = forward(b3, x) + y, logjac = with_logabsdet_jacobian(b3, x) @test (y == b3(x)) & (logjac == 0.0) - y, logjac = forward(b4, x) + y, logjac = with_logabsdet_jacobian(b4, x) @test (y == b4(x)) & (logjac == 0.0) end diff --git a/test/bijectors/utils.jl b/test/bijectors/utils.jl index ce2d03c3..85b7a1d9 100644 --- a/test/bijectors/utils.jl +++ b/test/bijectors/utils.jl @@ -10,11 +10,11 @@ function test_bijector_reals( y = @inferred b(x_true) logjac = @inferred logabsdetjac(b, x_true) ilogjac = @inferred logabsdetjac(ib, y_true) - res = @inferred forward(b, x_true) + res = @inferred with_logabsdet_jacobian(b, x_true) # If `isequal` is false, then we use the computed `y`, # but if it's true, we use the true `y`. - ires = isequal ? @inferred(forward(inverse(b), y_true)) : @inferred(forward(inverse(b), y)) + ires = isequal ? @inferred(with_logabsdet_jacobian(inverse(b), y_true)) : @inferred(with_logabsdet_jacobian(inverse(b), y)) # Always want the following to hold @test ires[1] ≈ x_true atol=tol @@ -46,10 +46,10 @@ function test_bijector_arrays( ib = @inferred inverse(b) ys = @inferred b(xs_true) logjacs = @inferred logabsdetjac(b, xs_true) - res = @inferred forward(b, xs_true) + res = @inferred with_logabsdet_jacobian(b, xs_true) # If `isequal` is false, then we use the computed `y`, # but if it's true, we use the true `y`. - ires = isequal ? @inferred(forward(inverse(b), ys_true)) : @inferred(forward(inverse(b), ys)) + ires = isequal ? @inferred(with_logabsdet_jacobian(inverse(b), ys_true)) : @inferred(with_logabsdet_jacobian(inverse(b), ys)) # always want the following to hold @test ys isa typeof(ys_true) From 8a0c6581558f35c68963dd2159426fe74e35106b Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sat, 11 Dec 2021 23:37:53 +0100 Subject: [PATCH 15/29] Use with_logabsdet_jacobian instead of forward --- src/transformed_distribution.jl | 4 ++-- test/bijectors/leaky_relu.jl | 8 ++++---- test/interface.jl | 28 ++++++++++++++-------------- test/norm_flows.jl | 10 +++++----- 4 files changed, 25 insertions(+), 25 deletions(-) diff --git a/src/transformed_distribution.jl b/src/transformed_distribution.jl index a22c6400..f51013c3 100644 --- a/src/transformed_distribution.jl +++ b/src/transformed_distribution.jl @@ -218,7 +218,7 @@ end const GLOBAL_RNG = Distributions.GLOBAL_RNG function _forward(d::UnivariateDistribution, x) - y, logjac = forward(Identity{0}(), x) + y, logjac = with_logabsdet_jacobian(Identity{0}(), x) return (x = x, y = y, logabsdetjac = logjac, logpdf = logpdf.(d, x)) end @@ -227,7 +227,7 @@ function forward(rng::AbstractRNG, d::Distribution, num_samples::Int) return _forward(d, rand(rng, d, num_samples)) end function _forward(d::Distribution, x) - y, logjac = forward(Identity{length(size(d))}(), x) + y, logjac = with_logabsdet_jacobian(Identity{length(size(d))}(), x) return (x = x, y = y, logabsdetjac = logjac, logpdf = logpdf(d, x)) end diff --git a/test/bijectors/leaky_relu.jl b/test/bijectors/leaky_relu.jl index d06a112a..e110a046 100644 --- a/test/bijectors/leaky_relu.jl +++ b/test/bijectors/leaky_relu.jl @@ -40,11 +40,11 @@ true_logabsdetjac(b::Bijector{1}, xs::AbstractMatrix) = mapreduce(z -> true_loga @test logabsdetjac(b, -Float32.(xs)) == true_logabsdetjac(b, -Float32.(xs)) # Forward - f = forward(b, xs) + f = with_logabsdet_jacobian(b, xs) @test f[2] ≈ logabsdetjac(b, xs) @test f[1] ≈ b(xs) - f = forward(b, Float32.(xs)) + f = with_logabsdet_jacobian(b, Float32.(xs)) @test f[2] == logabsdetjac(b, Float32.(xs)) @test f[1] ≈ b(Float32.(xs)) end @@ -66,11 +66,11 @@ end @test logabsdetjac(b, -Float32.(xs)) == true_logabsdetjac(b, -Float32.(xs)) # Forward - f = forward(b, xs) + f = with_logabsdet_jacobian(b, xs) @test f[2] ≈ logabsdetjac(b, xs) @test f[1] ≈ b(xs) - f = forward(b, Float32.(xs)) + f = with_logabsdet_jacobian(b, Float32.(xs)) @test f[2] == logabsdetjac(b, Float32.(xs)) @test f[1] ≈ b(Float32.(xs)) diff --git a/test/interface.jl b/test/interface.jl index 77e66c06..17733a36 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -510,10 +510,10 @@ end b = @inferred Bijectors.composel(td.transform, Bijectors.Identity{0}()) ib = @inferred inverse(b) - @test forward(b, x) == forward(td.transform, x) - @test forward(ib, y) == forward(inverse(td.transform), y) + @test with_logabsdet_jacobian(b, x) == with_logabsdet_jacobian(td.transform, x) + @test with_logabsdet_jacobian(ib, y) == with_logabsdet_jacobian(inverse(td.transform), y) - @test forward(b, x) == forward(Bijectors.composer(b.ts...), x) + @test with_logabsdet_jacobian(b, x) == with_logabsdet_jacobian(Bijectors.composer(b.ts...), x) # inverse works fine for composition cb = @inferred b ∘ ib @@ -547,10 +547,10 @@ end x = rand(d) cb_t = b⁻¹ ∘ b⁻¹ - f_t = forward(cb_t, x) + f_t = with_logabsdet_jacobian(cb_t, x) cb_a = Composed([b⁻¹, b⁻¹]) - f_a = forward(cb_a, x) + f_a = with_logabsdet_jacobian(cb_a, x) @test f_t == f_a @@ -572,7 +572,7 @@ end y = b(x) sb1 = @inferred stack(b, b, inverse(b), inverse(b)) # <= Tuple - res1 = forward(sb1, [x, x, y, y]) + res1 = with_logabsdet_jacobian(sb1, [x, x, y, y]) @test sb1(param([x, x, y, y])) isa TrackedArray @test sb1([x, x, y, y]) ≈ res1[1] @@ -580,7 +580,7 @@ end @test res1[2] ≈ 0 atol=1e-6 sb2 = Stacked([b, b, inverse(b), inverse(b)]) # <= Array - res2 = forward(sb2, [x, x, y, y]) + res2 = with_logabsdet_jacobian(sb2, [x, x, y, y]) @test sb2(param([x, x, y, y])) isa TrackedArray @test sb2([x, x, y, y]) ≈ res2[1] @@ -592,7 +592,7 @@ end y = b(x) sb1 = stack(b, b, inverse(b), inverse(b)) # <= Tuple - res1 = forward(sb1, [x, x, y, y]) + res1 = with_logabsdet_jacobian(sb1, [x, x, y, y]) @test sb1(param([x, x, y, y])) isa TrackedArray @test sb1([x, x, y, y]) == res1[1] @@ -600,7 +600,7 @@ end @test res1[2] ≈ 0.0 atol=1e-12 sb2 = Stacked([b, b, inverse(b), inverse(b)]) # <= Array - res2 = forward(sb2, [x, x, y, y]) + res2 = with_logabsdet_jacobian(sb2, [x, x, y, y]) @test sb2(param([x, x, y, y])) isa TrackedArray @test sb2([x, x, y, y]) == res2[1] @@ -610,7 +610,7 @@ end # value-test x = ones(3) sb = @inferred stack(Bijectors.Exp(), Bijectors.Log(), Bijectors.Shift(5.0)) - res = forward(sb, x) + res = with_logabsdet_jacobian(sb, x) @test sb(param(x)) isa TrackedArray @test sb(x) == [exp(x[1]), log(x[2]), x[3] + 5.0] @test res[1] == [exp(x[1]), log(x[2]), x[3] + 5.0] @@ -621,7 +621,7 @@ end # TODO: change when we have dimensionality in the type sb = @inferred Stacked((Bijectors.Exp(), Bijectors.SimplexBijector()), (1:1, 2:3)) x = ones(3) ./ 3.0 - res = @inferred forward(sb, x) + res = @inferred with_logabsdet_jacobian(sb, x) @test sb(param(x)) isa TrackedArray @test sb(x) == [exp(x[1]), sb.bs[2](x[2:3])...] @test res[1] == [exp(x[1]), sb.bs[2](x[2:3])...] @@ -634,7 +634,7 @@ end # Array-version sb = Stacked([Bijectors.Exp(), Bijectors.SimplexBijector()], [1:1, 2:3]) x = ones(3) ./ 3.0 - res = forward(sb, x) + res = with_logabsdet_jacobian(sb, x) @test sb(param(x)) isa TrackedArray @test sb(x) == [exp(x[1]), sb.bs[2](x[2:3])...] @test res[1] == [exp(x[1]), sb.bs[2](x[2:3])...] @@ -648,7 +648,7 @@ end # Tuple, Array sb = Stacked([Bijectors.Exp(), Bijectors.SimplexBijector()], (1:1, 2:3)) x = ones(3) ./ 3.0 - res = forward(sb, x) + res = with_logabsdet_jacobian(sb, x) @test sb(param(x)) isa TrackedArray @test sb(x) == [exp(x[1]), sb.bs[2](x[2:3])...] @test res[1] == [exp(x[1]), sb.bs[2](x[2:3])...] @@ -661,7 +661,7 @@ end # Array, Tuple sb = Stacked((Bijectors.Exp(), Bijectors.SimplexBijector()), [1:1, 2:3]) x = ones(3) ./ 3.0 - res = forward(sb, x) + res = with_logabsdet_jacobian(sb, x) @test sb(param(x)) isa TrackedArray @test sb(x) == [exp(x[1]), sb.bs[2](x[2:3])...] @test res[1] == [exp(x[1]), sb.bs[2](x[2:3])...] diff --git a/test/norm_flows.jl b/test/norm_flows.jl index 38ea7e5f..fdf79676 100644 --- a/test/norm_flows.jl +++ b/test/norm_flows.jl @@ -11,10 +11,10 @@ seed!(1) @test inverse(inverse(bn)) == bn @test inverse(bn)(bn(x)) ≈ x @test (inverse(bn) ∘ bn)(x) ≈ x - @test_throws ErrorException forward(bn, randn(10,2)) + @test_throws ErrorException with_logabsdet_jacobian(bn, randn(10,2)) @test logabsdetjac(inverse(bn), bn(x)) ≈ - logabsdetjac(bn, x) - y, ladj = forward(bn, x) + y, ladj = with_logabsdet_jacobian(bn, x) @test log(abs(det(ForwardDiff.jacobian(bn, x)))) ≈ sum(ladj) @test log(abs(det(ForwardDiff.jacobian(inverse(bn), y)))) ≈ sum(logabsdetjac(inverse(bn), y)) @@ -26,7 +26,7 @@ end flow = PlanarLayer(2) z = randn(2, 20) forward_diff = log(abs(det(ForwardDiff.jacobian(t -> flow(t), z)))) - our_method = sum(forward(flow, z)[2]) + our_method = sum(with_logabsdet_jacobian(flow, z)[2]) @test our_method ≈ forward_diff @test inverse(flow)(flow(z)) ≈ z @@ -74,7 +74,7 @@ end flow = RadialLayer(2) z = randn(2, 20) forward_diff = log(abs(det(ForwardDiff.jacobian(t -> flow(t), z)))) - our_method = sum(forward(flow, z)[2]) + our_method = sum(with_logabsdet_jacobian(flow, z)[2]) @test our_method ≈ forward_diff @test inverse(flow)(flow(z)) ≈ z rtol=0.2 @@ -102,7 +102,7 @@ end x = rand(d) y = flow.transform(x) - res = forward(flow.transform, x) + res = with_logabsdet_jacobian(flow.transform, x) lp = logpdf_forward(flow, x, res[2]) @test res[1] ≈ y From 34f7fc697259f7d967bc70bc1a0c5c2e082d5b81 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sun, 12 Dec 2021 15:49:36 +0100 Subject: [PATCH 16/29] Add rrules for combine with PartitionMask Zygote-generated pullback for `combine(m::PartitionMask, x_1, x_2, x_3)` fails with `no method matching zero(::Type{Nothing})`. --- src/bijectors/coupling.jl | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/bijectors/coupling.jl b/src/bijectors/coupling.jl index 03b00ba5..8ed89639 100644 --- a/src/bijectors/coupling.jl +++ b/src/bijectors/coupling.jl @@ -118,6 +118,20 @@ Combines `x_1`, `x_2`, and `x_3` into a single vector. """ @inline combine(m::PartitionMask, x_1, x_2, x_3) = m.A_1 * x_1 .+ m.A_2 * x_2 .+ m.A_3 * x_3 +function ChainRulesCore.rrule(::typeof(combine), m::PartitionMask, x_1, x_2, x_3) + prj = map(ChainRulesCore.ProjectTo, (x_1, x_2, x_3)) + + function _transform_ordered_adjoint(ΔΩ) + Δ = ChainRulesCore.unthunk(ΔΩ) + dx_1, dx_2, dx_3 = partition(m, Δ) + + return ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent(), prj[1](dx_1), prj[2](dx_2), prj[3](dx_3) + end + + return combine(m, x_1, x_2, x_3), _transform_ordered_adjoint +end + + """ partition(m::PartitionMask, x) @@ -195,6 +209,7 @@ function couple(cl::Coupling, x::AbstractVector) return b end +#!!!!!!!!!!!!! function (cl::Coupling)(x::AbstractVector) # partition vector using `cl.mask::PartitionMask` x_1, x_2, x_3 = partition(cl.mask, x) From 2f1c36dfbe88d9a6ac906e2060e8a43364845013 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sun, 12 Dec 2021 16:37:15 +0100 Subject: [PATCH 17/29] Use inv instead of inverse for numbers --- src/compat/reversediff.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/compat/reversediff.jl b/src/compat/reversediff.jl index 5c11a4db..116d8531 100644 --- a/src/compat/reversediff.jl +++ b/src/compat/reversediff.jl @@ -60,7 +60,7 @@ function _logabsdetjac_scale(a::TrackedReal, x::Real, ::Val{0}) return track(_logabsdetjac_scale, a, value(x), Val(0)) end @grad function _logabsdetjac_scale(a::Real, x::Real, v::Val{0}) - return _logabsdetjac_scale(value(a), value(x), Val(0)), Δ -> (inverse(value(a)) .* Δ, nothing, nothing) + return _logabsdetjac_scale(value(a), value(x), Val(0)), Δ -> (inv(value(a)) .* Δ, nothing, nothing) end # Need to treat `AbstractVector` and `AbstractMatrix` separately due to ambiguity errors function _logabsdetjac_scale(a::TrackedReal, x::AbstractVector, ::Val{0}) @@ -68,7 +68,7 @@ function _logabsdetjac_scale(a::TrackedReal, x::AbstractVector, ::Val{0}) end @grad function _logabsdetjac_scale(a::Real, x::AbstractVector, v::Val{0}) da = value(a) - J = fill(inverse.(da), length(x)) + J = fill(inv.(da), length(x)) return _logabsdetjac_scale(da, value(x), Val(0)), Δ -> (transpose(J) * Δ, nothing, nothing) end function _logabsdetjac_scale(a::TrackedReal, x::AbstractMatrix, ::Val{0}) From 012d90a8686a40a8678530011cc7afc7c5352f9f Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sun, 12 Dec 2021 19:26:06 +0100 Subject: [PATCH 18/29] Apply suggestions from code review Co-authored-by: David Widmann --- src/bijectors/coupling.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/bijectors/coupling.jl b/src/bijectors/coupling.jl index 8ed89639..50c7abd8 100644 --- a/src/bijectors/coupling.jl +++ b/src/bijectors/coupling.jl @@ -119,16 +119,17 @@ Combines `x_1`, `x_2`, and `x_3` into a single vector. @inline combine(m::PartitionMask, x_1, x_2, x_3) = m.A_1 * x_1 .+ m.A_2 * x_2 .+ m.A_3 * x_3 function ChainRulesCore.rrule(::typeof(combine), m::PartitionMask, x_1, x_2, x_3) - prj = map(ChainRulesCore.ProjectTo, (x_1, x_2, x_3)) + proj_x_1 = ChainRulesCore.ProjectTo(x_1) + proj_x_2 = ChainRulesCore.ProjectTo(x_2) + proj_x_3 = ChainRulesCore.ProjectTo(x_3) - function _transform_ordered_adjoint(ΔΩ) + function combine_pullback(ΔΩ) Δ = ChainRulesCore.unthunk(ΔΩ) dx_1, dx_2, dx_3 = partition(m, Δ) - - return ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent(), prj[1](dx_1), prj[2](dx_2), prj[3](dx_3) + return ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent(), proj_x_1(dx_1), proj_x_2(dx_2), proj_x_3(dx_3) end - return combine(m, x_1, x_2, x_3), _transform_ordered_adjoint + return combine(m, x_1, x_2, x_3), combine_pullback end @@ -209,7 +210,6 @@ function couple(cl::Coupling, x::AbstractVector) return b end -#!!!!!!!!!!!!! function (cl::Coupling)(x::AbstractVector) # partition vector using `cl.mask::PartitionMask` x_1, x_2, x_3 = partition(cl.mask, x) From 3fc7a436559ac69707f5074a238f0799b81969e8 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sun, 12 Dec 2021 19:26:30 +0100 Subject: [PATCH 19/29] Whitespace fix. Co-authored-by: David Widmann --- test/runtests.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 4f44833e..a4aaa465 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -13,7 +13,6 @@ using ReverseDiff using Tracker using Zygote - using Random, LinearAlgebra, Test using Bijectors: Log, Exp, Shift, Scale, Logit, SimplexBijector, PDBijector, Permute, From 045412fe56ab128e2687401ffeff2596578de2d6 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sun, 12 Dec 2021 19:46:59 +0100 Subject: [PATCH 20/29] Move combine rrule and add test --- src/bijectors/coupling.jl | 15 --------------- src/chainrules.jl | 14 ++++++++++++++ test/ad/chainrules.jl | 2 ++ 3 files changed, 16 insertions(+), 15 deletions(-) diff --git a/src/bijectors/coupling.jl b/src/bijectors/coupling.jl index 50c7abd8..03b00ba5 100644 --- a/src/bijectors/coupling.jl +++ b/src/bijectors/coupling.jl @@ -118,21 +118,6 @@ Combines `x_1`, `x_2`, and `x_3` into a single vector. """ @inline combine(m::PartitionMask, x_1, x_2, x_3) = m.A_1 * x_1 .+ m.A_2 * x_2 .+ m.A_3 * x_3 -function ChainRulesCore.rrule(::typeof(combine), m::PartitionMask, x_1, x_2, x_3) - proj_x_1 = ChainRulesCore.ProjectTo(x_1) - proj_x_2 = ChainRulesCore.ProjectTo(x_2) - proj_x_3 = ChainRulesCore.ProjectTo(x_3) - - function combine_pullback(ΔΩ) - Δ = ChainRulesCore.unthunk(ΔΩ) - dx_1, dx_2, dx_3 = partition(m, Δ) - return ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent(), proj_x_1(dx_1), proj_x_2(dx_2), proj_x_3(dx_3) - end - - return combine(m, x_1, x_2, x_3), combine_pullback -end - - """ partition(m::PartitionMask, x) diff --git a/src/chainrules.jl b/src/chainrules.jl index 283fa96b..a7071649 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -7,6 +7,20 @@ ChainRulesCore.@scalar_rule( (x, - tanh(Ω + b) * x, x - 1), ) +function ChainRulesCore.rrule(::typeof(combine), m::PartitionMask, x_1, x_2, x_3) + proj_x_1 = ChainRulesCore.ProjectTo(x_1) + proj_x_2 = ChainRulesCore.ProjectTo(x_2) + proj_x_3 = ChainRulesCore.ProjectTo(x_3) + + function combine_pullback(ΔΩ) + Δ = ChainRulesCore.unthunk(ΔΩ) + dx_1, dx_2, dx_3 = partition(m, Δ) + return ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent(), proj_x_1(dx_1), proj_x_2(dx_2), proj_x_3(dx_3) + end + + return combine(m, x_1, x_2, x_3), combine_pullback +end + # `OrderedBijector` function ChainRulesCore.rrule(::typeof(_transform_ordered), y::AbstractVector) # ensures that we remain in the primal's subspace diff --git a/test/ad/chainrules.jl b/test/ad/chainrules.jl index b8c1f1be..b0e4dc2e 100644 --- a/test/ad/chainrules.jl +++ b/test/ad/chainrules.jl @@ -5,6 +5,8 @@ test_frule(Bijectors.find_alpha, x, y, z) test_rrule(Bijectors.find_alpha, x, y, z) + test_rrule(Bijectors.combine, Bijectors.PartitionMask(3, [1], [2]) ⊢ ChainRulesTestUtils.NoTangent(), [1.0], [2.0], [3.0]) + # ordered bijector b = Bijectors.OrderedBijector() test_rrule(Bijectors._transform_ordered, randn(5)) From 7e96b8d87cfe1c3e4182ba12d662ac9fe9b28bde Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sun, 12 Dec 2021 22:00:06 +0100 Subject: [PATCH 21/29] Apply suggestions from code review Co-authored-by: David Widmann --- README.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index b23a60e7..14fba13a 100644 --- a/README.md +++ b/README.md @@ -562,8 +562,9 @@ julia> with_logabsdet_jacobian(b, 0.6) # defaults to `(b(x), logabsdetja For further efficiency, one could manually implement `with_logabsdet_jacobian(b::Logit, x)`: ```julia -julia> import Bijectors: forward, Logit -julia> import ChangesOfVariables: with_logabsdet_jacobian +julia> using Bijectors: Logit + +julia> import Bijectors: with_logabsdet_jacobian julia> function with_logabsdet_jacobian(b::Logit{<:Real}, x) totally_worth_saving = @. (x - b.a) / (b.b - b.a) # spoiler: it's probably not From 874b5ecc0fec12e602f68fd923df6c6a64cbaf35 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sun, 12 Dec 2021 22:02:00 +0100 Subject: [PATCH 22/29] Use @test_deprecated Co-authored-by: David Widmann --- test/interface.jl | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/test/interface.jl b/test/interface.jl index 17733a36..3b6a50ec 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -858,8 +858,6 @@ end b = Bijectors.Exp() x = 0.3 - @test let r = forward(b, x) - (r.rv, r.logabsdetjac) == with_logabsdet_jacobian(b, x) - end + @test @test_deprecated(forward(b, x)) == NamedTuple{(:rv, :logabsdetjac)}(with_logabsdet_jacobian(b, x)) @test inv(b) == inverse(b) end From d9c8562fb9c23892bfc49412375eea07f706fde1 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sun, 12 Dec 2021 22:02:07 +0100 Subject: [PATCH 23/29] Use @test_deprecated Co-authored-by: David Widmann --- test/interface.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/interface.jl b/test/interface.jl index 3b6a50ec..c2dc2ddf 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -859,5 +859,5 @@ end x = 0.3 @test @test_deprecated(forward(b, x)) == NamedTuple{(:rv, :logabsdetjac)}(with_logabsdet_jacobian(b, x)) - @test inv(b) == inverse(b) + @test @test_deprecated(inv(b)) == inverse(b) end From 5cad1e49d7d2f95dbbbe13886b3a0433f6a66b7d Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sun, 12 Dec 2021 22:31:22 +0100 Subject: [PATCH 24/29] Use inverse instead of inv --- test/interface.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/interface.jl b/test/interface.jl index c2dc2ddf..eb7f9341 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -702,7 +702,7 @@ end # Stacked{<:Array} bs = bijector.(dists) # constrained-to-unconstrained bijectors for dists - ibs = inv.(bs) # invert, so we get unconstrained-to-constrained + ibs = inverse.(bs) # invert, so we get unconstrained-to-constrained sb = Stacked(ibs, ranges) # => Stacked <: Bijector x = rand(d) From 6ed6fa9bba457cb4b876c435f817497144b69ddf Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sun, 12 Dec 2021 22:57:41 +0100 Subject: [PATCH 25/29] Use test_inverse and test_with_logabsdet_jacobian --- test/interface.jl | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/test/interface.jl b/test/interface.jl index eb7f9341..6a0132ca 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -854,6 +854,15 @@ end end end +@testset "test_inverse and test_with_logabsdet_jacobian" begin + b = Bijectors.Scale{Float64,0}(4.2) + x = 0.3 + + InverseFunctions.test_inverse(b, x) + ChangesOfVariables.test_with_logabsdet_jacobian(b, x, (f::Bijectors.Scale, x) -> f.a) +end + + @testset "deprecations" begin b = Bijectors.Exp() x = 0.3 From 4fadffc1687b0e5ae0d953b0dfac506c96385015 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sun, 12 Dec 2021 23:18:43 +0100 Subject: [PATCH 26/29] Use inverse instead of inv --- src/bijectors/named_bijector.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/bijectors/named_bijector.jl b/src/bijectors/named_bijector.jl index 7eb71c14..7e175394 100644 --- a/src/bijectors/named_bijector.jl +++ b/src/bijectors/named_bijector.jl @@ -107,7 +107,7 @@ composel(bs::AbstractNamedBijector...) = NamedComposition(bs) composer(bs::AbstractNamedBijector...) = NamedComposition(reverse(bs)) ∘(b1::AbstractNamedBijector, b2::AbstractNamedBijector) = composel(b2, b1) -InverseFunctions.inverse(ct::NamedComposition) = NamedComposition(reverse(map(inv, ct.bs))) +InverseFunctions.inverse(ct::NamedComposition) = NamedComposition(reverse(map(inverse, ct.bs))) function (cb::NamedComposition{<:AbstractArray{<:AbstractNamedBijector}})(x) @assert length(cb.bs) > 0 From 5f4d982e756d6fac20779f47ffa49456e8da7908 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sun, 12 Dec 2021 22:32:15 +0100 Subject: [PATCH 27/29] Increase version number to v0.9.12 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 1f296958..3afc20f0 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Bijectors" uuid = "76274a88-744f-5084-9051-94815aaf08c4" -version = "0.9.11" +version = "0.9.12" [deps] ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" From fb54734a6004afb143e4fee3f9b855af47621ff5 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Mon, 13 Dec 2021 01:39:29 +0100 Subject: [PATCH 28/29] Reexport with_logabsdet_jacobian and inverse --- src/Bijectors.jl | 8 +++++--- src/bijectors/composed.jl | 6 +++--- src/bijectors/exp_log.jl | 4 ++-- src/bijectors/leaky_relu.jl | 8 ++++---- src/bijectors/named_bijector.jl | 14 +++++++------- src/bijectors/normalise.jl | 4 ++-- src/bijectors/permute.jl | 2 +- src/bijectors/planar_layer.jl | 2 +- src/bijectors/radial_layer.jl | 2 +- src/bijectors/rational_quadratic_spline.jl | 2 +- src/bijectors/shift.jl | 2 +- src/bijectors/stacked.jl | 8 ++++---- src/interface.jl | 8 ++++---- test/interface.jl | 4 ++-- test/runtests.jl | 5 +++-- 15 files changed, 41 insertions(+), 38 deletions(-) diff --git a/src/Bijectors.jl b/src/Bijectors.jl index 38ad4dbd..c6fa9a76 100644 --- a/src/Bijectors.jl +++ b/src/Bijectors.jl @@ -35,8 +35,8 @@ using MappedArrays using Base.Iterators: drop using LinearAlgebra: AbstractTriangular -using ChangesOfVariables: ChangesOfVariables, with_logabsdet_jacobian -using InverseFunctions: InverseFunctions, inverse +import ChangesOfVariables: with_logabsdet_jacobian +import InverseFunctions: inverse import ChainRulesCore import Functors @@ -54,6 +54,8 @@ export TransformDistribution, logpdf_with_trans, isclosedform, transform, + with_logabsdet_jacobian, + inverse, forward, logabsdetjac, logabsdetjacinv, @@ -255,7 +257,7 @@ include("chainrules.jl") Base.@deprecate forward(b::AbstractBijector, x) NamedTuple{(:rv,:logabsdetjac)}(with_logabsdet_jacobian(b, x)) @noinline function Base.inv(b::AbstractBijector) - Base.depwarn("`Base.inv(b::AbstractBijector)` is deprecated, use `InverseFunctions.inverse(b)` instead.", :inv) + Base.depwarn("`Base.inv(b::AbstractBijector)` is deprecated, use `inverse(b)` instead.", :inv) inverse(b) end diff --git a/src/bijectors/composed.jl b/src/bijectors/composed.jl index 02dd5b6f..3221f94b 100644 --- a/src/bijectors/composed.jl +++ b/src/bijectors/composed.jl @@ -153,7 +153,7 @@ end ∘(::Identity{N}, b::Bijector{N}) where {N} = b ∘(b::Bijector{N}, ::Identity{N}) where {N} = b -InverseFunctions.inverse(ct::Composed) = Composed(reverse(map(inverse, ct.ts))) +inverse(ct::Composed) = Composed(reverse(map(inverse, ct.ts))) # # TODO: should arrays also be using recursive implementation instead? function (cb::Composed{<:AbstractArray{<:Bijector}})(x) @@ -208,7 +208,7 @@ end end -function ChangesOfVariables.with_logabsdet_jacobian(cb::Composed, x) +function with_logabsdet_jacobian(cb::Composed, x) rv, logjac = with_logabsdet_jacobian(cb.ts[1], x) for t in cb.ts[2:end] @@ -218,7 +218,7 @@ function ChangesOfVariables.with_logabsdet_jacobian(cb::Composed, x) return (rv, logjac) end -@generated function ChangesOfVariables.with_logabsdet_jacobian(cb::Composed{T}, x) where {T<:Tuple} +@generated function with_logabsdet_jacobian(cb::Composed{T}, x) where {T<:Tuple} expr = Expr(:block) sym_y, sym_ladj, sym_tmp_ladj = gensym(:y), gensym(:lady), gensym(:tmp_lady) push!(expr.args, :(($sym_y, $sym_ladj) = with_logabsdet_jacobian(cb.ts[1], x))) diff --git a/src/bijectors/exp_log.jl b/src/bijectors/exp_log.jl index aacaa901..0f5f4683 100644 --- a/src/bijectors/exp_log.jl +++ b/src/bijectors/exp_log.jl @@ -27,8 +27,8 @@ Log() = Log{0}() (b::Exp{2})(y::AbstractArray{<:AbstractMatrix{<:Real}}) = map(b, y) (b::Log{2})(x::AbstractArray{<:AbstractMatrix{<:Real}}) = map(b, x) -InverseFunctions.inverse(b::Exp{N}) where {N} = Log{N}() -InverseFunctions.inverse(b::Log{N}) where {N} = Exp{N}() +inverse(b::Exp{N}) where {N} = Log{N}() +inverse(b::Log{N}) where {N} = Exp{N}() logabsdetjac(b::Exp{0}, x::Real) = x logabsdetjac(b::Exp{0}, x::AbstractVector) = x diff --git a/src/bijectors/leaky_relu.jl b/src/bijectors/leaky_relu.jl index 8cf769a7..c91e0faf 100644 --- a/src/bijectors/leaky_relu.jl +++ b/src/bijectors/leaky_relu.jl @@ -31,7 +31,7 @@ function (b::LeakyReLU{<:Any, 0})(x::Real) end (b::LeakyReLU{<:Any, 0})(x::AbstractVector{<:Real}) = map(b, x) -function InverseFunctions.inverse(b::LeakyReLU{<:Any,N}) where N +function inverse(b::LeakyReLU{<:Any,N}) where N invα = inv.(b.α) return LeakyReLU{typeof(invα),N}(invα) end @@ -47,14 +47,14 @@ logabsdetjac(b::LeakyReLU{<:Real, 0}, x::AbstractVector{<:Real}) = map(x -> loga # We implement `with_logabsdet_jacobian` by hand since we can re-use the computation of # the Jacobian of the transformation. This will lead to faster sampling # when using `rand` on a `TransformedDistribution` making use of `LeakyReLU`. -function ChangesOfVariables.with_logabsdet_jacobian(b::LeakyReLU{<:Any, 0}, x::Real) +function with_logabsdet_jacobian(b::LeakyReLU{<:Any, 0}, x::Real) mask = x < zero(x) J = mask * b.α + !mask * one(x) return (J * x, log(abs(J))) end # Batched version -function ChangesOfVariables.with_logabsdet_jacobian(b::LeakyReLU{<:Any, 0}, x::AbstractVector) +function with_logabsdet_jacobian(b::LeakyReLU{<:Any, 0}, x::AbstractVector) J = let T = eltype(x), z = zero(T), o = one(T) @. (x < z) * b.α + (x > z) * o end @@ -84,7 +84,7 @@ end # We implement `forward` by hand since we can re-use the computation of # the Jacobian of the transformation. This will lead to faster sampling # when using `rand` on a `TransformedDistribution` making use of `LeakyReLU`. -function ChangesOfVariables.with_logabsdet_jacobian(b::LeakyReLU{<:Any, 1}, x::AbstractVecOrMat) +function with_logabsdet_jacobian(b::LeakyReLU{<:Any, 1}, x::AbstractVecOrMat) # Is really diagonal of jacobian J = let T = eltype(x), z = zero(T), o = one(T) @. (x < z) * b.α + (x > z) * o diff --git a/src/bijectors/named_bijector.jl b/src/bijectors/named_bijector.jl index 7e175394..d4fb0557 100644 --- a/src/bijectors/named_bijector.jl +++ b/src/bijectors/named_bijector.jl @@ -1,6 +1,6 @@ abstract type AbstractNamedBijector <: AbstractBijector end -ChangesOfVariables.with_logabsdet_jacobian(b::AbstractNamedBijector, x) = (b(x), logabsdetjac(b, x)) +with_logabsdet_jacobian(b::AbstractNamedBijector, x) = (b(x), logabsdetjac(b, x)) ####################### ### `NamedBijector` ### @@ -55,7 +55,7 @@ names_to_bijectors(b::NamedBijector) = b.bs return :($(exprs...), ) end -@generated function InverseFunctions.inverse(b::NamedBijector{names}) where {names} +@generated function inverse(b::NamedBijector{names}) where {names} return :(NamedBijector(($([:($n = inverse(b.bs.$n)) for n in names]...), ))) end @@ -78,8 +78,8 @@ See also: [`Inverse`](@ref) struct NamedInverse{B<:AbstractNamedBijector} <: AbstractNamedBijector orig::B end -InverseFunctions.inverse(nb::AbstractNamedBijector) = NamedInverse(nb) -InverseFunctions.inverse(ni::NamedInverse) = ni.orig +inverse(nb::AbstractNamedBijector) = NamedInverse(nb) +inverse(ni::NamedInverse) = ni.orig logabsdetjac(ni::NamedInverse, y::NamedTuple) = -logabsdetjac(inverse(ni), ni(y)) @@ -107,7 +107,7 @@ composel(bs::AbstractNamedBijector...) = NamedComposition(bs) composer(bs::AbstractNamedBijector...) = NamedComposition(reverse(bs)) ∘(b1::AbstractNamedBijector, b2::AbstractNamedBijector) = composel(b2, b1) -InverseFunctions.inverse(ct::NamedComposition) = NamedComposition(reverse(map(inverse, ct.bs))) +inverse(ct::NamedComposition) = NamedComposition(reverse(map(inverse, ct.bs))) function (cb::NamedComposition{<:AbstractArray{<:AbstractNamedBijector}})(x) @assert length(cb.bs) > 0 @@ -151,7 +151,7 @@ end end -function ChangesOfVariables.with_logabsdet_jacobian(cb::NamedComposition, x) +function with_logabsdet_jacobian(cb::NamedComposition, x) rv, logjac = with_logabsdet_jacobian(cb.bs[1], x) for t in cb.bs[2:end] @@ -162,7 +162,7 @@ function ChangesOfVariables.with_logabsdet_jacobian(cb::NamedComposition, x) end -@generated function ChangesOfVariables.with_logabsdet_jacobian(cb::NamedComposition{T}, x) where {T<:Tuple} +@generated function with_logabsdet_jacobian(cb::NamedComposition{T}, x) where {T<:Tuple} expr = Expr(:block) sym_y, sym_ladj, sym_tmp_ladj = gensym(:y), gensym(:lady), gensym(:tmp_lady) diff --git a/src/bijectors/normalise.jl b/src/bijectors/normalise.jl index 1bd7851b..c49863c3 100644 --- a/src/bijectors/normalise.jl +++ b/src/bijectors/normalise.jl @@ -48,7 +48,7 @@ function Functors.functor(::Type{<:InvertibleBatchNorm}, x) return (b = x.b, logs = x.logs), reconstruct_invertiblebatchnorm end -function ChangesOfVariables.with_logabsdet_jacobian(bn::InvertibleBatchNorm, x) +function with_logabsdet_jacobian(bn::InvertibleBatchNorm, x) dims = ndims(x) size(x, dims - 1) == length(bn.b) || error("InvertibleBatchNorm expected $(length(bn.b)) channels, got $(size(x, dims - 1))") @@ -83,7 +83,7 @@ logabsdetjac(bn::InvertibleBatchNorm, x) = last(with_logabsdet_jacobian(bn, x)) (bn::InvertibleBatchNorm)(x) = first(with_logabsdet_jacobian(bn, x)) -function ChangesOfVariables.with_logabsdet_jacobian(invbn::Inverse{<:InvertibleBatchNorm}, y) +function with_logabsdet_jacobian(invbn::Inverse{<:InvertibleBatchNorm}, y) @assert !istraining() "`with_logabsdet_jacobian(::Inverse{InvertibleBatchNorm})` is only available in test mode." dims = ndims(y) as = ntuple(i -> i == ndims(y) - 1 ? size(y, i) : 1, dims) diff --git a/src/bijectors/permute.jl b/src/bijectors/permute.jl index b8a7c30f..5ba9e5cf 100644 --- a/src/bijectors/permute.jl +++ b/src/bijectors/permute.jl @@ -151,7 +151,7 @@ end @inline (b::Permute)(x::AbstractVecOrMat) = b.A * x -@inline InverseFunctions.inverse(b::Permute) = Permute(transpose(b.A)) +@inline inverse(b::Permute) = Permute(transpose(b.A)) logabsdetjac(b::Permute, x::AbstractVector) = zero(eltype(x)) logabsdetjac(b::Permute, x::AbstractMatrix) = zero(eltype(x), size(x, 2)) diff --git a/src/bijectors/planar_layer.jl b/src/bijectors/planar_layer.jl index 8b24fb0f..a4f7be1e 100644 --- a/src/bijectors/planar_layer.jl +++ b/src/bijectors/planar_layer.jl @@ -94,7 +94,7 @@ we get \\log |det ∂f(z)/∂z| = \\log(1 + sech²(wᵀz + b) wᵀû). ``` =# -function ChangesOfVariables.with_logabsdet_jacobian(flow::PlanarLayer, z::AbstractVecOrMat{<:Real}) +function with_logabsdet_jacobian(flow::PlanarLayer, z::AbstractVecOrMat{<:Real}) transformed, wT_û, wT_z = _transform(flow, z) # Compute ``\\log |det ∂f(z)/∂z|`` (see above). diff --git a/src/bijectors/radial_layer.jl b/src/bijectors/radial_layer.jl index 809c8fd6..e486f800 100644 --- a/src/bijectors/radial_layer.jl +++ b/src/bijectors/radial_layer.jl @@ -49,7 +49,7 @@ end (b::RadialLayer)(z::AbstractMatrix{<:Real}) = _transform(b, z).transformed (b::RadialLayer)(z::AbstractVector{<:Real}) = vec(_transform(b, z).transformed) -function ChangesOfVariables.with_logabsdet_jacobian(flow::RadialLayer, z::AbstractVecOrMat) +function with_logabsdet_jacobian(flow::RadialLayer, z::AbstractVecOrMat) transformed, α, β_hat, r = _transform(flow, z) # Compute log_det_jacobian d = size(flow.z_0, 1) diff --git a/src/bijectors/rational_quadratic_spline.jl b/src/bijectors/rational_quadratic_spline.jl index f9d3e59b..ef34c436 100644 --- a/src/bijectors/rational_quadratic_spline.jl +++ b/src/bijectors/rational_quadratic_spline.jl @@ -379,6 +379,6 @@ function rqs_forward( return (y, logjac) end -function ChangesOfVariables.with_logabsdet_jacobian(b::RationalQuadraticSpline{<:AbstractVector, 0}, x::Real) +function with_logabsdet_jacobian(b::RationalQuadraticSpline{<:AbstractVector, 0}, x::Real) return rqs_forward(b.widths, b.heights, b.derivatives, x) end diff --git a/src/bijectors/shift.jl b/src/bijectors/shift.jl index ad7816f2..e4e9960c 100644 --- a/src/bijectors/shift.jl +++ b/src/bijectors/shift.jl @@ -24,7 +24,7 @@ up1(b::Shift{T, N}) where {T, N} = Shift{T, N + 1}(b.a) (b::Shift)(x) = b.a .+ x (b::Shift{<:Any, 2})(x::AbstractArray{<:AbstractMatrix}) = map(b, x) -InverseFunctions.inverse(b::Shift{T, N}) where {T, N} = Shift{T, N}(-b.a) +inverse(b::Shift{T, N}) where {T, N} = Shift{T, N}(-b.a) # FIXME: implement custom adjoint to ensure we don't get tracking logabsdetjac(b::Shift{T, N}, x) where {T, N} = _logabsdetjac_shift(b.a, x, Val(N)) diff --git a/src/bijectors/stacked.jl b/src/bijectors/stacked.jl index 31a52d77..7f5272b9 100644 --- a/src/bijectors/stacked.jl +++ b/src/bijectors/stacked.jl @@ -51,10 +51,10 @@ isclosedform(b::Stacked) = all(isclosedform, b.bs) stack(bs::Bijector{0}...) = Stacked(bs) # For some reason `inverse.(sb.bs)` was unstable... This works though. -InverseFunctions.inverse(sb::Stacked) = Stacked(map(inverse, sb.bs), sb.ranges) +inverse(sb::Stacked) = Stacked(map(inverse, sb.bs), sb.ranges) # map is not type stable for many stacked bijectors as a large tuple # hence the generated function -@generated function InverseFunctions.inverse(sb::Stacked{A}) where {A <: Tuple} +@generated function inverse(sb::Stacked{A}) where {A <: Tuple} exprs = [] for i = 1:length(A.parameters) push!(exprs, :(inverse(sb.bs[$i]))) @@ -138,7 +138,7 @@ end # logjac += sum(_logjac) # return (vcat(y_1, y_2), logjac) # end -@generated function ChangesOfVariables.with_logabsdet_jacobian(b::Stacked{<:Tuple{Vararg{<:Any, N}}, <:Tuple{Vararg{<:Any, N}}}, x::AbstractVector) where {N} +@generated function with_logabsdet_jacobian(b::Stacked{<:Tuple{Vararg{<:Any, N}}, <:Tuple{Vararg{<:Any, N}}}, x::AbstractVector) where {N} expr = Expr(:block) y_names = [] @@ -160,7 +160,7 @@ end return expr end -function ChangesOfVariables.with_logabsdet_jacobian(sb::Stacked, x::AbstractVector) +function with_logabsdet_jacobian(sb::Stacked, x::AbstractVector) N = length(sb.bs) yinit, linit = with_logabsdet_jacobian(sb.bs[1], x[sb.ranges[1]]) logjac = sum(linit) diff --git a/src/interface.jl b/src/interface.jl index 80c2a140..276ce631 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -72,8 +72,8 @@ Functors.@functor Inverse up1(b::Inverse) = Inverse(up1(b.orig)) -InverseFunctions.inverse(b::Bijector) = Inverse(b) -InverseFunctions.inverse(ib::Inverse{<:Bijector}) = ib.orig +inverse(b::Bijector) = Inverse(b) +inverse(ib::Inverse{<:Bijector}) = ib.orig Base.:(==)(b1::Inverse{<:Bijector}, b2::Inverse{<:Bijector}) = b1.orig == b2.orig """ @@ -99,7 +99,7 @@ in the computation of the forward pass and the computation of the `logabsdetjac`. `forward` allows the user to take advantange of such efficiencies, if they exist. """ -ChangesOfVariables.with_logabsdet_jacobian(b::Bijector, x) = (b(x), logabsdetjac(b, x)) +with_logabsdet_jacobian(b::Bijector, x) = (b(x), logabsdetjac(b, x)) """ logabsdetjacinv(b::Bijector, y) @@ -114,7 +114,7 @@ logabsdetjacinv(b::Bijector, y) = logabsdetjac(inverse(b), y) struct Identity{N} <: Bijector{N} end (::Identity)(x) = copy(x) -InverseFunctions.inverse(b::Identity) = b +inverse(b::Identity) = b up1(::Identity{N}) where {N} = Identity{N + 1}() logabsdetjac(::Identity{0}, x::Real) = zero(eltype(x)) diff --git a/test/interface.jl b/test/interface.jl index 6a0132ca..11fc27f6 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -858,8 +858,8 @@ end b = Bijectors.Scale{Float64,0}(4.2) x = 0.3 - InverseFunctions.test_inverse(b, x) - ChangesOfVariables.test_with_logabsdet_jacobian(b, x, (f::Bijectors.Scale, x) -> f.a) + test_inverse(b, x) + test_with_logabsdet_jacobian(b, x, (f::Bijectors.Scale, x) -> f.a) end diff --git a/test/runtests.jl b/test/runtests.jl index a4aaa465..65000d5f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,13 +1,11 @@ using Bijectors using ChainRulesTestUtils -using ChangesOfVariables using Combinatorics using DistributionsAD using FiniteDifferences using ForwardDiff using Functors -using InverseFunctions using LogExpFunctions using ReverseDiff using Tracker @@ -18,6 +16,9 @@ using Random, LinearAlgebra, Test using Bijectors: Log, Exp, Shift, Scale, Logit, SimplexBijector, PDBijector, Permute, PlanarLayer, RadialLayer, Stacked, TruncatedBijector, ADBijector +using ChangesOfVariables: test_with_logabsdet_jacobian +using InverseFunctions: test_inverse + using DistributionsAD: TuringUniform, TuringMvNormal, TuringMvLogNormal, TuringPoissonBinomial From 4b683e5e2eed870033052d75b69fdbf3873c6a8f Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Wed, 15 Dec 2021 20:12:20 +0100 Subject: [PATCH 29/29] Increase package version to v0.10.0 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 3afc20f0..15a1887d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Bijectors" uuid = "76274a88-744f-5084-9051-94815aaf08c4" -version = "0.9.12" +version = "0.10.0" [deps] ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"