From fa32c88c390efbd75e12289a80f4c7e215b6c453 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 13 Nov 2025 16:46:06 +0000 Subject: [PATCH 1/3] Implement `ParamsWithStats` for `FastLDF` --- src/DynamicPPL.jl | 2 +- src/chains.jl | 57 ++++++++++++++++++++++++++++++++++++++ src/fasteval.jl | 70 +++++++++++++++++++++++++++++++---------------- test/chains.jl | 28 ++++++++++++++++++- 4 files changed, 132 insertions(+), 25 deletions(-) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index e9b902363..6d3900e91 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -202,6 +202,7 @@ include("logdensityfunction.jl") include("model_utils.jl") include("extract_priors.jl") include("values_as_in_model.jl") +include("experimental.jl") include("chains.jl") include("bijector.jl") @@ -209,7 +210,6 @@ include("debug_utils.jl") using .DebugUtils include("test_utils.jl") -include("experimental.jl") include("deprecated.jl") if isdefined(Base.Experimental, :register_error_hint) diff --git a/src/chains.jl b/src/chains.jl index 2b5976b9b..892423822 100644 --- a/src/chains.jl +++ b/src/chains.jl @@ -133,3 +133,60 @@ function ParamsWithStats( end return ParamsWithStats(params, stats) end + +""" + ParamsWithStats( + param_vector::AbstractVector, + ldf::DynamicPPL.Experimental.FastLDF, + stats::NamedTuple=NamedTuple(); + include_colon_eq::Bool=true, + include_log_probs::Bool=true, + ) + +Generate a `ParamsWithStats` by re-evaluating the given `ldf` with the provided +`param_vector`. + +This method is intended to replace the old method of obtaining parameters and statistics +via `unflatten` plus re-evaluation. It is faster for two reasons: + +1. It does not rely on `deepcopy`-ing the VarInfo object (this used to be mandatory as + otherwise re-evaluation would mutate the VarInfo, rendering it unusable for subsequent + MCMC iterations). +2. The re-evaluation is faster as it uses `OnlyAccsVarInfo`. +""" +function ParamsWithStats( + param_vector::AbstractVector, + ldf::DynamicPPL.Experimental.FastLDF, + stats::NamedTuple=NamedTuple(); + include_colon_eq::Bool=true, + include_log_probs::Bool=true, +) + strategy = InitFromParams( + VectorWithRanges(ldf._iden_varname_ranges, ldf._varname_ranges, param_vector), + nothing, + ) + accs = if include_log_probs + ( + DynamicPPL.LogPriorAccumulator(), + DynamicPPL.LogLikelihoodAccumulator(), + DynamicPPL.ValuesAsInModelAccumulator(include_colon_eq), + ) + else + (DynamicPPL.ValuesAsInModelAccumulator(include_colon_eq),) + end + _, vi = DynamicPPL.Experimental.fast_evaluate!!( + ldf.model, strategy, AccumulatorTuple(accs) + ) + params = DynamicPPL.getacc(vi, Val(:ValuesAsInModel)).values + if include_log_probs + stats = merge( + stats, + ( + logprior=DynamicPPL.getlogprior(vi), + loglikelihood=DynamicPPL.getloglikelihood(vi), + lp=DynamicPPL.getlogjoint(vi), + ), + ) + end + return ParamsWithStats(params, stats) +end diff --git a/src/fasteval.jl b/src/fasteval.jl index 4f402f4a8..8d165a4ef 100644 --- a/src/fasteval.jl +++ b/src/fasteval.jl @@ -3,6 +3,7 @@ using DynamicPPL: AccumulatorTuple, InitContext, InitFromParams, + AbstractInitStrategy, LogJacobianAccumulator, LogLikelihoodAccumulator, LogPriorAccumulator, @@ -28,6 +29,49 @@ using LogDensityProblems: LogDensityProblems import DifferentiationInterface as DI using Random: Random +""" + DynamicPPL.Experimental.fast_evaluate!!( + [rng::Random.AbstractRNG,] + model::Model, + strategy::AbstractInitStrategy, + accs::AccumulatorTuple, params::AbstractVector{<:Real} + ) + +Evaluate a model using parameters obtained via `strategy`, and only computing the results in +the provided accumulators. +""" +@inline function fast_evaluate!!( + rng::Random.AbstractRNG, + model::Model, + strategy::AbstractInitStrategy, + accs::AccumulatorTuple, +) + ctx = InitContext(rng, strategy) + model = DynamicPPL.setleafcontext(model, ctx) + # Calling `evaluate!!` would be fine, but would lead to an extra call to resetaccs!!, + # which is unnecessary. So we shortcircuit this by simply calling `_evaluate!!` + # directly. To preserve thread-safety we need to reproduce the ThreadSafeVarInfo logic + # here. + # TODO(penelopeysm): This should _not_ check Threads.nthreads(). I still don't know what + # it _should_ do, but this is wrong regardless. + # https://github.com/TuringLang/DynamicPPL.jl/issues/1086 + vi = if Threads.nthreads() > 1 + param_eltype = DynamicPPL.get_param_eltype(strategy) + accs = map(accs) do acc + DynamicPPL.convert_eltype(float_type_with_fallback(param_eltype), acc) + end + ThreadSafeVarInfo(OnlyAccsVarInfo(accs)) + else + OnlyAccsVarInfo(accs) + end + return DynamicPPL._evaluate!!(model, vi) +end +@inline function fast_evaluate!!( + model::Model, strategy::AbstractInitStrategy, accs::AccumulatorTuple +) + return fast_evaluate!!(Random.default_rng(), model, strategy, accs) +end + """ FastLDF( model::Model, @@ -213,31 +257,11 @@ struct FastLogDensityAt{M<:Model,F<:Function,N<:NamedTuple} varname_ranges::Dict{VarName,RangeAndLinked} end function (f::FastLogDensityAt)(params::AbstractVector{<:Real}) - ctx = InitContext( - Random.default_rng(), - InitFromParams( - VectorWithRanges(f.iden_varname_ranges, f.varname_ranges, params), nothing - ), + strategy = InitFromParams( + VectorWithRanges(f.iden_varname_ranges, f.varname_ranges, params), nothing ) - model = DynamicPPL.setleafcontext(f.model, ctx) accs = fast_ldf_accs(f.getlogdensity) - # Calling `evaluate!!` would be fine, but would lead to an extra call to resetaccs!!, - # which is unnecessary. So we shortcircuit this by simply calling `_evaluate!!` - # directly. To preserve thread-safety we need to reproduce the ThreadSafeVarInfo logic - # here. - # TODO(penelopeysm): This should _not_ check Threads.nthreads(). I still don't know what - # it _should_ do, but this is wrong regardless. - # https://github.com/TuringLang/DynamicPPL.jl/issues/1086 - vi = if Threads.nthreads() > 1 - accs = map( - acc -> DynamicPPL.convert_eltype(float_type_with_fallback(eltype(params)), acc), - accs, - ) - ThreadSafeVarInfo(OnlyAccsVarInfo(accs)) - else - OnlyAccsVarInfo(accs) - end - _, vi = DynamicPPL._evaluate!!(model, vi) + _, vi = fast_evaluate!!(f.model, strategy, accs) return f.getlogdensity(vi) end diff --git a/test/chains.jl b/test/chains.jl index ab0ff4475..43b877d62 100644 --- a/test/chains.jl +++ b/test/chains.jl @@ -4,7 +4,7 @@ using DynamicPPL using Distributions using Test -@testset "ParamsWithStats" begin +@testset "ParamsWithStats from VarInfo" begin @model function f(z) x ~ Normal() y := x + 1 @@ -66,4 +66,30 @@ using Test end end +@testset "ParamsWithStats from FastLDF" begin + @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS + unlinked_vi = VarInfo(m) + @testset "$islinked" for islinked in (false, true) + vi = if islinked + DynamicPPL.link!!(unlinked_vi, m) + else + unlinked_vi + end + params = [x for x in vi[:]] + + # Get the ParamsWithStats using FastLDF + fldf = DynamicPPL.Experimental.FastLDF(m, getlogjoint, vi) + ps = ParamsWithStats(params, fldf) + + # Check that length of parameters is as expected + @test length(ps.params) == length(keys(vi)) + + # Iterate over all variables to check that their values match + for vn in keys(vi) + @test ps.params[vn] == vi[vn] + end + end + end +end + end # module From 2bf5b185e9313b9dae5f64ba2adc2847470be672 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 13 Nov 2025 16:52:13 +0000 Subject: [PATCH 2/3] Add comments --- src/fasteval.jl | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/fasteval.jl b/src/fasteval.jl index 8d165a4ef..722760fa1 100644 --- a/src/fasteval.jl +++ b/src/fasteval.jl @@ -39,8 +39,18 @@ using Random: Random Evaluate a model using parameters obtained via `strategy`, and only computing the results in the provided accumulators. + +It is assumed that the accumulators passed in have been initialised to appropriate values, +as this function will not reset them. The default constructors for each accumulator will do +this for you correctly. + +Returns a tuple of the model's return value, plus an `OnlyAccsVarInfo`. Note that the `accs` +argument may be mutated (depending on how the accumulators are implemented); hence the `!!` +in the function name. """ @inline function fast_evaluate!!( + # Note that this `@inline` is mandatory for performance. If it's not inlined, it leads + # to extra allocations (even for trivial models) and much slower runtime. rng::Random.AbstractRNG, model::Model, strategy::AbstractInitStrategy, @@ -69,6 +79,7 @@ end @inline function fast_evaluate!!( model::Model, strategy::AbstractInitStrategy, accs::AccumulatorTuple ) + # This `@inline` is also mandatory for performance return fast_evaluate!!(Random.default_rng(), model, strategy, accs) end From 2fad97ba289bef0414cca9b80068d4bdf243fba2 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Thu, 13 Nov 2025 18:42:28 +0000 Subject: [PATCH 3/3] Implement `bundle_samples` for ParamsWithStats -> MCMCChains --- Project.toml | 2 +- ext/DynamicPPLMCMCChainsExt.jl | 39 ++++++++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 1b5e52492..3e99afd62 100644 --- a/Project.toml +++ b/Project.toml @@ -41,7 +41,7 @@ DynamicPPLChainRulesCoreExt = ["ChainRulesCore"] DynamicPPLEnzymeCoreExt = ["EnzymeCore"] DynamicPPLForwardDiffExt = ["ForwardDiff"] DynamicPPLJETExt = ["JET"] -DynamicPPLMCMCChainsExt = ["MCMCChains"] +DynamicPPLMCMCChainsExt = ["MCMCChains", "Statistics"] DynamicPPLMarginalLogDensitiesExt = ["MarginalLogDensities"] DynamicPPLMooncakeExt = ["Mooncake"] diff --git a/ext/DynamicPPLMCMCChainsExt.jl b/ext/DynamicPPLMCMCChainsExt.jl index d8c343917..a927e09d5 100644 --- a/ext/DynamicPPLMCMCChainsExt.jl +++ b/ext/DynamicPPLMCMCChainsExt.jl @@ -2,6 +2,7 @@ module DynamicPPLMCMCChainsExt using DynamicPPL: DynamicPPL, AbstractPPL, AbstractMCMC using MCMCChains: MCMCChains +using Statistics: mean _has_varname_to_symbol(info::NamedTuple{names}) where {names} = :varname_to_symbol in names @@ -140,6 +141,44 @@ function AbstractMCMC.to_samples( end end +function AbstractMCMC.bundle_samples( + ts::Vector{<:DynamicPPL.ParamsWithStats}, + model::DynamicPPL.Model, + spl::AbstractMCMC.AbstractSampler, + state, + chain_type::Type{MCMCChains.Chains}; + save_state=false, + stats=missing, + sort_chain=false, + discard_initial=0, + thinning=1, + kwargs..., +) + # Construct the 'bare' chain first + bare_chain = AbstractMCMC.from_samples(MCMCChains.Chains, reshape(ts, :, 1)) + + # Add additional MCMC-specific info + info = bare_chain.info + if save_state + info = merge(info, (model=model, sampler=spl, samplerstate=state)) + end + if !ismissing(stats) + info = merge(info, (start_time=stats.start, stop_time=stats.stop)) + end + + # Reconstruct the chain with the extra information + # Yeah, this is quite ugly. Blame MCMCChains. + chain = MCMCChains.Chains( + bare_chain.value.data, + names(bare_chain), + bare_chain.name_map; + info=info, + start=discard_initial + 1, + thin=thinning, + ) + return sort_chain ? sort(chain) : chain +end + """ predict([rng::AbstractRNG,] model::Model, chain::MCMCChains.Chains; include_all=false)