From 5990fae9f176e84d83bd119fa4a6b0e68f028493 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Fri, 26 Aug 2022 14:27:00 +0200 Subject: [PATCH] Use LogDensityProblems instead of `gradient_logp` (#1877) * Use LogDensityProblems instead of `gradient_logp` * Update Project.toml * Update Project.toml * Import LogDensityProblems in submodule * Fix Gibbs sampling with DynamicHMC * Update ad.jl * Add LogDensityProblems test dependency * Qualify `LogDensityFunction` --- Project.toml | 8 +- docs/src/using-turing/autodiff.md | 2 +- src/Turing.jl | 9 ++ src/contrib/inference/dynamichmc.jl | 41 +++------ src/contrib/inference/sghmc.jl | 22 +++-- src/essential/Essential.jl | 15 +--- src/essential/ad.jl | 126 ++++++++-------------------- src/essential/compat/reversediff.jl | 80 ------------------ src/inference/Inference.jl | 4 +- src/inference/hmc.jl | 27 +++--- src/modes/ModeEstimation.jl | 13 ++- test/Project.toml | 4 +- test/essential/ad.jl | 28 ++++--- test/runtests.jl | 5 +- test/skipped/unit_test_helper.jl | 6 +- test/test_utils/ad_utils.jl | 8 +- 16 files changed, 128 insertions(+), 270 deletions(-) delete mode 100644 src/essential/compat/reversediff.jl diff --git a/Project.toml b/Project.toml index f390512ca..53be72a2a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Turing" uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" -version = "0.21.10" +version = "0.21.11" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" @@ -11,7 +11,6 @@ AdvancedVI = "b5ca4192-6429-45e5-a2d9-87aec30a685c" BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" -DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" @@ -20,6 +19,7 @@ EllipticalSliceSampling = "cad2338a-1db2-11e9-3401-43bc07c9ede2" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" NamedArrays = "86f7a689-2022-50b4-a561-43c23ac3c673" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" @@ -32,7 +32,6 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" -ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] AbstractMCMC = "4" @@ -43,7 +42,6 @@ AdvancedVI = "0.1" BangBang = "0.3" Bijectors = "0.8, 0.9, 0.10" DataStructures = "0.18" -DiffResults = "1" Distributions = "0.23.3, 0.24, 0.25" DistributionsAD = "0.6" DocStringExtensions = "0.8, 0.9" @@ -51,6 +49,7 @@ DynamicPPL = "0.20" EllipticalSliceSampling = "0.5, 1" ForwardDiff = "0.10.3" Libtask = "0.6.7, 0.7" +LogDensityProblems = "0.12" MCMCChains = "5" NamedArrays = "0.9" Reexport = "0.2, 1" @@ -60,5 +59,4 @@ SpecialFunctions = "0.7.2, 0.8, 0.9, 0.10, 1, 2" StatsBase = "0.32, 0.33" StatsFuns = "0.8, 0.9, 1" Tracker = "0.2.3" -ZygoteRules = "0.2" julia = "1.6" diff --git a/docs/src/using-turing/autodiff.md b/docs/src/using-turing/autodiff.md index dd9e2976d..44eff6133 100644 --- a/docs/src/using-turing/autodiff.md +++ b/docs/src/using-turing/autodiff.md @@ -10,7 +10,7 @@ title: Automatic Differentiation Turing supports four packages of automatic differentiation (AD) in the back end during sampling. The default AD backend is [ForwardDiff](https://github.com/JuliaDiff/ForwardDiff.jl) for forward-mode AD. Three reverse-mode AD backends are also supported, namely [Tracker](https://github.com/FluxML/Tracker.jl), [Zygote](https://github.com/FluxML/Zygote.jl) and [ReverseDiff](https://github.com/JuliaDiff/ReverseDiff.jl). `Zygote` and `ReverseDiff` are supported optionally if explicitly loaded by the user with `using Zygote` or `using ReverseDiff` next to `using Turing`. -To switch between the different AD backends, one can call function `Turing.setadbackend(backend_sym)`, where `backend_sym` can be `:forwarddiff` (`ForwardDiff`), `:tracker` (`Tracker`), `:zygote` (`Zygote`) or `:reversediff` (`ReverseDiff.jl`). When using `ReverseDiff`, to compile the tape only once and cache it for later use, the user needs to load [Memoization.jl](https://github.com/marius311/Memoization.jl) first with `using Memoization` then call `Turing.setrdcache(true)`. However, note that the use of caching in certain types of models can lead to incorrect results and/or errors. Models for which the compiled tape can be safely cached are models with fixed size loops and no run-time if statements. Compile-time if statements are fine. To empty the cache, you can call `Turing.emptyrdcache()`. +To switch between the different AD backends, one can call function `Turing.setadbackend(backend_sym)`, where `backend_sym` can be `:forwarddiff` (`ForwardDiff`), `:tracker` (`Tracker`), `:zygote` (`Zygote`) or `:reversediff` (`ReverseDiff.jl`). When using `ReverseDiff`, to compile the tape only once and cache it for later use, the user has to call `Turing.setrdcache(true)`. However, note that the use of caching in certain types of models can lead to incorrect results and/or errors. Models for which the compiled tape can be safely cached are models with fixed size loops and no run-time if statements. Compile-time if statements are fine. ## Compositional Sampling with Differing AD Modes diff --git a/src/Turing.jl b/src/Turing.jl index 06ab9644b..0d1f58992 100644 --- a/src/Turing.jl +++ b/src/Turing.jl @@ -9,6 +9,7 @@ using Tracker: Tracker import AdvancedVI import DynamicPPL: getspace, NoDist, NamedDist +import LogDensityProblems import Random const PROGRESS = Ref(true) @@ -37,12 +38,20 @@ function (f::LogDensityFunction)(θ::AbstractVector) return getlogp(last(DynamicPPL.evaluate!!(f.model, VarInfo(f.varinfo, f.sampler, θ), f.sampler, f.context))) end +# LogDensityProblems interface +LogDensityProblems.logdensity(f::LogDensityFunction, θ::AbstractVector) = f(θ) +LogDensityProblems.dimension(f::LogDensityFunction) = length(f.varinfo[f.sampler]) +function LogDensityProblems.capabilities(::Type{<:LogDensityFunction}) + return LogDensityProblems.LogDensityOrder{0}() +end + # Standard tag: Improves stacktraces # Ref: https://www.stochasticlifestyle.com/improved-forwarddiff-jl-stacktraces-with-package-tags/ struct TuringTag end # Allow Turing tag in gradient etc. calls of the log density function ForwardDiff.checktag(::Type{ForwardDiff.Tag{TuringTag, V}}, ::LogDensityFunction, ::AbstractArray{V}) where {V} = true +ForwardDiff.checktag(::Type{ForwardDiff.Tag{TuringTag, V}}, ::Base.Fix1{typeof(LogDensityProblems.logdensity),<:LogDensityFunction}, ::AbstractArray{V}) where {V} = true # Random probability measures. include("stdlib/distributions.jl") diff --git a/src/contrib/inference/dynamichmc.jl b/src/contrib/inference/dynamichmc.jl index 6a9207217..fd7c5a6e3 100644 --- a/src/contrib/inference/dynamichmc.jl +++ b/src/contrib/inference/dynamichmc.jl @@ -19,25 +19,6 @@ DynamicNUTS{AD}(space::Symbol...) where AD = DynamicNUTS{AD, space}() DynamicPPL.getspace(::DynamicNUTS{<:Any, space}) where {space} = space -# Only define traits for `DynamicNUTS` sampler to avoid type piracy and surprises -# TODO: Implement generally with `LogDensityProblems` -const DynamicHMCLogDensity{M<:Model,S<:Sampler{<:DynamicNUTS},V<:AbstractVarInfo} = Turing.LogDensityFunction{V,M,S,DynamicPPL.DefaultContext} - -function DynamicHMC.dimension(ℓ::DynamicHMCLogDensity) - return length(ℓ.varinfo[ℓ.sampler]) -end - -function DynamicHMC.capabilities(::Type{<:DynamicHMCLogDensity}) - return DynamicHMC.LogDensityOrder{1}() -end - -function DynamicHMC.logdensity_and_gradient( - ℓ::DynamicHMCLogDensity, - x::AbstractVector, -) - return gradient_logp(x, ℓ.varinfo, ℓ.model, ℓ.sampler, ℓ.context) -end - """ DynamicNUTSState @@ -46,9 +27,10 @@ State of the [`DynamicNUTS`](@ref) sampler. # Fields $(TYPEDFIELDS) """ -struct DynamicNUTSState{V<:AbstractVarInfo,C,M,S} +struct DynamicNUTSState{L,V<:AbstractVarInfo,C,M,S} + logdensity::L vi::V - "Cache of sample, log density, and gradient of log density." + "Cache of sample, log density, and gradient of log density evaluation." cache::C metric::M stepsize::S @@ -61,10 +43,10 @@ function gibbs_state( state::DynamicNUTSState, varinfo::AbstractVarInfo, ) - # Update the previous evaluation. - ℓ = Turing.LogDensityFunction(varinfo, model, spl, DynamicPPL.DefaultContext()) + # Update the log density function and its cached evaluation. + ℓ = LogDensityProblems.ADgradient(Turing.LogDensityFunction(varinfo, model, spl, DynamicPPL.DefaultContext())) Q = DynamicHMC.evaluate_ℓ(ℓ, varinfo[spl]) - return DynamicNUTSState(varinfo, Q, state.metric, state.stepsize) + return DynamicNUTSState(ℓ, varinfo, Q, state.metric, state.stepsize) end DynamicPPL.initialsampler(::Sampler{<:DynamicNUTS}) = SampleFromUniform() @@ -82,10 +64,13 @@ function DynamicPPL.initialstep( model(rng, vi, spl) end + # Define log-density function. + ℓ = LogDensityProblems.ADgradient(Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext())) + # Perform initial step. results = DynamicHMC.mcmc_keep_warmup( rng, - Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext()), + ℓ, 0; initialization = (q = vi[spl],), reporter = DynamicHMC.NoProgressReport(), @@ -99,7 +84,7 @@ function DynamicPPL.initialstep( # Create first sample and state. sample = Transition(vi) - state = DynamicNUTSState(vi, Q, steps.H.κ, steps.ϵ) + state = DynamicNUTSState(ℓ, vi, Q, steps.H.κ, steps.ϵ) return sample, state end @@ -113,7 +98,7 @@ function AbstractMCMC.step( ) # Compute next sample. vi = state.vi - ℓ = Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext()) + ℓ = state.logdensity steps = DynamicHMC.mcmc_steps( rng, DynamicHMC.NUTS(), @@ -129,7 +114,7 @@ function AbstractMCMC.step( # Create next sample and state. sample = Transition(vi) - newstate = DynamicNUTSState(vi, Q, state.metric, state.stepsize) + newstate = DynamicNUTSState(ℓ, vi, Q, state.metric, state.stepsize) return sample, newstate end diff --git a/src/contrib/inference/sghmc.jl b/src/contrib/inference/sghmc.jl index d026b3656..0cb42c8a4 100644 --- a/src/contrib/inference/sghmc.jl +++ b/src/contrib/inference/sghmc.jl @@ -41,7 +41,8 @@ function SGHMC{AD}( return SGHMC{AD,space,typeof(_learning_rate)}(_learning_rate, _momentum_decay) end -struct SGHMCState{V<:AbstractVarInfo, T<:AbstractVector{<:Real}} +struct SGHMCState{L,V<:AbstractVarInfo, T<:AbstractVector{<:Real}} + logdensity::L vi::V velocity::T end @@ -61,7 +62,8 @@ function DynamicPPL.initialstep( # Compute initial sample and state. sample = Transition(vi) - state = SGHMCState(vi, zero(vi[spl])) + ℓ = LogDensityProblems.ADgradient(Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext())) + state = SGHMCState(ℓ, vi, zero(vi[spl])) return sample, state end @@ -74,9 +76,10 @@ function AbstractMCMC.step( kwargs... ) # Compute gradient of log density. + ℓ = state.logdensity vi = state.vi θ = vi[spl] - _, grad = gradient_logp(θ, vi, model, spl) + grad = last(LogDensityProblems.logdensity_and_gradient(ℓ, θ)) # Update latent variables and velocity according to # equation (15) of Chen et al. (2014) @@ -92,7 +95,7 @@ function AbstractMCMC.step( # Compute next sample and state. sample = Transition(vi) - newstate = SGHMCState(vi, newv) + newstate = SGHMCState(ℓ, vi, newv) return sample, newstate end @@ -191,7 +194,8 @@ metadata(t::SGLDTransition) = (lp = t.lp, SGLD_stepsize = t.stepsize) DynamicPPL.getlogp(t::SGLDTransition) = t.lp -struct SGLDState{V<:AbstractVarInfo} +struct SGLDState{L,V<:AbstractVarInfo} + logdensity::L vi::V step::Int end @@ -211,7 +215,8 @@ function DynamicPPL.initialstep( # Create first sample and state. sample = SGLDTransition(vi, zero(spl.alg.stepsize(0))) - state = SGLDState(vi, 1) + ℓ = LogDensityProblems.ADgradient(Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext())) + state = SGLDState(ℓ, vi, 1) return sample, state end @@ -224,9 +229,10 @@ function AbstractMCMC.step( kwargs... ) # Perform gradient step. + ℓ = state.logdensity vi = state.vi θ = vi[spl] - _, grad = gradient_logp(θ, vi, model, spl) + grad = last(LogDensityProblems.logdensity_and_gradient(ℓ, θ)) step = state.step stepsize = spl.alg.stepsize(step) θ .+= (stepsize / 2) .* grad .+ sqrt(stepsize) .* randn(rng, eltype(θ), length(θ)) @@ -237,7 +243,7 @@ function AbstractMCMC.step( # Compute next sample and state. sample = SGLDTransition(vi, stepsize) - newstate = SGLDState(vi, state.step + 1) + newstate = SGLDState(ℓ, vi, state.step + 1) return sample, newstate end diff --git a/src/essential/Essential.jl b/src/essential/Essential.jl index d18a6e7ab..df0a9b5ac 100644 --- a/src/essential/Essential.jl +++ b/src/essential/Essential.jl @@ -13,22 +13,13 @@ import Bijectors: link, invlink using AdvancedVI using StatsFuns: logsumexp, softmax @reexport using DynamicPPL -using Requires import AdvancedPS -import DiffResults -import ZygoteRules +import LogDensityProblems include("container.jl") include("ad.jl") -function __init__() - @require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" begin - include("compat/reversediff.jl") - export ReverseDiffAD, getrdcache, setrdcache, emptyrdcache - end -end - export @model, @varname, generate_observe, @@ -53,11 +44,13 @@ export @model, ForwardDiffAD, TrackerAD, ZygoteAD, + ReverseDiffAD, value, - gradient_logp, CHUNKSIZE, ADBACKEND, setchunksize, + setrdcache, + getrdcache, verifygrad, @logprob_str, @prob_str diff --git a/src/essential/ad.jl b/src/essential/ad.jl index 0e4005395..b56ce0140 100644 --- a/src/essential/ad.jl +++ b/src/essential/ad.jl @@ -18,6 +18,9 @@ end function _setadbackend(::Val{:zygote}) ADBACKEND[] = :zygote end +function _setadbackend(::Val{:reversediff}) + ADBACKEND[] = :reversediff +end const ADSAFE = Ref(false) function setadsafe(switch::Bool) @@ -39,9 +42,7 @@ struct ForwardDiffAD{chunk,standardtag} <: ADBackend end # Use standard tag if not specified otherwise ForwardDiffAD{N}() where {N} = ForwardDiffAD{N,true}() -getchunksize(::Type{<:ForwardDiffAD{chunk}}) where chunk = chunk -getchunksize(::Type{<:Sampler{Talg}}) where Talg = getchunksize(Talg) -getchunksize(::Type{SampleFromPrior}) = CHUNKSIZE[] +getchunksize(::ForwardDiffAD{chunk}) where chunk = chunk standardtag(::ForwardDiffAD{<:Any,true}) = true standardtag(::ForwardDiffAD) = false @@ -49,12 +50,24 @@ standardtag(::ForwardDiffAD) = false struct TrackerAD <: ADBackend end struct ZygoteAD <: ADBackend end +struct ReverseDiffAD{cache} <: ADBackend end + +const RDCache = Ref(false) + +setrdcache(b::Bool) = setrdcache(Val(b)) +setrdcache(::Val{false}) = RDCache[] = false +setrdcache(::Val{true}) = RDCache[] = true + +getrdcache() = RDCache[] + ADBackend() = ADBackend(ADBACKEND[]) ADBackend(T::Symbol) = ADBackend(Val(T)) ADBackend(::Val{:forwarddiff}) = ForwardDiffAD{CHUNKSIZE[]} ADBackend(::Val{:tracker}) = TrackerAD ADBackend(::Val{:zygote}) = ZygoteAD +ADBackend(::Val{:reversediff}) = ReverseDiffAD{getrdcache()} + ADBackend(::Val) = error("The requested AD backend is not available. Make sure to load all required packages.") """ @@ -63,54 +76,15 @@ ADBackend(::Val) = error("The requested AD backend is not available. Make sure t Find the autodifferentiation backend of the algorithm `alg`. """ getADbackend(spl::Sampler) = getADbackend(spl.alg) -getADbackend(spl::SampleFromPrior) = ADBackend()() +getADbackend(::SampleFromPrior) = ADBackend()() -""" - gradient_logp( - θ::AbstractVector{<:Real}, - vi::VarInfo, - model::Model, - sampler::AbstractSampler, - ctx::DynamicPPL.AbstractContext = DynamicPPL.DefaultContext() - ) - -Computes the value of the log joint of `θ` and its gradient for the model -specified by `(vi, sampler, model)` using whichever automatic differentation -tool is currently active. -""" -function gradient_logp( - θ::AbstractVector{<:Real}, - vi::VarInfo, - model::Model, - sampler::AbstractSampler, - ctx::DynamicPPL.AbstractContext = DynamicPPL.DefaultContext() -) - return gradient_logp(getADbackend(sampler), θ, vi, model, sampler, ctx) +function LogDensityProblems.ADgradient(ℓ::Turing.LogDensityFunction) + return LogDensityProblems.ADgradient(getADbackend(ℓ.sampler), ℓ) end -""" -gradient_logp( - backend::ADBackend, - θ::AbstractVector{<:Real}, - vi::VarInfo, - model::Model, - sampler::AbstractSampler = SampleFromPrior(), - ctx::DynamicPPL.AbstractContext = DynamicPPL.DefaultContext() -) - -Compute the value of the log joint of `θ` and its gradient for the model -specified by `(vi, sampler, model)` using `backend` for AD, e.g. `ForwardDiffAD{N}()` uses `ForwardDiff.jl` with chunk size `N`, `TrackerAD()` uses `Tracker.jl` and `ZygoteAD()` uses `Zygote.jl`. -""" -function gradient_logp( - ad::ForwardDiffAD, - θ::AbstractVector{<:Real}, - vi::VarInfo, - model::Model, - sampler::AbstractSampler=SampleFromPrior(), - context::DynamicPPL.AbstractContext = DynamicPPL.DefaultContext() -) - # Define log density function. - f = Turing.LogDensityFunction(vi, model, sampler, context) +function LogDensityProblems.ADgradient(ad::ForwardDiffAD, ℓ::Turing.LogDensityFunction) + θ = ℓ.varinfo[ℓ.sampler] + f = Base.Fix1(LogDensityProblems.logdensity, ℓ) # Define configuration for ForwardDiff. tag = if standardtag(ad) @@ -118,58 +92,30 @@ function gradient_logp( else ForwardDiff.Tag(f, eltype(θ)) end - chunk_size = getchunksize(typeof(ad)) + chunk_size = getchunksize(ad) config = if chunk_size == 0 ForwardDiff.GradientConfig(f, θ, ForwardDiff.Chunk(θ), tag) else ForwardDiff.GradientConfig(f, θ, ForwardDiff.Chunk(length(θ), chunk_size), tag) end - # Obtain both value and gradient of the log density function. - out = DiffResults.GradientResult(θ) - ForwardDiff.gradient!(out, f, θ, config) - logp = DiffResults.value(out) - ∂logp∂θ = DiffResults.gradient(out) + return LogDensityProblems.ADgradient(Val(:ForwardDiff), ℓ; gradientconfig=config) +end - return logp, ∂logp∂θ +function LogDensityProblems.ADgradient(::TrackerAD, ℓ::Turing.LogDensityFunction) + return LogDensityProblems.ADgradient(Val(:Tracker), ℓ) end -function gradient_logp( - ::TrackerAD, - θ::AbstractVector{<:Real}, - vi::VarInfo, - model::Model, - sampler::AbstractSampler = SampleFromPrior(), - context::DynamicPPL.AbstractContext = DynamicPPL.DefaultContext() -) - # Define log density function. - f = Turing.LogDensityFunction(vi, model, sampler, context) - - # Compute forward pass and pullback. - l_tracked, ȳ = Tracker.forward(f, θ) - - # Remove tracking info. - l::typeof(getlogp(vi)) = Tracker.data(l_tracked) - ∂l∂θ::typeof(θ) = Tracker.data(only(ȳ(1))) - - return l, ∂l∂θ + +function LogDensityProblems.ADgradient(::ZygoteAD, ℓ::Turing.LogDensityFunction) + return LogDensityProblems.ADgradient(Val(:Zygote), ℓ) end -function gradient_logp( - backend::ZygoteAD, - θ::AbstractVector{<:Real}, - vi::VarInfo, - model::Model, - sampler::AbstractSampler = SampleFromPrior(), - context::DynamicPPL.AbstractContext = DynamicPPL.DefaultContext() -) - # Define log density function. - f = Turing.LogDensityFunction(vi, model, sampler, context) - - # Compute forward pass and pullback. - l::typeof(getlogp(vi)), ȳ = ZygoteRules.pullback(f, θ) - ∂l∂θ::typeof(θ) = only(ȳ(1)) - - return l, ∂l∂θ +for cache in (:true, :false) + @eval begin + function LogDensityProblems.ADgradient(::ReverseDiffAD{$cache}, ℓ::Turing.LogDensityFunction) + return LogDensityProblems.ADgradient(Val(:ReverseDiff), ℓ; compile=Val($cache)) + end + end end function verifygrad(grad::AbstractVector{<:Real}) diff --git a/src/essential/compat/reversediff.jl b/src/essential/compat/reversediff.jl deleted file mode 100644 index cc077c5e0..000000000 --- a/src/essential/compat/reversediff.jl +++ /dev/null @@ -1,80 +0,0 @@ -using .ReverseDiff: compile, GradientTape - -struct ReverseDiffAD{cache} <: ADBackend end -const RDCache = Ref(false) -setrdcache(b::Bool) = setrdcache(Val(b)) -setrdcache(::Val{false}) = RDCache[] = false -setrdcache(::Val) = throw("Memoization.jl is not loaded. Please load it before setting the cache to true.") -function emptyrdcache end - -getrdcache() = RDCache[] -ADBackend(::Val{:reversediff}) = ReverseDiffAD{getrdcache()} -function _setadbackend(::Val{:reversediff}) - ADBACKEND[] = :reversediff -end - -function gradient_logp( - backend::ReverseDiffAD{false}, - θ::AbstractVector{<:Real}, - vi::VarInfo, - model::Model, - sampler::AbstractSampler = SampleFromPrior(), - context::DynamicPPL.AbstractContext = DynamicPPL.DefaultContext() -) - # Define log density function. - f = Turing.LogDensityFunction(vi, model, sampler, context) - - # Obtain both value and gradient of the log density function. - tp, result = taperesult(f, θ) - ReverseDiff.gradient!(result, tp, θ) - logp = DiffResults.value(result) - ∂logp∂θ = DiffResults.gradient(result) - - return logp, ∂logp∂θ -end - -tape(f, x) = GradientTape(f, x) -taperesult(f, x) = (tape(f, x), DiffResults.GradientResult(x)) - -@require Memoization = "6fafb56a-5788-4b4e-91ca-c0cea6611c73" @eval begin - setrdcache(::Val{true}) = RDCache[] = true - function emptyrdcache() - Memoization.empty_cache!(memoized_taperesult) - return - end - - function gradient_logp( - backend::ReverseDiffAD{true}, - θ::AbstractVector{<:Real}, - vi::VarInfo, - model::Model, - sampler::AbstractSampler = SampleFromPrior(), - context::DynamicPPL.AbstractContext = DynamicPPL.DefaultContext() - ) - # Define log density function. - f = Turing.LogDensityFunction(vi, model, sampler, context) - - # Obtain both value and gradient of the log density function. - ctp, result = memoized_taperesult(f, θ) - ReverseDiff.gradient!(result, ctp, θ) - logp = DiffResults.value(result) - ∂logp∂θ = DiffResults.gradient(result) - - return logp, ∂logp∂θ - end - - # This makes sure we generate a single tape per Turing model and sampler - struct RDTapeKey{F, Tx} - f::F - x::Tx - end - function Memoization._get!(f, d::Dict, keys::Tuple{Tuple{RDTapeKey}, Any}) - key = keys[1][1] - return Memoization._get!(f, d, (key.f, typeof(key.x), size(key.x), Threads.threadid())) - end - memoized_taperesult(f, x) = memoized_taperesult(RDTapeKey(f, x)) - Memoization.@memoize Dict function memoized_taperesult(k::RDTapeKey) - return compiledtape(k.f, k.x), DiffResults.GradientResult(k.x) - end - compiledtape(f, x) = compile(GradientTape(f, x)) -end diff --git a/src/inference/Inference.jl b/src/inference/Inference.jl index 8dfe8306a..2024cff2c 100644 --- a/src/inference/Inference.jl +++ b/src/inference/Inference.jl @@ -28,8 +28,9 @@ import AdvancedHMC; const AHMC = AdvancedHMC import AdvancedMH; const AMH = AdvancedMH import AdvancedPS import BangBang -import ..Essential: getchunksize, getADbackend +import ..Essential: getADbackend import EllipticalSliceSampling +import LogDensityProblems import Random import MCMCChains import StatsBase: predict @@ -76,7 +77,6 @@ abstract type Hamiltonian{AD} <: InferenceAlgorithm end abstract type StaticHamiltonian{AD} <: Hamiltonian{AD} end abstract type AdaptiveHamiltonian{AD} <: Hamiltonian{AD} end -getchunksize(::Type{<:Hamiltonian{AD}}) where AD = getchunksize(AD) getADbackend(::Hamiltonian{AD}) where AD = AD() # Algorithm for sampling from the prior diff --git a/src/inference/hmc.jl b/src/inference/hmc.jl index 34274a32e..c9583754d 100644 --- a/src/inference/hmc.jl +++ b/src/inference/hmc.jl @@ -159,8 +159,11 @@ function DynamicPPL.initialstep( # Create a Hamiltonian. metricT = getmetricT(spl.alg) metric = metricT(length(theta)) - ∂logπ∂θ = gen_∂logπ∂θ(vi, spl, model) - logπ = Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext()) + ℓ = LogDensityProblems.ADgradient( + Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext()) + ) + logπ = Base.Fix1(LogDensityProblems.logdensity, ℓ) + ∂logπ∂θ(x) = LogDensityProblems.logdensity_and_gradient(ℓ, x) hamiltonian = AHMC.Hamiltonian(metric, logπ, ∂logπ∂θ) # Compute phase point z. @@ -262,8 +265,11 @@ end function get_hamiltonian(model, spl, vi, state, n) metric = gen_metric(n, spl, state) - ℓπ = Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext()) - ∂ℓπ∂θ = gen_∂logπ∂θ(vi, spl, model) + ℓ = LogDensityProblems.ADgradient( + Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext()) + ) + ℓπ = Base.Fix1(LogDensityProblems.logdensity, ℓ) + ∂ℓπ∂θ = Base.Fix1(LogDensityProblems.logdensity_and_gradient, ℓ) return AHMC.Hamiltonian(metric, ℓπ, ∂ℓπ∂θ) end @@ -422,19 +428,6 @@ end getstepsize(sampler::Sampler{<:Hamiltonian}, state) = sampler.alg.ϵ getstepsize(sampler::Sampler{<:AdaptiveHamiltonian}, state) = AHMC.getϵ(state.adaptor) -""" - gen_∂logπ∂θ(vi, spl::Sampler, model) - -Generate a function that takes a vector of reals `θ` and compute the logpdf and -gradient at `θ` for the model specified by `(vi, spl, model)`. -""" -function gen_∂logπ∂θ(vi, spl::Sampler, model) - function ∂logπ∂θ(x) - return gradient_logp(x, vi, model, spl) - end - return ∂logπ∂θ -end - gen_metric(dim::Int, spl::Sampler{<:Hamiltonian}, state) = AHMC.UnitEuclideanMetric(dim) function gen_metric(dim::Int, spl::Sampler{<:AdaptiveHamiltonian}, state) return AHMC.renew(state.hamiltonian.metric, AHMC.getM⁻¹(state.adaptor.pc)) diff --git a/src/modes/ModeEstimation.jl b/src/modes/ModeEstimation.jl index 7efb3be7e..184da87f1 100644 --- a/src/modes/ModeEstimation.jl +++ b/src/modes/ModeEstimation.jl @@ -10,6 +10,8 @@ using DynamicPPL: Model, AbstractContext, VarInfo, VarName, _getindex, getsym, getfield, setorder!, get_and_set_val!, istrans +import LogDensityProblems + export constrained_space, MAP, MLE, @@ -107,14 +109,9 @@ end function (f::OptimLogDensity)(F, G, z) if G !== nothing # Calculate negative log joint and its gradient. - sampler = f.sampler - neglogp, ∇neglogp = Turing.gradient_logp( - z, - DynamicPPL.VarInfo(f.varinfo, sampler, z), - f.model, - sampler, - f.context, - ) + # TODO: Make OptimLogDensity already an LogDensityProblems.ADgradient? Allow to specify AD? + ℓ = LogDensityProblems.ADgradient(f) + neglogp, ∇neglogp = LogDensityProblems.logdensity_and_gradient(ℓ, z) # Save the gradient to the pre-allocated array. copyto!(G, ∇neglogp) diff --git a/test/Project.toml b/test/Project.toml index 4f7a48ad2..38818547e 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -11,8 +11,8 @@ DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" -Memoization = "6fafb56a-5788-4b4e-91ca-c0cea6611c73" NamedArrays = "86f7a689-2022-50b4-a561-43c23ac3c673" Optim = "429524aa-4258-5aef-a3af-852621145aeb" Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba" @@ -42,8 +42,8 @@ DynamicHMC = "2.1.6, 3.0" DynamicPPL = "0.20" FiniteDifferences = "0.10.8, 0.11, 0.12" ForwardDiff = "0.10.12" +LogDensityProblems = "0.12" MCMCChains = "5" -Memoization = "0.1.4" NamedArrays = "0.9.4" Optim = "0.22, 1.0" Optimization = "3.5" diff --git a/test/essential/ad.jl b/test/essential/ad.jl index c36006c94..1620b50e4 100644 --- a/test/essential/ad.jl +++ b/test/essential/ad.jl @@ -27,11 +27,19 @@ _x = [_m, _s] grad_FWAD = sort(g(_x)) + ℓ = Turing.LogDensityFunction(vi, ad_test_f, SampleFromPrior(), DynamicPPL.DefaultContext()) x = map(x->Float64(x), vi[SampleFromPrior()]) - ∇E1 = gradient_logp(TrackerAD(), x, vi, ad_test_f)[2] + + trackerℓ = LogDensityProblems.ADgradient(TrackerAD(), ℓ) + @test trackerℓ isa LogDensityProblems.TrackerGradientLogDensity + @test trackerℓ.ℓ === ℓ + ∇E1 = LogDensityProblems.logdensity_and_gradient(trackerℓ, x)[2] @test sort(∇E1) ≈ grad_FWAD atol=1e-9 - ∇E2 = gradient_logp(ZygoteAD(), x, vi, ad_test_f)[2] + zygoteℓ = LogDensityProblems.ADgradient(ZygoteAD(), ℓ) + @test zygoteℓ isa LogDensityProblems.ZygoteGradientLogDensity + @test zygoteℓ.ℓ === ℓ + ∇E2 = LogDensityProblems.logdensity_and_gradient(zygoteℓ, x)[2] @test sort(∇E2) ≈ grad_FWAD atol=1e-9 end @turing_testset "general AD tests" begin @@ -71,19 +79,13 @@ Turing.setadbackend(:tracker) sample(dir(), HMC(0.01, 1), 1000); Turing.setadbackend(:zygote) - sample(dir(), HMC(0.01, 1), 1000); + sample(dir(), HMC(0.01, 1), 1000) Turing.setadbackend(:reversediff) Turing.setrdcache(false) - sample(dir(), HMC(0.01, 1), 1000); + sample(dir(), HMC(0.01, 1), 1000) Turing.setrdcache(true) - sample(dir(), HMC(0.01, 1), 1000); - caches = Memoization.find_caches(Turing.Essential.memoized_taperesult) - @test length(caches) == 1 - @test !isempty(first(values(caches))) - Turing.emptyrdcache() - caches = Memoization.find_caches(Turing.Essential.memoized_taperesult) - @test length(caches) == 1 - @test isempty(first(values(caches))) + sample(dir(), HMC(0.01, 1), 1000) + Turing.setrdcache(false) end # FIXME: For some reasons PDMatDistribution AD tests fail with ReverseDiff @testset "PDMatDistribution AD" begin @@ -163,7 +165,7 @@ @test mean(Array(chn[:sigma])) ≈ std(data) atol=0.5 end - Turing.emptyrdcache() + Turing.setrdcache(false) end @testset "chunksize" begin diff --git a/test/runtests.jl b/test/runtests.jl index ae73d2be3..ed6f19470 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -7,7 +7,6 @@ using DistributionsAD using FiniteDifferences using ForwardDiff using MCMCChains -using Memoization using NamedArrays using Optim using Optimization @@ -38,10 +37,12 @@ using MCMCChains: Chains using StatsFuns: binomlogpdf, logistic, logsumexp using TimerOutputs: TimerOutputs, @timeit using Turing: BinomialLogit, ForwardDiffAD, Sampler, SampleFromPrior, NUTS, TrackerAD, - Variational, ZygoteAD, getspace, gradient_logp + Variational, ZygoteAD, getspace using Turing.Essential: TuringDenseMvNormal, TuringDiagMvNormal using Turing.Variational: TruncatedADAGrad, DecayedADAGrad, AdvancedVI +import LogDensityProblems + setprogress!(false) include(pkgdir(Turing)*"/test/test_utils/AllUtils.jl") diff --git a/test/skipped/unit_test_helper.jl b/test/skipped/unit_test_helper.jl index 0a5c52789..b66002ff6 100644 --- a/test/skipped/unit_test_helper.jl +++ b/test/skipped/unit_test_helper.jl @@ -8,9 +8,13 @@ function test_grad(turing_model, grad_f; trans=Dict()) end d = length(vi.vals) @testset "Gradient using random inputs" begin + ℓ = LogDensityProblems.ADgradient( + TrackerAD(), + Turing.LogDensityFunction(vi, model_f, SampleFromPrior(), DynamicPPL.DefaultContext()), + ) for _ = 1:10000 theta = rand(d) - @test Turing.gradient_logp(TrackerAD(), theta, vi, model_f) == grad_f(theta)[2] + @test LogDensityProblems.logdensity_and_gradient(ℓ, theta) == grad_f(theta)[2] end end end diff --git a/test/test_utils/ad_utils.jl b/test/test_utils/ad_utils.jl index 55fe375c4..7b323c749 100644 --- a/test/test_utils/ad_utils.jl +++ b/test/test_utils/ad_utils.jl @@ -90,10 +90,14 @@ function test_model_ad(model, f, syms::Vector{Symbol}) # Call ForwardDiff's AD directly. grad_FWAD = sort(ForwardDiff.gradient(f, x)) - # Compare with `gradient_logp`. + # Compare with `logdensity_and_gradient`. z = vi[SampleFromPrior()] for chunksize in (0, 1, 10), standardtag in (true, false, 0, 3) - l, ∇E = gradient_logp(ForwardDiffAD{chunksize, standardtag}(), z, vi, model) + ℓ = LogDensityProblems.ADgradient( + ForwardDiffAD{chunksize, standardtag}(), + Turing.LogDensityFunction(vi, model, SampleFromPrior(), DynamicPPL.DefaultContext()), + ) + l, ∇E = LogDensityProblems.logdensity_and_gradient(ℓ, z) # Compare result @test l ≈ logp