Skip to content

Commit 92cbb20

Browse files
committed
Implement ParamsWithStats for FastLDF
1 parent 535ce4f commit 92cbb20

File tree

4 files changed

+132
-25
lines changed

4 files changed

+132
-25
lines changed

src/DynamicPPL.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,14 +202,14 @@ include("logdensityfunction.jl")
202202
include("model_utils.jl")
203203
include("extract_priors.jl")
204204
include("values_as_in_model.jl")
205+
include("experimental.jl")
205206
include("chains.jl")
206207
include("bijector.jl")
207208

208209
include("debug_utils.jl")
209210
using .DebugUtils
210211
include("test_utils.jl")
211212

212-
include("experimental.jl")
213213
include("deprecated.jl")
214214

215215
if isdefined(Base.Experimental, :register_error_hint)

src/chains.jl

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,3 +133,60 @@ function ParamsWithStats(
133133
end
134134
return ParamsWithStats(params, stats)
135135
end
136+
137+
"""
138+
ParamsWithStats(
139+
param_vector::AbstractVector,
140+
ldf::DynamicPPL.Experimental.FastLDF,
141+
stats::NamedTuple=NamedTuple();
142+
include_colon_eq::Bool=true,
143+
include_log_probs::Bool=true,
144+
)
145+
146+
Generate a `ParamsWithStats` by re-evaluating the given `ldf` with the provided
147+
`param_vector`.
148+
149+
This method is intended to replace the old method of obtaining parameters and statistics
150+
via `unflatten` plus re-evaluation. It is faster for two reasons:
151+
152+
1. It does not rely on `deepcopy`-ing the VarInfo object (this used to be mandatory as
153+
otherwise re-evaluation would mutate the VarInfo, rendering it unusable for subsequent
154+
MCMC iterations).
155+
2. The re-evaluation is faster as it uses `OnlyAccsVarInfo`.
156+
"""
157+
function ParamsWithStats(
158+
param_vector::AbstractVector,
159+
ldf::DynamicPPL.Experimental.FastLDF,
160+
stats::NamedTuple=NamedTuple();
161+
include_colon_eq::Bool=true,
162+
include_log_probs::Bool=true,
163+
)
164+
strategy = InitFromParams(
165+
VectorWithRanges(ldf._iden_varname_ranges, ldf._varname_ranges, param_vector),
166+
nothing,
167+
)
168+
accs = if include_log_probs
169+
(
170+
DynamicPPL.LogPriorAccumulator(),
171+
DynamicPPL.LogLikelihoodAccumulator(),
172+
DynamicPPL.ValuesAsInModelAccumulator(include_colon_eq),
173+
)
174+
else
175+
(DynamicPPL.ValuesAsInModelAccumulator(include_colon_eq),)
176+
end
177+
_, vi = DynamicPPL.Experimental.fast_evaluate!!(
178+
ldf.model, strategy, AccumulatorTuple(accs)
179+
)
180+
params = DynamicPPL.getacc(vi, Val(:ValuesAsInModel)).values
181+
if include_log_probs
182+
stats = merge(
183+
stats,
184+
(
185+
logprior=DynamicPPL.getlogprior(vi),
186+
loglikelihood=DynamicPPL.getloglikelihood(vi),
187+
lp=DynamicPPL.getlogjoint(vi),
188+
),
189+
)
190+
end
191+
return ParamsWithStats(params, stats)
192+
end

src/fasteval.jl

Lines changed: 47 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ using DynamicPPL:
33
AccumulatorTuple,
44
InitContext,
55
InitFromParams,
6+
AbstractInitStrategy,
67
LogJacobianAccumulator,
78
LogLikelihoodAccumulator,
89
LogPriorAccumulator,
@@ -28,6 +29,49 @@ using LogDensityProblems: LogDensityProblems
2829
import DifferentiationInterface as DI
2930
using Random: Random
3031

32+
"""
33+
DynamicPPL.Experimental.fast_evaluate!!(
34+
[rng::Random.AbstractRNG,]
35+
model::Model,
36+
strategy::AbstractInitStrategy,
37+
accs::AccumulatorTuple, params::AbstractVector{<:Real}
38+
)
39+
40+
Evaluate a model using parameters obtained via `strategy`, and only computing the results in
41+
the provided accumulators.
42+
"""
43+
@inline function fast_evaluate!!(
44+
rng::Random.AbstractRNG,
45+
model::Model,
46+
strategy::AbstractInitStrategy,
47+
accs::AccumulatorTuple,
48+
)
49+
ctx = InitContext(rng, strategy)
50+
model = DynamicPPL.setleafcontext(model, ctx)
51+
# Calling `evaluate!!` would be fine, but would lead to an extra call to resetaccs!!,
52+
# which is unnecessary. So we shortcircuit this by simply calling `_evaluate!!`
53+
# directly. To preserve thread-safety we need to reproduce the ThreadSafeVarInfo logic
54+
# here.
55+
# TODO(penelopeysm): This should _not_ check Threads.nthreads(). I still don't know what
56+
# it _should_ do, but this is wrong regardless.
57+
# https://github.com/TuringLang/DynamicPPL.jl/issues/1086
58+
vi = if Threads.nthreads() > 1
59+
param_eltype = DynamicPPL.get_param_eltype(strategy)
60+
accs = map(accs) do acc
61+
DynamicPPL.convert_eltype(float_type_with_fallback(param_eltype), acc)
62+
end
63+
ThreadSafeVarInfo(OnlyAccsVarInfo(accs))
64+
else
65+
OnlyAccsVarInfo(accs)
66+
end
67+
return DynamicPPL._evaluate!!(model, vi)
68+
end
69+
@inline function fast_evaluate!!(
70+
model::Model, strategy::AbstractInitStrategy, accs::AccumulatorTuple
71+
)
72+
return fast_evaluate!!(Random.default_rng(), model, strategy, accs)
73+
end
74+
3175
"""
3276
FastLDF(
3377
model::Model,
@@ -211,31 +255,11 @@ struct FastLogDensityAt{M<:Model,F<:Function,N<:NamedTuple}
211255
varname_ranges::Dict{VarName,RangeAndLinked}
212256
end
213257
function (f::FastLogDensityAt)(params::AbstractVector{<:Real})
214-
ctx = InitContext(
215-
Random.default_rng(),
216-
InitFromParams(
217-
VectorWithRanges(f.iden_varname_ranges, f.varname_ranges, params), nothing
218-
),
258+
strategy = InitFromParams(
259+
VectorWithRanges(f.iden_varname_ranges, f.varname_ranges, params), nothing
219260
)
220-
model = DynamicPPL.setleafcontext(f.model, ctx)
221261
accs = fast_ldf_accs(f.getlogdensity)
222-
# Calling `evaluate!!` would be fine, but would lead to an extra call to resetaccs!!,
223-
# which is unnecessary. So we shortcircuit this by simply calling `_evaluate!!`
224-
# directly. To preserve thread-safety we need to reproduce the ThreadSafeVarInfo logic
225-
# here.
226-
# TODO(penelopeysm): This should _not_ check Threads.nthreads(). I still don't know what
227-
# it _should_ do, but this is wrong regardless.
228-
# https://github.com/TuringLang/DynamicPPL.jl/issues/1086
229-
vi = if Threads.nthreads() > 1
230-
accs = map(
231-
acc -> DynamicPPL.convert_eltype(float_type_with_fallback(eltype(params)), acc),
232-
accs,
233-
)
234-
ThreadSafeVarInfo(OnlyAccsVarInfo(accs))
235-
else
236-
OnlyAccsVarInfo(accs)
237-
end
238-
_, vi = DynamicPPL._evaluate!!(model, vi)
262+
_, vi = fast_evaluate!!(f.model, strategy, accs)
239263
return f.getlogdensity(vi)
240264
end
241265

test/chains.jl

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using DynamicPPL
44
using Distributions
55
using Test
66

7-
@testset "ParamsWithStats" begin
7+
@testset "ParamsWithStats from VarInfo" begin
88
@model function f(z)
99
x ~ Normal()
1010
y := x + 1
@@ -66,4 +66,30 @@ using Test
6666
end
6767
end
6868

69+
@testset "ParamsWithStats from FastLDF" begin
70+
@testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS
71+
unlinked_vi = VarInfo(m)
72+
@testset "$islinked" for islinked in (false, true)
73+
vi = if islinked
74+
DynamicPPL.link!!(unlinked_vi, m)
75+
else
76+
unlinked_vi
77+
end
78+
params = [x for x in vi[:]]
79+
80+
# Get the ParamsWithStats using FastLDF
81+
fldf = DynamicPPL.Experimental.FastLDF(m, getlogjoint, vi)
82+
ps = ParamsWithStats(params, fldf)
83+
84+
# Check that length of parameters is as expected
85+
@test length(ps.params) == length(keys(vi))
86+
87+
# Iterate over all variables to check that their values match
88+
for vn in keys(vi)
89+
@test ps.params[vn] == vi[vn]
90+
end
91+
end
92+
end
93+
end
94+
6995
end # module

0 commit comments

Comments
 (0)