@@ -79,6 +79,20 @@ function AbstractMCMC.sample(
7979 initial_state= DynamicPPL. loadstate (resume_from),
8080 kwargs... ,
8181)
82+ # LDF needs to be set with SamplingContext, or else samplers cannot
83+ # overload the tilde-pipeline.
84+ if ! (ldf. context isa SamplingContext)
85+ ldf = LogDensityFunction (
86+ ldf. model, ldf. varinfo, SamplingContext (rng, spl); adtype= ldf. adtype
87+ )
88+ end
89+ # Note that, in particular, sampling can mutate the variables in the LDF's
90+ # varinfo (because it ultimately ends up calling `evaluate!!(ldf.model,
91+ # ldf.varinfo)`. Furthermore, the first call to `AbstractMCMC.step` assumes
92+ # that the parameters in the LDF are the initial parameters. So, we need to
93+ # deepcopy the LDF here to ensure that sample(rng, ldf, ...) is
94+ # reproducible.
95+ ldf = deepcopy (ldf)
8296 # TODO : Right now, only generic checks are run. We could in principle
8397 # specialise this to check for e.g. discrete variables with HMC
8498 check_model && DynamicPPL. check_model (ldf. model; error_on_failure= true )
@@ -144,6 +158,20 @@ function AbstractMCMC.sample(
144158 initial_state= DynamicPPL. loadstate (resume_from),
145159 kwargs... ,
146160)
161+ # LDF needs to be set with SamplingContext, or else samplers cannot
162+ # overload the tilde-pipeline.
163+ if ! (ldf. context isa SamplingContext)
164+ ldf = LogDensityFunction (
165+ ldf. model, ldf. varinfo, SamplingContext (rng, spl); adtype= ldf. adtype
166+ )
167+ end
168+ # Note that, in particular, sampling can mutate the variables in the LDF's
169+ # varinfo (because it ultimately ends up calling `evaluate!!(ldf.model,
170+ # ldf.varinfo)`. Furthermore, the first call to `AbstractMCMC.step` assumes
171+ # that the parameters in the LDF are the initial parameters. So, we need to
172+ # deepcopy the LDF here to ensure that sample(rng, ldf, ...) is
173+ # reproducible.
174+ ldf = deepcopy (ldf)
147175 # TODO : Right now, only generic checks are run. We could in principle
148176 # specialise this to check for e.g. discrete variables with HMC
149177 check_model && DynamicPPL. check_model (ldf. model; error_on_failure= true )
0 commit comments