@@ -75,6 +75,20 @@ function AbstractMCMC.sample(
7575 initial_state= DynamicPPL. loadstate (resume_from),
7676 kwargs... ,
7777)
78+ # LDF needs to be set with SamplingContext, or else samplers cannot
79+ # overload the tilde-pipeline.
80+ if ! (ldf. context isa SamplingContext)
81+ ldf = LogDensityFunction (
82+ ldf. model, ldf. varinfo, SamplingContext (rng, spl); adtype= ldf. adtype
83+ )
84+ end
85+ # Note that, in particular, sampling can mutate the variables in the LDF's
86+ # varinfo (because it ultimately ends up calling `evaluate!!(ldf.model,
87+ # ldf.varinfo)`. Furthermore, the first call to `AbstractMCMC.step` assumes
88+ # that the parameters in the LDF are the initial parameters. So, we need to
89+ # deepcopy the LDF here to ensure that sample(rng, ldf, ...) is
90+ # reproducible.
91+ ldf = deepcopy (ldf)
7892 # TODO : Right now, only generic checks are run. We could in principle
7993 # specialise this to check for e.g. discrete variables with HMC
8094 check_model && DynamicPPL. check_model (ldf. model; error_on_failure= true )
@@ -140,6 +154,20 @@ function AbstractMCMC.sample(
140154 initial_state= DynamicPPL. loadstate (resume_from),
141155 kwargs... ,
142156)
157+ # LDF needs to be set with SamplingContext, or else samplers cannot
158+ # overload the tilde-pipeline.
159+ if ! (ldf. context isa SamplingContext)
160+ ldf = LogDensityFunction (
161+ ldf. model, ldf. varinfo, SamplingContext (rng, spl); adtype= ldf. adtype
162+ )
163+ end
164+ # Note that, in particular, sampling can mutate the variables in the LDF's
165+ # varinfo (because it ultimately ends up calling `evaluate!!(ldf.model,
166+ # ldf.varinfo)`. Furthermore, the first call to `AbstractMCMC.step` assumes
167+ # that the parameters in the LDF are the initial parameters. So, we need to
168+ # deepcopy the LDF here to ensure that sample(rng, ldf, ...) is
169+ # reproducible.
170+ ldf = deepcopy (ldf)
143171 # TODO : Right now, only generic checks are run. We could in principle
144172 # specialise this to check for e.g. discrete variables with HMC
145173 check_model && DynamicPPL. check_model (ldf. model; error_on_failure= true )
0 commit comments