@@ -25,9 +25,8 @@ function TracedModel(
2525 " Sampling with `$(sampler. alg) ` does not support models with keyword arguments. See issue #2007 for more details." ,
2626 )
2727 end
28- return TracedModel {AbstractSampler,AbstractVarInfo,Model,Tuple} (
29- model, sampler, varinfo, (model. f, args... )
30- )
28+ evaluator = (model. f, args... )
29+ return TracedModel (model, sampler, varinfo, evaluator)
3130end
3231
3332function AdvancedPS. advance! (
@@ -59,20 +58,10 @@ function AdvancedPS.reset_logprob!(trace::TracedModel)
5958 return trace
6059end
6160
62- function AdvancedPS. update_rng! (
63- trace:: AdvancedPS.Trace{<:AdvancedPS.LibtaskModel{<:TracedModel}}
64- )
65- # Extract the `args`.
66- args = trace. model. ctask. args
67- # From `args`, extract the `SamplingContext`, which contains the RNG.
68- sampling_context = args[3 ]
69- rng = sampling_context. rng
70- trace. rng = rng
71- return trace
72- end
73-
74- function Libtask. TapedTask (model:: TracedModel , :: Random.AbstractRNG , args... ; kwargs... ) # RNG ?
75- return Libtask. TapedTask (model. evaluator[1 ], model. evaluator[2 : end ]. .. ; kwargs... )
61+ function Libtask. TapedTask (taped_globals, model:: TracedModel ; kwargs... )
62+ return Libtask. TapedTask (
63+ taped_globals, model. evaluator[1 ], model. evaluator[2 : end ]. .. ; kwargs...
64+ )
7665end
7766
7867abstract type ParticleInference <: InferenceAlgorithm end
@@ -402,11 +391,11 @@ end
402391
403392function trace_local_varinfo_maybe (varinfo)
404393 try
405- trace = AdvancedPS . current_trace ()
406- return trace. model. f. varinfo
394+ trace = Libtask . get_taped_globals (Any) . other
395+ return ( trace === nothing ? varinfo : trace . model. f. varinfo) :: AbstractVarInfo
407396 catch e
408397 # NOTE: this heuristic allows Libtask evaluating a model outside a `Trace`.
409- if e == KeyError (:__trace ) || current_task () . storage isa Nothing
398+ if e == KeyError (:task_variable )
410399 return varinfo
411400 else
412401 rethrow (e)
@@ -416,11 +405,10 @@ end
416405
417406function trace_local_rng_maybe (rng:: Random.AbstractRNG )
418407 try
419- trace = AdvancedPS. current_trace ()
420- return trace. rng
408+ return Libtask. get_taped_globals (Any). rng
421409 catch e
422410 # NOTE: this heuristic allows Libtask evaluating a model outside a `Trace`.
423- if e == KeyError (:__trace ) || current_task () . storage isa Nothing
411+ if e == KeyError (:task_variable )
424412 return rng
425413 else
426414 rethrow (e)
@@ -481,6 +469,25 @@ function AdvancedPS.Trace(
481469
482470 tmodel = TracedModel (model, sampler, newvarinfo, rng)
483471 newtrace = AdvancedPS. Trace (tmodel, rng)
484- AdvancedPS. addreference! (newtrace. model. ctask. task, newtrace)
485472 return newtrace
486473end
474+
475+ # We need to tell Libtask which calls may have `produce` calls within them. In practice most
476+ # of these won't be needed, because of inlining and the fact that `might_produce` is only
477+ # called on `:invoke` expressions rather than `:call`s, but since those are implementation
478+ # details of the compiler, we set a bunch of methods as might_produce = true. We start with
479+ # `acclogp_observe!!` which is what calls `produce` and go up the call stack.
480+ Libtask. might_produce (:: Type{<:Tuple{typeof(DynamicPPL.acclogp_observe!!),Vararg}} ) = true
481+ Libtask. might_produce (:: Type{<:Tuple{typeof(DynamicPPL.tilde_observe!!),Vararg}} ) = true
482+ Libtask. might_produce (:: Type{<:Tuple{typeof(DynamicPPL.evaluate!!),Vararg}} ) = true
483+ function Libtask. might_produce (
484+ :: Type{<:Tuple{typeof(DynamicPPL.evaluate_threadsafe!!),Vararg}}
485+ )
486+ return true
487+ end
488+ function Libtask. might_produce (
489+ :: Type{<:Tuple{typeof(DynamicPPL.evaluate_threadunsafe!!),Vararg}}
490+ )
491+ return true
492+ end
493+ Libtask. might_produce (:: Type{<:Tuple{<:DynamicPPL.Model,Vararg}} ) = true
0 commit comments