From 0632e26b4d2d00848d33a0cf034b4f770c2ff6f3 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Thu, 13 Jun 2024 21:56:38 +0100 Subject: [PATCH 1/6] update interface for objective initialization --- src/AdvancedVI.jl | 7 ++++--- src/optimize.jl | 2 +- src/utils.jl | 11 +++++++---- 3 files changed, 12 insertions(+), 8 deletions(-) diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 7c7a1fc8..08b1b71d 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -78,7 +78,7 @@ If the estimator is stateful, it can implement `init` to initialize the state. abstract type AbstractVariationalObjective end """ - init(rng, obj, λ, restructure) + init(rng, obj, prob, params, restructure) Initialize a state of the variational objective `obj` given the initial variational parameters `λ`. This function needs to be implemented only if `obj` is stateful. @@ -86,13 +86,14 @@ This function needs to be implemented only if `obj` is stateful. # Arguments - `rng::Random.AbstractRNG`: Random number generator. - `obj::AbstractVariationalObjective`: Variational objective. -- `λ`: Initial variational parameters. +- `params`: Initial variational parameters. - `restructure`: Function that reconstructs the variational approximation from `λ`. """ init( ::Random.AbstractRNG, ::AbstractVariationalObjective, - ::AbstractVector, + ::Any + ::Any, ::Any ) = nothing diff --git a/src/optimize.jl b/src/optimize.jl index 4eb6644a..e5fe374d 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -66,7 +66,7 @@ function optimize( ) params, restructure = Optimisers.destructure(deepcopy(q_init)) opt_st = maybe_init_optimizer(state_init, optimizer, params) - obj_st = maybe_init_objective(state_init, rng, objective, params, restructure) + obj_st = maybe_init_objective(state_init, rng, objective, problem, params, restructure) grad_buf = DiffResults.DiffResult(zero(eltype(params)), similar(params)) stats = NamedTuple[] diff --git a/src/utils.jl b/src/utils.jl index 98b79b2d..a8039b22 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -6,19 +6,22 @@ end function maybe_init_optimizer( state_init::NamedTuple, optimizer ::Optimisers.AbstractRule, - params ::AbstractVector + params ) - haskey(state_init, :optimizer) ? state_init.optimizer : Optimisers.setup(optimizer, params) + haskey(state_init, :optimizer) ? + state_init.optimizer : Optimisers.setup(optimizer, params) end function maybe_init_objective( state_init::NamedTuple, rng ::Random.AbstractRNG, objective ::AbstractVariationalObjective, - params ::AbstractVector, + problem, + params, restructure ) - haskey(state_init, :objective) ? state_init.objective : init(rng, objective, params, restructure) + haskey(state_init, :objective) ? + state_init.objective : init(rng, objective, params, problem, restructure) end eachsample(samples::AbstractMatrix) = eachcol(samples) From e9b960924d614d54f1b8b6009c2bf5c3fc325ee4 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Thu, 13 Jun 2024 22:39:21 +0100 Subject: [PATCH 2/6] improve `RepGradELBO` to not redefine integrand --- src/AdvancedVI.jl | 7 ++-- src/objectives/elbo/repgradelbo.jl | 52 ++++++++++++++++++------------ 2 files changed, 37 insertions(+), 22 deletions(-) diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 08b1b71d..632bb182 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -37,6 +37,9 @@ Evaluate the value and gradient of a function `f` at `θ` using the automatic di """ function value_and_gradient! end +stop_gradient(x) = x + + # Update for gradient descent step """ update_variational_params!(family_type, opt_st, params, restructure, grad) @@ -92,9 +95,9 @@ This function needs to be implemented only if `obj` is stateful. init( ::Random.AbstractRNG, ::AbstractVariationalObjective, - ::Any ::Any, - ::Any + ::Any, + ::Any, ) = nothing """ diff --git a/src/objectives/elbo/repgradelbo.jl b/src/objectives/elbo/repgradelbo.jl index 2d95d076..d860a600 100644 --- a/src/objectives/elbo/repgradelbo.jl +++ b/src/objectives/elbo/repgradelbo.jl @@ -46,24 +46,18 @@ function Base.show(io::IO, obj::RepGradELBO) print(io, ")") end -function estimate_entropy_maybe_stl(entropy_estimator::AbstractEntropyEstimator, samples, q, q_stop) - q_maybe_stop = maybe_stop_entropy_score(entropy_estimator, q, q_stop) - estimate_entropy(entropy_estimator, samples, q_maybe_stop) -end - function estimate_energy_with_samples(prob, samples) mean(Base.Fix1(LogDensityProblems.logdensity, prob), eachsample(samples)) end """ - reparam_with_entropy(rng, q, q_stop, n_samples, ent_est) + reparam_with_entropy(rng, q, n_samples, ent_est) Draw `n_samples` from `q` and compute its entropy. # Arguments - `rng::Random.AbstractRNG`: Random number generator. - `q`: Variational approximation. -- `q_stop`: `q` but with its gradient stopped. - `n_samples::Int`: Number of Monte Carlo samples - `ent_est`: The entropy estimation strategy. (See `estimate_entropy`.) @@ -72,10 +66,10 @@ Draw `n_samples` from `q` and compute its entropy. - `entropy`: An estimate (or exact value) of the differential entropy of `q`. """ function reparam_with_entropy( - rng::Random.AbstractRNG, q, q_stop, n_samples::Int, ent_est::AbstractEntropyEstimator + rng::Random.AbstractRNG, q, n_samples::Int, ent_est::AbstractEntropyEstimator ) samples = rand(rng, q, n_samples) - entropy = estimate_entropy_maybe_stl(ent_est, samples, q, q_stop) + entropy = estimate_entropy(ent_est, samples, q) samples, entropy end @@ -86,7 +80,7 @@ function estimate_objective( prob; n_samples::Int = obj.n_samples ) - samples, entropy = reparam_with_entropy(rng, q, q, n_samples, obj.entropy) + samples, entropy = reparam_with_entropy(rng, q, n_samples, obj.entropy) energy = estimate_energy_with_samples(prob, samples) energy + entropy end @@ -94,6 +88,31 @@ end estimate_objective(obj::RepGradELBO, q, prob; n_samples::Int = obj.n_samples) = estimate_objective(Random.default_rng(), obj, q, prob; n_samples) +function init(rng::Random.AbstractRNG, obj::RepGradELBO, problem, params, restructure) + function obj_integrand(params′) + q = restructure(params′) + samples, entropy = reparam_with_entropy(rng, q, obj.n_samples, obj.entropy) + energy = estimate_energy_with_samples(problem, samples) + elbo = energy + entropy + -elbo + end + obj_integrand +end + +function estimate_objective_restructure( + rng::Random.AbstractRNG, + obj::RepGradELBO, + params, + restructure, + prob; + n_samples::Int = obj.n_samples +) + q = restructure(params) + samples, entropy = reparam_with_entropy(rng, q, n_samples, obj.entropy) + energy = estimate_energy_with_samples(prob, samples) + energy + entropy +end + function estimate_gradient!( rng ::Random.AbstractRNG, obj ::RepGradELBO, @@ -104,18 +123,11 @@ function estimate_gradient!( restructure, state, ) - q_stop = restructure(λ) - function f(λ′) - q = restructure(λ′) - samples, entropy = reparam_with_entropy(rng, q, q_stop, obj.n_samples, obj.entropy) - energy = estimate_energy_with_samples(prob, samples) - elbo = energy + entropy - -elbo - end - value_and_gradient!(adtype, f, λ, out) + obj_integrand = state + value_and_gradient!(adtype, obj_integrand, λ, out) nelbo = DiffResults.value(out) stat = (elbo=-nelbo,) - out, nothing, stat + out, state, stat end From bcbe11d251a762c3d1c360bb1c5e1dc158922115 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Thu, 13 Jun 2024 22:58:01 +0100 Subject: [PATCH 3/6] fix bug --- src/utils.jl | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index a8039b22..3ae59a78 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -8,8 +8,11 @@ function maybe_init_optimizer( optimizer ::Optimisers.AbstractRule, params ) - haskey(state_init, :optimizer) ? - state_init.optimizer : Optimisers.setup(optimizer, params) + if haskey(state_init, :optimizer) + state_init.optimizer + else + Optimisers.setup(optimizer, params) + end end function maybe_init_objective( @@ -20,8 +23,11 @@ function maybe_init_objective( params, restructure ) - haskey(state_init, :objective) ? - state_init.objective : init(rng, objective, params, problem, restructure) + if haskey(state_init, :objective) + state_init.objective + else + init(rng, objective, problem, params, restructure) + end end eachsample(samples::AbstractMatrix) = eachcol(samples) From 00493693f22b44a5f01f472e7636ab0e03ea7c31 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Thu, 13 Jun 2024 23:10:33 +0100 Subject: [PATCH 4/6] fix Bijectors ext to match new `reparam_with_entropy` interface --- ext/AdvancedVIBijectorsExt.jl | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/ext/AdvancedVIBijectorsExt.jl b/ext/AdvancedVIBijectorsExt.jl index 4a88d6fb..950b5f15 100644 --- a/ext/AdvancedVIBijectorsExt.jl +++ b/ext/AdvancedVIBijectorsExt.jl @@ -38,17 +38,15 @@ end function AdvancedVI.reparam_with_entropy( rng ::Random.AbstractRNG, q ::Bijectors.TransformedDistribution, - q_stop ::Bijectors.TransformedDistribution, n_samples::Int, ent_est ::AdvancedVI.AbstractEntropyEstimator ) - transform = q.transform - q_unconst = q.dist - q_unconst_stop = q_stop.dist + transform = q.transform + q_unconst = q.dist # Draw samples and compute entropy of the uncontrained distribution unconstr_samples, unconst_entropy = AdvancedVI.reparam_with_entropy( - rng, q_unconst, q_unconst_stop, n_samples, ent_est + rng, q_unconst, n_samples, ent_est ) # Apply bijector to samples while estimating its jacobian From ef3c312f004ecf5eb14aa82878aac88037b88bc9 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Sat, 15 Jun 2024 01:04:20 +0100 Subject: [PATCH 5/6] revert removal of q_stop, add auxiliary argto `value_and_gradient!` --- ext/AdvancedVIBijectorsExt.jl | 8 +++-- ext/AdvancedVIForwardDiffExt.jl | 23 +++++++++--- ext/AdvancedVIReverseDiffExt.jl | 19 ++++++++-- ext/AdvancedVIZygoteExt.jl | 23 +++++++++--- src/AdvancedVI.jl | 22 +++++++++--- src/objectives/elbo/repgradelbo.jl | 58 ++++++++++++++---------------- test/Project.toml | 3 +- test/interface/repgradelbo.jl | 29 +++++++++++++++ 8 files changed, 132 insertions(+), 53 deletions(-) diff --git a/ext/AdvancedVIBijectorsExt.jl b/ext/AdvancedVIBijectorsExt.jl index 950b5f15..a227fdf2 100644 --- a/ext/AdvancedVIBijectorsExt.jl +++ b/ext/AdvancedVIBijectorsExt.jl @@ -38,15 +38,17 @@ end function AdvancedVI.reparam_with_entropy( rng ::Random.AbstractRNG, q ::Bijectors.TransformedDistribution, + q_stop ::Bijectors.TransformedDistribution, n_samples::Int, ent_est ::AdvancedVI.AbstractEntropyEstimator ) - transform = q.transform - q_unconst = q.dist + transform = q.transform + q_unconst = q.dist + q_unconst_stop = q_stop.dist # Draw samples and compute entropy of the uncontrained distribution unconstr_samples, unconst_entropy = AdvancedVI.reparam_with_entropy( - rng, q_unconst, n_samples, ent_est + rng, q_unconst, q_unconst_stop, n_samples, ent_est ) # Apply bijector to samples while estimating its jacobian diff --git a/ext/AdvancedVIForwardDiffExt.jl b/ext/AdvancedVIForwardDiffExt.jl index 5949bdf8..a8afd031 100644 --- a/ext/AdvancedVIForwardDiffExt.jl +++ b/ext/AdvancedVIForwardDiffExt.jl @@ -14,16 +14,29 @@ end getchunksize(::ADTypes.AutoForwardDiff{chunksize}) where {chunksize} = chunksize function AdvancedVI.value_and_gradient!( - ad::ADTypes.AutoForwardDiff, f, θ::AbstractVector{T}, out::DiffResults.MutableDiffResult -) where {T<:Real} + ad ::ADTypes.AutoForwardDiff, + f, + x ::AbstractVector{<:Real}, + out::DiffResults.MutableDiffResult +) chunk_size = getchunksize(ad) config = if isnothing(chunk_size) - ForwardDiff.GradientConfig(f, θ) + ForwardDiff.GradientConfig(f, x) else - ForwardDiff.GradientConfig(f, θ, ForwardDiff.Chunk(length(θ), chunk_size)) + ForwardDiff.GradientConfig(f, x, ForwardDiff.Chunk(length(x), chunk_size)) end - ForwardDiff.gradient!(out, f, θ, config) + ForwardDiff.gradient!(out, f, x, config) return out end +function AdvancedVI.value_and_gradient!( + ad ::ADTypes.AutoForwardDiff, + f, + x ::AbstractVector, + aux, + out::DiffResults.MutableDiffResult +) + AdvancedVI.value_and_gradient!(ad, x′ -> f(x′, aux), x, out) +end + end diff --git a/ext/AdvancedVIReverseDiffExt.jl b/ext/AdvancedVIReverseDiffExt.jl index 520cd9ff..392f5cea 100644 --- a/ext/AdvancedVIReverseDiffExt.jl +++ b/ext/AdvancedVIReverseDiffExt.jl @@ -13,11 +13,24 @@ end # ReverseDiff without compiled tape function AdvancedVI.value_and_gradient!( - ad::ADTypes.AutoReverseDiff, f, θ::AbstractVector{<:Real}, out::DiffResults.MutableDiffResult + ad::ADTypes.AutoReverseDiff, + f, + x::AbstractVector{<:Real}, + out::DiffResults.MutableDiffResult ) - tp = ReverseDiff.GradientTape(f, θ) - ReverseDiff.gradient!(out, tp, θ) + tp = ReverseDiff.GradientTape(f, x) + ReverseDiff.gradient!(out, tp, x) return out end +function AdvancedVI.value_and_gradient!( + ad::ADTypes.AutoReverseDiff, + f, + x::AbstractVector{<:Real}, + aux, + out::DiffResults.MutableDiffResult +) + AdvancedVI.value_and_gradient!(ad, x′ -> f(x′, aux), x, out) +end + end diff --git a/ext/AdvancedVIZygoteExt.jl b/ext/AdvancedVIZygoteExt.jl index 7b8f8817..806c08e4 100644 --- a/ext/AdvancedVIZygoteExt.jl +++ b/ext/AdvancedVIZygoteExt.jl @@ -4,21 +4,36 @@ module AdvancedVIZygoteExt if isdefined(Base, :get_extension) using AdvancedVI using AdvancedVI: ADTypes, DiffResults + using ChainRulesCore using Zygote else using ..AdvancedVI using ..AdvancedVI: ADTypes, DiffResults + using ..ChainRulesCore using ..Zygote end function AdvancedVI.value_and_gradient!( - ad::ADTypes.AutoZygote, f, θ::AbstractVector{<:Real}, out::DiffResults.MutableDiffResult + ::ADTypes.AutoZygote, + f, + x::AbstractVector{<:Real}, + out::DiffResults.MutableDiffResult ) - y, back = Zygote.pullback(f, θ) - ∇θ = back(one(y)) + y, back = Zygote.pullback(f, x) + ∇x = back(one(y)) DiffResults.value!(out, y) - DiffResults.gradient!(out, only(∇θ)) + DiffResults.gradient!(out, only(∇x)) return out end +function AdvancedVI.value_and_gradient!( + ad::ADTypes.AutoZygote, + f, + x::AbstractVector{<:Real}, + aux, + out::DiffResults.MutableDiffResult +) + AdvancedVI.value_and_gradient!(ad, x′ -> f(x′, aux), x, out) +end + end diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 632bb182..7a09030b 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -25,20 +25,34 @@ using StatsBase # derivatives """ - value_and_gradient!(ad, f, θ, out) + value_and_gradient!(ad, f, x, out) + value_and_gradient!(ad, f, x, aux, out) -Evaluate the value and gradient of a function `f` at `θ` using the automatic differentiation backend `ad` and store the result in `out`. +Evaluate the value and gradient of a function `f` at `x` using the automatic differentiation backend `ad` and store the result in `out`. +`f` may receive auxiliary input as `f(x,aux)`. # Arguments - `ad::ADTypes.AbstractADType`: Automatic differentiation backend. - `f`: Function subject to differentiation. -- `θ`: The point to evaluate the gradient. +- `x`: The point to evaluate the gradient. +- `aux`: Auxiliary input passed to `f`. - `out::DiffResults.MutableDiffResult`: Buffer to contain the output gradient and function value. """ function value_and_gradient! end -stop_gradient(x) = x +""" + stop_gradient(x) + +Stop the gradient from propagating to `x` if the selected ad backend supports it. +Otherwise, it is equivalent to `identity`. + +# Arguments +- `x`: Input +# Returns +- `x`: Same value as the input. +""" +function stop_gradient end # Update for gradient descent step """ diff --git a/src/objectives/elbo/repgradelbo.jl b/src/objectives/elbo/repgradelbo.jl index d860a600..27a937e8 100644 --- a/src/objectives/elbo/repgradelbo.jl +++ b/src/objectives/elbo/repgradelbo.jl @@ -46,6 +46,11 @@ function Base.show(io::IO, obj::RepGradELBO) print(io, ")") end +function estimate_entropy_maybe_stl(entropy_estimator::AbstractEntropyEstimator, samples, q, q_stop) + q_maybe_stop = maybe_stop_entropy_score(entropy_estimator, q, q_stop) + estimate_entropy(entropy_estimator, samples, q_maybe_stop) +end + function estimate_energy_with_samples(prob, samples) mean(Base.Fix1(LogDensityProblems.logdensity, prob), eachsample(samples)) end @@ -66,10 +71,14 @@ Draw `n_samples` from `q` and compute its entropy. - `entropy`: An estimate (or exact value) of the differential entropy of `q`. """ function reparam_with_entropy( - rng::Random.AbstractRNG, q, n_samples::Int, ent_est::AbstractEntropyEstimator + rng ::Random.AbstractRNG, + q, + q_stop, + n_samples::Int, + ent_est ::AbstractEntropyEstimator ) samples = rand(rng, q, n_samples) - entropy = estimate_entropy(ent_est, samples, q) + entropy = estimate_entropy_maybe_stl(ent_est, samples, q, q_stop) samples, entropy end @@ -80,7 +89,7 @@ function estimate_objective( prob; n_samples::Int = obj.n_samples ) - samples, entropy = reparam_with_entropy(rng, q, n_samples, obj.entropy) + samples, entropy = reparam_with_entropy(rng, q, q, n_samples, obj.entropy) energy = estimate_energy_with_samples(prob, samples) energy + entropy end @@ -88,29 +97,13 @@ end estimate_objective(obj::RepGradELBO, q, prob; n_samples::Int = obj.n_samples) = estimate_objective(Random.default_rng(), obj, q, prob; n_samples) -function init(rng::Random.AbstractRNG, obj::RepGradELBO, problem, params, restructure) - function obj_integrand(params′) - q = restructure(params′) - samples, entropy = reparam_with_entropy(rng, q, obj.n_samples, obj.entropy) - energy = estimate_energy_with_samples(problem, samples) - elbo = energy + entropy - -elbo - end - obj_integrand -end - -function estimate_objective_restructure( - rng::Random.AbstractRNG, - obj::RepGradELBO, - params, - restructure, - prob; - n_samples::Int = obj.n_samples -) - q = restructure(params) - samples, entropy = reparam_with_entropy(rng, q, n_samples, obj.entropy) - energy = estimate_energy_with_samples(prob, samples) - energy + entropy +function estimate_repgradelbo_ad_forward(params′, aux) + @unpack rng, obj, problem, restructure, q_stop = aux + q = restructure(params′) + samples, entropy = reparam_with_entropy(rng, q, q_stop, obj.n_samples, obj.entropy) + energy = estimate_energy_with_samples(problem, samples) + elbo = energy + entropy + -elbo end function estimate_gradient!( @@ -119,15 +112,16 @@ function estimate_gradient!( adtype::ADTypes.AbstractADType, out ::DiffResults.MutableDiffResult, prob, - λ, + params, restructure, state, ) - obj_integrand = state - value_and_gradient!(adtype, obj_integrand, λ, out) - + q_stop = restructure(params) + aux = (rng=rng, obj=obj, problem=prob, restructure=restructure, q_stop=q_stop) + value_and_gradient!( + adtype, estimate_repgradelbo_ad_forward, params, aux, out + ) nelbo = DiffResults.value(out) stat = (elbo=-nelbo,) - - out, state, stat + out, nothing, stat end diff --git a/test/Project.toml b/test/Project.toml index f3acedea..a0dba17f 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,9 +1,9 @@ [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" +DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" -Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" @@ -26,7 +26,6 @@ ADTypes = "0.2.1, 1" Bijectors = "0.13" Distributions = "0.25.100" DistributionsAD = "0.6.45" -Enzyme = "0.12" FillArrays = "1.6.1" ForwardDiff = "0.10.36" Functors = "0.4.5" diff --git a/test/interface/repgradelbo.jl b/test/interface/repgradelbo.jl index 61ff0111..ac9bfeca 100644 --- a/test/interface/repgradelbo.jl +++ b/test/interface/repgradelbo.jl @@ -26,3 +26,32 @@ using Test @test elbo ≈ elbo_ref rtol=0.1 end end + +@testset "interface RepGradELBO STL variance reduction" begin + seed = (0x38bef07cf9cc549d) + rng = StableRNG(seed) + + modelstats = normal_meanfield(rng, Float64) + @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats + + @testset for ad in [ + ADTypes.AutoForwardDiff(), + ADTypes.AutoReverseDiff(), + ADTypes.AutoZygote() + ] + q_true = MeanFieldGaussian( + Vector{eltype(μ_true)}(μ_true), + Diagonal(Vector{eltype(L_true)}(diag(L_true))) + ) + params, re = Optimisers.destructure(q_true) + obj = RepGradELBO(10; entropy=StickingTheLandingEntropy()) + out = DiffResults.DiffResult(zero(eltype(params)), similar(params)) + + aux = (rng=rng, obj=obj, problem=model, restructure=re, q_stop=q_true) + AdvancedVI.value_and_gradient!( + ad, AdvancedVI.estimate_repgradelbo_ad_forward, params, aux, out + ) + grad = DiffResults.gradient(out) + @test norm(grad) ≈ 0 atol=1e-5 + end +end From c9c5c5d9c12464053330e0ae8f4b5f1ce2692047 Mon Sep 17 00:00:00 2001 From: Ray Kim Date: Mon, 24 Jun 2024 23:01:17 +0100 Subject: [PATCH 6/6] add Tapir support, restructure for caching round Tapir `rrule` --- Project.toml | 3 ++ ext/AdvancedVIForwardDiffExt.jl | 16 +++++---- ext/AdvancedVIReverseDiffExt.jl | 14 ++++---- ext/AdvancedVIZygoteExt.jl | 14 ++++---- src/AdvancedVI.jl | 33 ++++++++++++------- src/objectives/elbo/repgradelbo.jl | 19 +++++++++-- src/optimize.jl | 4 ++- src/utils.jl | 3 +- test/Project.toml | 2 ++ test/inference/repgradelbo_distributionsad.jl | 1 + test/inference/repgradelbo_locationscale.jl | 1 + .../repgradelbo_locationscale_bijectors.jl | 3 +- test/interface/ad.jl | 9 +++-- test/interface/repgradelbo.jl | 14 ++++---- test/runtests.jl | 2 +- 15 files changed, 92 insertions(+), 46 deletions(-) diff --git a/Project.toml b/Project.toml index 9d30e740..515310bb 100644 --- a/Project.toml +++ b/Project.toml @@ -25,6 +25,7 @@ Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] @@ -32,6 +33,7 @@ AdvancedVIBijectorsExt = "Bijectors" AdvancedVIEnzymeExt = "Enzyme" AdvancedVIForwardDiffExt = "ForwardDiff" AdvancedVIReverseDiffExt = "ReverseDiff" +AdvancedVITapirExt = "Tapir" AdvancedVIZygoteExt = "Zygote" [compat] @@ -55,6 +57,7 @@ Requires = "1.0" ReverseDiff = "1.15.1" SimpleUnPack = "1.1.0" StatsBase = "0.32, 0.33, 0.34" +Tapir = "0.2.23" Zygote = "0.6.63" julia = "1.6" diff --git a/ext/AdvancedVIForwardDiffExt.jl b/ext/AdvancedVIForwardDiffExt.jl index a8afd031..bf0da458 100644 --- a/ext/AdvancedVIForwardDiffExt.jl +++ b/ext/AdvancedVIForwardDiffExt.jl @@ -14,10 +14,11 @@ end getchunksize(::ADTypes.AutoForwardDiff{chunksize}) where {chunksize} = chunksize function AdvancedVI.value_and_gradient!( - ad ::ADTypes.AutoForwardDiff, + ad ::ADTypes.AutoForwardDiff, + ::Any, f, - x ::AbstractVector{<:Real}, - out::DiffResults.MutableDiffResult + x ::AbstractVector, + out ::DiffResults.MutableDiffResult ) chunk_size = getchunksize(ad) config = if isnothing(chunk_size) @@ -30,13 +31,14 @@ function AdvancedVI.value_and_gradient!( end function AdvancedVI.value_and_gradient!( - ad ::ADTypes.AutoForwardDiff, + ad ::ADTypes.AutoForwardDiff, + st_ad, f, - x ::AbstractVector, + x ::AbstractVector, aux, - out::DiffResults.MutableDiffResult + out ::DiffResults.MutableDiffResult ) - AdvancedVI.value_and_gradient!(ad, x′ -> f(x′, aux), x, out) + AdvancedVI.value_and_gradient!(ad, st_ad, x′ -> f(x′, aux), x, out) end end diff --git a/ext/AdvancedVIReverseDiffExt.jl b/ext/AdvancedVIReverseDiffExt.jl index 392f5cea..cde187b5 100644 --- a/ext/AdvancedVIReverseDiffExt.jl +++ b/ext/AdvancedVIReverseDiffExt.jl @@ -13,9 +13,10 @@ end # ReverseDiff without compiled tape function AdvancedVI.value_and_gradient!( - ad::ADTypes.AutoReverseDiff, + ::ADTypes.AutoReverseDiff, + ::Any, f, - x::AbstractVector{<:Real}, + x ::AbstractVector{<:Real}, out::DiffResults.MutableDiffResult ) tp = ReverseDiff.GradientTape(f, x) @@ -24,13 +25,14 @@ function AdvancedVI.value_and_gradient!( end function AdvancedVI.value_and_gradient!( - ad::ADTypes.AutoReverseDiff, + ad ::ADTypes.AutoReverseDiff, + st_ad, f, - x::AbstractVector{<:Real}, + x ::AbstractVector{<:Real}, aux, - out::DiffResults.MutableDiffResult + out ::DiffResults.MutableDiffResult ) - AdvancedVI.value_and_gradient!(ad, x′ -> f(x′, aux), x, out) + AdvancedVI.value_and_gradient!(ad, st_ad, x′ -> f(x′, aux), x, out) end end diff --git a/ext/AdvancedVIZygoteExt.jl b/ext/AdvancedVIZygoteExt.jl index 806c08e4..f60b31fb 100644 --- a/ext/AdvancedVIZygoteExt.jl +++ b/ext/AdvancedVIZygoteExt.jl @@ -14,9 +14,10 @@ else end function AdvancedVI.value_and_gradient!( - ::ADTypes.AutoZygote, + ::ADTypes.AutoZygote, + ::Any, f, - x::AbstractVector{<:Real}, + x ::AbstractVector{<:Real}, out::DiffResults.MutableDiffResult ) y, back = Zygote.pullback(f, x) @@ -27,13 +28,14 @@ function AdvancedVI.value_and_gradient!( end function AdvancedVI.value_and_gradient!( - ad::ADTypes.AutoZygote, + ad ::ADTypes.AutoZygote, + st_ad, f, - x::AbstractVector{<:Real}, + x ::AbstractVector{<:Real}, aux, - out::DiffResults.MutableDiffResult + out ::DiffResults.MutableDiffResult ) - AdvancedVI.value_and_gradient!(ad, x′ -> f(x′, aux), x, out) + AdvancedVI.value_and_gradient!(ad, st_ad, x′ -> f(x′, aux), x, out) end end diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 7a09030b..8423a312 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -25,14 +25,15 @@ using StatsBase # derivatives """ - value_and_gradient!(ad, f, x, out) - value_and_gradient!(ad, f, x, aux, out) + value_and_gradient!(adtype, ad_st, f, x, out) + value_and_gradient!(adtype, ad_st, f, x, aux, out) -Evaluate the value and gradient of a function `f` at `x` using the automatic differentiation backend `ad` and store the result in `out`. +Evaluate the value and gradient of a function `f` at `x` using the automatic differentiation (AD) backend `ad` and store the result in `out`. `f` may receive auxiliary input as `f(x,aux)`. # Arguments -- `ad::ADTypes.AbstractADType`: Automatic differentiation backend. +- `adtype::ADTypes.AbstractADType`: AD backend. +- `ad_st`: State used by the AD backend. (This will often be pre-compiled tapes/caches.) - `f`: Function subject to differentiation. - `x`: The point to evaluate the gradient. - `aux`: Auxiliary input passed to `f`. @@ -41,18 +42,22 @@ Evaluate the value and gradient of a function `f` at `x` using the automatic dif function value_and_gradient! end """ - stop_gradient(x) + init_adbackend(adtype, f, x) + init_adbackend(adtype, f, x, aux) -Stop the gradient from propagating to `x` if the selected ad backend supports it. -Otherwise, it is equivalent to `identity`. +Initialize the AD backend and setup states necessary. # Arguments -- `x`: Input +- `ad::ADTypes.AbstractADType`: Automatic differentiation backend. +- `f`: Function subject to differentiation. +- `x`: The point to evaluate the gradient. +- `aux`: Auxiliary input passed to `f`. # Returns -- `x`: Same value as the input. +- `ad_st`: State of the AD backend. (This will often be pre-compiled tapes/caches.) """ -function stop_gradient end +init_adbackend(::ADTypes.AbstractADType, ::Any, ::Any) = nothing +init_adbackend(::ADTypes.AbstractADType, ::Any, ::Any, ::Any) = nothing # Update for gradient descent step """ @@ -95,14 +100,17 @@ If the estimator is stateful, it can implement `init` to initialize the state. abstract type AbstractVariationalObjective end """ - init(rng, obj, prob, params, restructure) + init(rng, obj, adtype, prob, params, restructure) -Initialize a state of the variational objective `obj` given the initial variational parameters `λ`. +Initialize a state of the variational objective `obj`. This function needs to be implemented only if `obj` is stateful. +The state of the AD backend `adtype` shall also be initialized here. # Arguments - `rng::Random.AbstractRNG`: Random number generator. - `obj::AbstractVariationalObjective`: Variational objective. +- `adtype::ADTypes.ADType`:Automatic differentiation backend. +- `prob`: The target log-joint likelihood implementing the `LogDensityProblem` interface. - `params`: Initial variational parameters. - `restructure`: Function that reconstructs the variational approximation from `λ`. """ @@ -112,6 +120,7 @@ init( ::Any, ::Any, ::Any, + ::Any, ) = nothing """ diff --git a/src/objectives/elbo/repgradelbo.jl b/src/objectives/elbo/repgradelbo.jl index 27a937e8..e0f5de40 100644 --- a/src/objectives/elbo/repgradelbo.jl +++ b/src/objectives/elbo/repgradelbo.jl @@ -106,6 +106,20 @@ function estimate_repgradelbo_ad_forward(params′, aux) -elbo end +function init( + rng ::Random.AbstractRNG, + obj ::RepGradELBO, + adtype ::ADTypes.AbstractADType, + prob, + params, + restructure, +) + q_stop = restructure(params) + aux = (rng=rng, obj=obj, problem=prob, restructure=restructure, q_stop=q_stop) + ad_st = init_adbackend(adtype, estimate_repgradelbo_ad_forward, params, aux) + (ad_st=ad_st,) +end + function estimate_gradient!( rng ::Random.AbstractRNG, obj ::RepGradELBO, @@ -117,11 +131,12 @@ function estimate_gradient!( state, ) q_stop = restructure(params) + ad_st = state.ad_st aux = (rng=rng, obj=obj, problem=prob, restructure=restructure, q_stop=q_stop) value_and_gradient!( - adtype, estimate_repgradelbo_ad_forward, params, aux, out + adtype, ad_st, estimate_repgradelbo_ad_forward, params, aux, out ) nelbo = DiffResults.value(out) stat = (elbo=-nelbo,) - out, nothing, stat + out, state, stat end diff --git a/src/optimize.jl b/src/optimize.jl index e5fe374d..659f3d16 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -66,7 +66,9 @@ function optimize( ) params, restructure = Optimisers.destructure(deepcopy(q_init)) opt_st = maybe_init_optimizer(state_init, optimizer, params) - obj_st = maybe_init_objective(state_init, rng, objective, problem, params, restructure) + obj_st = maybe_init_objective( + state_init, rng, adtype, objective, problem, params, restructure + ) grad_buf = DiffResults.DiffResult(zero(eltype(params)), similar(params)) stats = NamedTuple[] diff --git a/src/utils.jl b/src/utils.jl index 3ae59a78..fbfdc330 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -18,6 +18,7 @@ end function maybe_init_objective( state_init::NamedTuple, rng ::Random.AbstractRNG, + adtype ::ADTypes.AbstractADType, objective ::AbstractVariationalObjective, problem, params, @@ -26,7 +27,7 @@ function maybe_init_objective( if haskey(state_init, :objective) state_init.objective else - init(rng, objective, problem, params, restructure) + init(rng, objective, adtype, problem, params, restructure) end end diff --git a/test/Project.toml b/test/Project.toml index a0dba17f..b25ddfd5 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -17,6 +17,7 @@ ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" @@ -38,6 +39,7 @@ ReverseDiff = "1.15.1" SimpleUnPack = "1.1.0" StableRNGs = "1.0.0" Statistics = "1" +Tapir = "0.2.23" Test = "1" Tracker = "0.2.20" Zygote = "0.6.63" diff --git a/test/inference/repgradelbo_distributionsad.jl b/test/inference/repgradelbo_distributionsad.jl index 38dbc6e3..201f69b2 100644 --- a/test/inference/repgradelbo_distributionsad.jl +++ b/test/inference/repgradelbo_distributionsad.jl @@ -14,6 +14,7 @@ :ForwarDiff => AutoForwardDiff(), #:ReverseDiff => AutoReverseDiff(), :Zygote => AutoZygote(), + :Tapir => AutoTapir(), #:Enzyme => AutoEnzyme(), ) diff --git a/test/inference/repgradelbo_locationscale.jl b/test/inference/repgradelbo_locationscale.jl index 73846675..b46d4d4d 100644 --- a/test/inference/repgradelbo_locationscale.jl +++ b/test/inference/repgradelbo_locationscale.jl @@ -14,6 +14,7 @@ (adbackname, adtype) in Dict( :ForwarDiff => AutoForwardDiff(), :ReverseDiff => AutoReverseDiff(), + :Tapir => AutoTapir(safe_mode=false), :Zygote => AutoZygote(), #:Enzyme => AutoEnzyme(), ) diff --git a/test/inference/repgradelbo_locationscale_bijectors.jl b/test/inference/repgradelbo_locationscale_bijectors.jl index 1f62df0f..5f04c25a 100644 --- a/test/inference/repgradelbo_locationscale_bijectors.jl +++ b/test/inference/repgradelbo_locationscale_bijectors.jl @@ -13,7 +13,8 @@ (adbackname, adtype) in Dict( :ForwarDiff => AutoForwardDiff(), :ReverseDiff => AutoReverseDiff(), - #:Zygote => AutoZygote(), + :Zygote => AutoZygote(), + :Tapir => AutoTapir(safe_mode=false), #:Enzyme => AutoEnzyme(), ) diff --git a/test/interface/ad.jl b/test/interface/ad.jl index be4ca34e..55e7cdbd 100644 --- a/test/interface/ad.jl +++ b/test/interface/ad.jl @@ -2,18 +2,21 @@ using Test @testset "ad" begin - @testset "$(adname)" for (adname, adsymbol) ∈ Dict( + @testset "$(adname)" for (adname, adtype) ∈ Dict( :ForwardDiff => AutoForwardDiff(), :ReverseDiff => AutoReverseDiff(), :Zygote => AutoZygote(), + :Tapir => AutoTapir(), # :Enzyme => AutoEnzyme() # Currently not tested against ) D = 10 A = randn(D, D) λ = randn(D) - grad_buf = DiffResults.GradientResult(λ) f(λ′) = λ′'*A*λ′ / 2 - AdvancedVI.value_and_gradient!(adsymbol, f, λ, grad_buf) + + ad_st = AdvancedVI.init_adbackend(adtype, f, λ) + grad_buf = DiffResults.GradientResult(λ) + AdvancedVI.value_and_gradient!(adtype, ad_st, f, λ, grad_buf) ∇ = DiffResults.gradient(grad_buf) f = DiffResults.value(grad_buf) @test ∇ ≈ (A + A')*λ/2 diff --git a/test/interface/repgradelbo.jl b/test/interface/repgradelbo.jl index ac9bfeca..58e440ce 100644 --- a/test/interface/repgradelbo.jl +++ b/test/interface/repgradelbo.jl @@ -34,7 +34,7 @@ end modelstats = normal_meanfield(rng, Float64) @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats - @testset for ad in [ + @testset for adtype in [ ADTypes.AutoForwardDiff(), ADTypes.AutoReverseDiff(), ADTypes.AutoZygote() @@ -44,12 +44,14 @@ end Diagonal(Vector{eltype(L_true)}(diag(L_true))) ) params, re = Optimisers.destructure(q_true) - obj = RepGradELBO(10; entropy=StickingTheLandingEntropy()) - out = DiffResults.DiffResult(zero(eltype(params)), similar(params)) - - aux = (rng=rng, obj=obj, problem=model, restructure=re, q_stop=q_true) + obj = RepGradELBO(10; entropy=StickingTheLandingEntropy()) + out = DiffResults.DiffResult(zero(eltype(params)), similar(params)) + aux = (rng=rng, obj=obj, problem=model, restructure=re, q_stop=q_true) + ad_st = AdvancedVI.init_adbackend( + adtype, AdvancedVI.estimate_repgradelbo_ad_forward, params, aux + ) AdvancedVI.value_and_gradient!( - ad, AdvancedVI.estimate_repgradelbo_ad_forward, params, aux, out + adtype, ad_st, AdvancedVI.estimate_repgradelbo_ad_forward, params, aux, out ) grad = DiffResults.gradient(out) @test norm(grad) ≈ 0 atol=1e-5 diff --git a/test/runtests.jl b/test/runtests.jl index 3bd13144..c85840b5 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -18,7 +18,7 @@ using DistributionsAD using LogDensityProblems using Optimisers using ADTypes -using ForwardDiff, ReverseDiff, Zygote +using ForwardDiff, ReverseDiff, Zygote, Tapir using AdvancedVI