@@ -135,23 +135,6 @@ function SMC(threshold::Real)
135135 return SMC (AdvancedPS. resample_systematic, threshold)
136136end
137137
138- struct SMCTransition{T,F<: AbstractFloat } <: AbstractTransition
139- " The parameters for any given sample."
140- θ:: T
141- " The joint log probability of the sample (NOTE: does not work, always set to zero)."
142- lp:: F
143- " The weight of the particle the sample was retrieved from."
144- weight:: F
145- end
146-
147- function SMCTransition (model:: DynamicPPL.Model , vi:: AbstractVarInfo , weight)
148- theta = getparams (model, vi)
149- lp = DynamicPPL. getlogjoint_internal (vi)
150- return SMCTransition (theta, lp, weight)
151- end
152-
153- getstats_with_lp (t:: SMCTransition ) = (lp= t. lp, weight= t. weight)
154-
155138struct SMCState{P,F<: AbstractFloat }
156139 particles:: P
157140 particleindex:: Int
@@ -228,7 +211,8 @@ function DynamicPPL.initialstep(
228211 weight = AdvancedPS. getweight (particles, 1 )
229212
230213 # Compute the first transition and the first state.
231- transition = SMCTransition (model, particle. model. f. varinfo, weight)
214+ stats = (; weight= weight, logevidence= logevidence)
215+ transition = Transition (model, particle. model. f. varinfo, stats)
232216 state = SMCState (particles, 2 , logevidence)
233217
234218 return transition, state
@@ -246,7 +230,8 @@ function AbstractMCMC.step(
246230 weight = AdvancedPS. getweight (particles, index)
247231
248232 # Compute the transition and the next state.
249- transition = SMCTransition (model, particle. model. f. varinfo, weight)
233+ stats = (; weight= weight, logevidence= state. average_logevidence)
234+ transition = Transition (model, particle. model. f. varinfo, stats)
250235 nextstate = SMCState (state. particles, index + 1 , state. average_logevidence)
251236
252237 return transition, nextstate
@@ -300,32 +285,28 @@ Equivalent to [`PG`](@ref).
300285"""
301286const CSMC = PG # type alias of PG as Conditional SMC
302287
303- struct PGTransition{T,F<: AbstractFloat } <: AbstractTransition
304- " The parameters for any given sample."
305- θ:: T
306- " The joint log probability of the sample (NOTE: does not work, always set to zero)."
307- lp:: F
308- " The log evidence of the sample."
309- logevidence:: F
310- end
311-
312288struct PGState
313289 vi:: AbstractVarInfo
314290 rng:: Random.AbstractRNG
315291end
316292
317293get_varinfo (state:: PGState ) = state. vi
318294
319- function PGTransition (model:: DynamicPPL.Model , vi:: AbstractVarInfo , logevidence)
320- theta = getparams (model, vi)
321- lp = DynamicPPL. getlogjoint_internal (vi)
322- return PGTransition (theta, lp, logevidence)
323- end
324-
325- getstats_with_lp (t:: PGTransition ) = (lp= t. lp, logevidence= t. logevidence)
326-
327- function getlogevidence (samples, sampler:: Sampler{<:PG} , state:: PGState )
328- return mean (x. logevidence for x in samples)
295+ function getlogevidence (
296+ transitions:: AbstractVector{<:Turing.Inference.Transition} ,
297+ sampler:: Sampler{<:PG} ,
298+ state:: PGState ,
299+ )
300+ logevidences = map (transitions) do t
301+ if haskey (t. stat, :logevidence )
302+ return t. stat. logevidence
303+ else
304+ # This should not really happen, but if it does we can handle it
305+ # gracefully
306+ return missing
307+ end
308+ end
309+ return mean (logevidences)
329310end
330311
331312function DynamicPPL. initialstep (
@@ -357,7 +338,7 @@ function DynamicPPL.initialstep(
357338
358339 # Compute the first transition.
359340 _vi = reference. model. f. varinfo
360- transition = PGTransition (model, _vi, logevidence)
341+ transition = Transition (model, _vi, (; logevidence= logevidence) )
361342
362343 return transition, PGState (_vi, reference. rng)
363344end
@@ -397,7 +378,7 @@ function AbstractMCMC.step(
397378
398379 # Compute the transition.
399380 _vi = newreference. model. f. varinfo
400- transition = PGTransition (model, _vi, logevidence)
381+ transition = Transition (model, _vi, (; logevidence= logevidence) )
401382
402383 return transition, PGState (_vi, newreference. rng)
403384end
0 commit comments