@@ -159,7 +159,14 @@ function DynamicPPL.initialstep(
159159 metricT = getmetricT (spl. alg)
160160 metric = metricT (length (theta))
161161 ℓ = LogDensityProblemsAD. ADgradient (
162- Turing. LogDensityFunction (vi, model, spl, DynamicPPL. DefaultContext ())
162+ Turing. LogDensityFunction (
163+ vi,
164+ model,
165+ # Use the leaf-context from the `model` in case the user has
166+ # contextualized the model with something like `PriorContext`
167+ # to sample from the prior.
168+ DynamicPPL. SamplingContext (rng, spl, DynamicPPL. leafcontext (model. context))
169+ )
163170 )
164171 logπ = Base. Fix1 (LogDensityProblems. logdensity, ℓ)
165172 ∂logπ∂θ (x) = LogDensityProblems. logdensity_and_gradient (ℓ, x)
265272function get_hamiltonian (model, spl, vi, state, n)
266273 metric = gen_metric (n, spl, state)
267274 ℓ = LogDensityProblemsAD. ADgradient (
268- Turing. LogDensityFunction (vi, model, spl, DynamicPPL. DefaultContext ())
275+ Turing. LogDensityFunction (
276+ vi,
277+ model,
278+ DynamicPPL. SamplingContext (spl, DynamicPPL. leafcontext (model. context))
279+ )
269280 )
270281 ℓπ = Base. Fix1 (LogDensityProblems. logdensity, ℓ)
271282 ∂ℓπ∂θ = Base. Fix1 (LogDensityProblems. logdensity_and_gradient, ℓ)
@@ -538,7 +549,12 @@ function HMCState(
538549
539550 # Get the initial log pdf and gradient functions.
540551 ∂logπ∂θ = gen_∂logπ∂θ (vi, spl, model)
541- logπ = Turing. LogDensityFunction (vi, model, spl, DynamicPPL. DefaultContext ())
552+ logπ = Turing. LogDensityFunction (
553+ vi,
554+ model,
555+ DynamicPPL. SamplingContext (rng, spl, DynamicPPL. leafcontext (model. context))
556+ )
557+
542558
543559 # Get the metric type.
544560 metricT = getmetricT (spl. alg)
0 commit comments