|
4 | 4 |
|
5 | 5 | ### AdvancedPS models and interface |
6 | 6 |
|
| 7 | +""" |
| 8 | + set_all_del!(vi::AbstractVarInfo) |
| 9 | +
|
| 10 | +Set the "del" flag for all variables in the VarInfo `vi`, thus marking them for |
| 11 | +resampling. |
| 12 | +""" |
| 13 | +function set_all_del!(vi::AbstractVarInfo) |
| 14 | + # TODO(penelopeysm): Instead of being a 'del' flag on the VarInfo, we |
| 15 | + # could either: |
| 16 | + # - keep a boolean 'resample' flag on the trace, or |
| 17 | + # - modify the model context appropriately. |
| 18 | + # However, this refactoring will have to wait until InitContext is |
| 19 | + # merged into DPPL. |
| 20 | + for vn in keys(vi) |
| 21 | + DynamicPPL.set_flag!(vi, vn, "del") |
| 22 | + end |
| 23 | + return nothing |
| 24 | +end |
| 25 | + |
| 26 | +""" |
| 27 | + unset_all_del!(vi::AbstractVarInfo) |
| 28 | +
|
| 29 | +Unset the "del" flag for all variables in the VarInfo `vi`, thus preventing |
| 30 | +them from being resampled. |
| 31 | +""" |
| 32 | +function unset_all_del!(vi::AbstractVarInfo) |
| 33 | + for vn in keys(vi) |
| 34 | + DynamicPPL.unset_flag!(vi, vn, "del") |
| 35 | + end |
| 36 | + return nothing |
| 37 | +end |
| 38 | + |
7 | 39 | struct TracedModel{S<:AbstractSampler,V<:AbstractVarInfo,M<:Model,E<:Tuple} <: |
8 | 40 | AdvancedPS.AbstractGenericModel |
9 | 41 | model::M |
|
33 | 65 | function AdvancedPS.advance!( |
34 | 66 | trace::AdvancedPS.Trace{<:AdvancedPS.LibtaskModel{<:TracedModel}}, isref::Bool=false |
35 | 67 | ) |
36 | | - # We want to increment num produce for the VarInfo stored in the trace. The trace is |
37 | | - # mutable, so we create a new model with the incremented VarInfo and set it in the trace |
38 | | - model = trace.model |
39 | | - model = Accessors.@set model.f.varinfo = DynamicPPL.increment_num_produce!!( |
40 | | - model.f.varinfo |
41 | | - ) |
42 | | - trace.model = model |
43 | 68 | # Make sure we load/reset the rng in the new replaying mechanism |
44 | 69 | isref ? AdvancedPS.load_state!(trace.rng) : AdvancedPS.save_state!(trace.rng) |
45 | 70 | score = consume(trace.model.ctask) |
46 | 71 | return score |
47 | 72 | end |
48 | 73 |
|
49 | 74 | function AdvancedPS.delete_retained!(trace::TracedModel) |
50 | | - DynamicPPL.set_retained_vns_del!(trace.varinfo) |
| 75 | + # This method is called if, during a CSMC update, we perform a resampling |
| 76 | + # and choose the reference particle as the trajectory to carry on from. |
| 77 | + # In such a case, we need to ensure that when we continue sampling (i.e. |
| 78 | + # the next time we hit tilde_assume), we don't use the values in the |
| 79 | + # reference particle but rather sample new values. |
| 80 | + # |
| 81 | + # Here, we indiscriminately set the 'del' flag for all variables in the |
| 82 | + # VarInfo. This is slightly overkill: it is not necessary to set the 'del' |
| 83 | + # flag for variables that were already sampled. However, it allows us to |
| 84 | + # avoid keeping track of which variables were sampled, which leads to many |
| 85 | + # simplifications in the VarInfo data structure. |
| 86 | + set_all_del!(trace.varinfo) |
51 | 87 | return trace |
52 | 88 | end |
53 | 89 |
|
54 | 90 | function AdvancedPS.reset_model(trace::TracedModel) |
55 | | - return Accessors.@set trace.varinfo = DynamicPPL.reset_num_produce!!(trace.varinfo) |
| 91 | + return trace |
56 | 92 | end |
57 | 93 |
|
58 | 94 | function Libtask.TapedTask(taped_globals, model::TracedModel; kwargs...) |
@@ -176,8 +212,7 @@ function DynamicPPL.initialstep( |
176 | 212 | ) |
177 | 213 | # Reset the VarInfo. |
178 | 214 | vi = DynamicPPL.setacc!!(vi, ProduceLogLikelihoodAccumulator()) |
179 | | - vi = DynamicPPL.reset_num_produce!!(vi) |
180 | | - DynamicPPL.set_retained_vns_del!(vi) |
| 215 | + set_all_del!(vi) |
181 | 216 | vi = DynamicPPL.resetlogp!!(vi) |
182 | 217 | vi = DynamicPPL.empty!!(vi) |
183 | 218 |
|
@@ -307,8 +342,7 @@ function DynamicPPL.initialstep( |
307 | 342 | ) |
308 | 343 | vi = DynamicPPL.setacc!!(vi, ProduceLogLikelihoodAccumulator()) |
309 | 344 | # Reset the VarInfo before new sweep |
310 | | - vi = DynamicPPL.reset_num_produce!!(vi) |
311 | | - DynamicPPL.set_retained_vns_del!(vi) |
| 345 | + set_all_del!(vi) |
312 | 346 | vi = DynamicPPL.resetlogp!!(vi) |
313 | 347 |
|
314 | 348 | # Create a new set of particles |
@@ -339,14 +373,15 @@ function AbstractMCMC.step( |
339 | 373 | ) |
340 | 374 | # Reset the VarInfo before new sweep. |
341 | 375 | vi = state.vi |
342 | | - vi = DynamicPPL.reset_num_produce!!(vi) |
| 376 | + vi = DynamicPPL.setacc!!(vi, ProduceLogLikelihoodAccumulator()) |
343 | 377 | vi = DynamicPPL.resetlogp!!(vi) |
344 | 378 |
|
345 | 379 | # Create reference particle for which the samples will be retained. |
| 380 | + unset_all_del!(vi) |
346 | 381 | reference = AdvancedPS.forkr(AdvancedPS.Trace(model, spl, vi, state.rng)) |
347 | 382 |
|
348 | 383 | # For all other particles, do not retain the variables but resample them. |
349 | | - DynamicPPL.set_retained_vns_del!(vi) |
| 384 | + set_all_del!(vi) |
350 | 385 |
|
351 | 386 | # Create a new set of particles. |
352 | 387 | num_particles = spl.alg.nparticles |
@@ -451,12 +486,11 @@ function DynamicPPL.assume( |
451 | 486 | vi = push!!(vi, vn, r, dist) |
452 | 487 | elseif DynamicPPL.is_flagged(vi, vn, "del") |
453 | 488 | DynamicPPL.unset_flag!(vi, vn, "del") # Reference particle parent |
454 | | - r = rand(trng, dist) |
455 | | - vi[vn] = DynamicPPL.tovec(r) |
456 | 489 | # TODO(mhauru): |
457 | 490 | # The below is the only line that differs from assume called on SampleFromPrior. |
458 | | - # Could we just call assume on SampleFromPrior and then `setorder!!` after that? |
459 | | - vi = DynamicPPL.setorder!!(vi, vn, DynamicPPL.get_num_produce(vi)) |
| 491 | + # Could we just call assume on SampleFromPrior with a specific rng? |
| 492 | + r = rand(trng, dist) |
| 493 | + vi[vn] = DynamicPPL.tovec(r) |
460 | 494 | else |
461 | 495 | r = vi[vn] |
462 | 496 | end |
@@ -498,8 +532,6 @@ function AdvancedPS.Trace( |
498 | 532 | rng::AdvancedPS.TracedRNG, |
499 | 533 | ) |
500 | 534 | newvarinfo = deepcopy(varinfo) |
501 | | - newvarinfo = DynamicPPL.reset_num_produce!!(newvarinfo) |
502 | | - |
503 | 535 | tmodel = TracedModel(model, sampler, newvarinfo, rng) |
504 | 536 | newtrace = AdvancedPS.Trace(tmodel, rng) |
505 | 537 | return newtrace |
|
0 commit comments