@@ -3,6 +3,7 @@ using DynamicPPL:
33 AccumulatorTuple,
44 InitContext,
55 InitFromParams,
6+ AbstractInitStrategy,
67 LogJacobianAccumulator,
78 LogLikelihoodAccumulator,
89 LogPriorAccumulator,
@@ -28,6 +29,49 @@ using LogDensityProblems: LogDensityProblems
2829import DifferentiationInterface as DI
2930using Random: Random
3031
32+ """
33+ DynamicPPL.Experimental.fast_evaluate!!(
34+ [rng::Random.AbstractRNG,]
35+ model::Model,
36+ strategy::AbstractInitStrategy,
37+ accs::AccumulatorTuple, params::AbstractVector{<:Real}
38+ )
39+
40+ Evaluate a model using parameters obtained via `strategy`, and only computing the results in
41+ the provided accumulators.
42+ """
43+ @inline function fast_evaluate!! (
44+ rng:: Random.AbstractRNG ,
45+ model:: Model ,
46+ strategy:: AbstractInitStrategy ,
47+ accs:: AccumulatorTuple ,
48+ )
49+ ctx = InitContext (rng, strategy)
50+ model = DynamicPPL. setleafcontext (model, ctx)
51+ # Calling `evaluate!!` would be fine, but would lead to an extra call to resetaccs!!,
52+ # which is unnecessary. So we shortcircuit this by simply calling `_evaluate!!`
53+ # directly. To preserve thread-safety we need to reproduce the ThreadSafeVarInfo logic
54+ # here.
55+ # TODO (penelopeysm): This should _not_ check Threads.nthreads(). I still don't know what
56+ # it _should_ do, but this is wrong regardless.
57+ # https://github.com/TuringLang/DynamicPPL.jl/issues/1086
58+ vi = if Threads. nthreads () > 1
59+ param_eltype = DynamicPPL. get_param_eltype (strategy)
60+ accs = map (accs) do acc
61+ DynamicPPL. convert_eltype (float_type_with_fallback (param_eltype), acc)
62+ end
63+ ThreadSafeVarInfo (OnlyAccsVarInfo (accs))
64+ else
65+ OnlyAccsVarInfo (accs)
66+ end
67+ return DynamicPPL. _evaluate!! (model, vi)
68+ end
69+ @inline function fast_evaluate!! (
70+ model:: Model , strategy:: AbstractInitStrategy , accs:: AccumulatorTuple
71+ )
72+ return fast_evaluate!! (Random. default_rng (), model, strategy, accs)
73+ end
74+
3175"""
3276 FastLDF(
3377 model::Model,
@@ -211,31 +255,11 @@ struct FastLogDensityAt{M<:Model,F<:Function,N<:NamedTuple}
211255 varname_ranges:: Dict{VarName,RangeAndLinked}
212256end
213257function (f:: FastLogDensityAt )(params:: AbstractVector{<:Real} )
214- ctx = InitContext (
215- Random. default_rng (),
216- InitFromParams (
217- VectorWithRanges (f. iden_varname_ranges, f. varname_ranges, params), nothing
218- ),
258+ strategy = InitFromParams (
259+ VectorWithRanges (f. iden_varname_ranges, f. varname_ranges, params), nothing
219260 )
220- model = DynamicPPL. setleafcontext (f. model, ctx)
221261 accs = fast_ldf_accs (f. getlogdensity)
222- # Calling `evaluate!!` would be fine, but would lead to an extra call to resetaccs!!,
223- # which is unnecessary. So we shortcircuit this by simply calling `_evaluate!!`
224- # directly. To preserve thread-safety we need to reproduce the ThreadSafeVarInfo logic
225- # here.
226- # TODO (penelopeysm): This should _not_ check Threads.nthreads(). I still don't know what
227- # it _should_ do, but this is wrong regardless.
228- # https://github.com/TuringLang/DynamicPPL.jl/issues/1086
229- vi = if Threads. nthreads () > 1
230- accs = map (
231- acc -> DynamicPPL. convert_eltype (float_type_with_fallback (eltype (params)), acc),
232- accs,
233- )
234- ThreadSafeVarInfo (OnlyAccsVarInfo (accs))
235- else
236- OnlyAccsVarInfo (accs)
237- end
238- _, vi = DynamicPPL. _evaluate!! (model, vi)
262+ _, vi = fast_evaluate!! (f. model, strategy, accs)
239263 return f. getlogdensity (vi)
240264end
241265
0 commit comments