From 8518ef47a522eb811fcc56f48b6f4e3d9788024e Mon Sep 17 00:00:00 2001 From: FredericWantiez Date: Fri, 3 Nov 2023 21:38:00 +0000 Subject: [PATCH] Expose API for Turing integration (#90) * Move things around * Update Project.toml * Add test, fix comments --------- Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> --- Project.toml | 2 +- ext/AdvancedPSLibtaskExt.jl | 55 ++++++++++++++++++------------------- src/model.jl | 31 +++++++++++++++------ src/rng.jl | 2 ++ test/container.jl | 17 +++++++++++- 5 files changed, 67 insertions(+), 40 deletions(-) diff --git a/Project.toml b/Project.toml index 09862f4e..3b8842bc 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "AdvancedPS" uuid = "576499cb-2369-40b2-a588-c64705576edc" authors = ["TuringLang"] -version = "0.5.1" +version = "0.5.2" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" diff --git a/ext/AdvancedPSLibtaskExt.jl b/ext/AdvancedPSLibtaskExt.jl index 2f60ef72..d6062fdb 100644 --- a/ext/AdvancedPSLibtaskExt.jl +++ b/ext/AdvancedPSLibtaskExt.jl @@ -19,30 +19,29 @@ end """ LibtaskModel{F} -State wrapper to hold `Libtask.CTask` model initiated from `f` +State wrapper to hold `Libtask.CTask` model initiated from `f`. """ -struct LibtaskModel{F1,F2} - f::F1 - ctask::Libtask.TapedTask{F2} - - LibtaskModel(f::F1, ctask::Libtask.TapedTask{F2}) where {F1,F2} = new{F1,F2}(f, ctask) -end - -function LibtaskModel(f, args...) - return LibtaskModel( +function AdvancedPS.LibtaskModel( + f::AdvancedPS.AbstractGenericModel, rng::Random.AbstractRNG, args... +) # Changed the API, need to take care of the RNG properly + return AdvancedPS.LibtaskModel( f, - Libtask.TapedTask(f, args...; deepcopy_types=Union{AdvancedPS.TracedRNG,typeof(f)}), + Libtask.TapedTask( + f, rng, args...; deepcopy_types=Union{AdvancedPS.TracedRNG,typeof(f)} + ), ) end -Base.copy(model::LibtaskModel) = LibtaskModel(model.f, copy(model.ctask)) +function Base.copy(model::AdvancedPS.LibtaskModel) + return AdvancedPS.LibtaskModel(model.f, copy(model.ctask)) +end -const LibtaskTrace{R} = AdvancedPS.Trace{<:LibtaskModel,R} +const LibtaskTrace{R} = AdvancedPS.Trace{<:AdvancedPS.LibtaskModel,R} function AdvancedPS.Trace( model::AdvancedPS.AbstractGenericModel, rng::Random.AbstractRNG, args... ) - return AdvancedPS.Trace(LibtaskModel(model, args...), rng) + return AdvancedPS.Trace(AdvancedPS.LibtaskModel(model, rng, args...), rng) end # step to the next observe statement and @@ -56,7 +55,7 @@ function AdvancedPS.advance!(t::LibtaskTrace, isref::Bool=false) end # create a backward reference in task_local_storage -function addreference!(task::Task, trace::LibtaskTrace) +function AdvancedPS.addreference!(task::Task, trace::LibtaskTrace) if task.storage === nothing task.storage = IdDict() end @@ -65,9 +64,7 @@ function addreference!(task::Task, trace::LibtaskTrace) return task end -current_trace() = current_task().storage[:__trace] - -function update_rng!(trace::LibtaskTrace) +function AdvancedPS.update_rng!(trace::LibtaskTrace) rng, = trace.model.ctask.args trace.rng = rng return trace @@ -76,12 +73,12 @@ end # Task copying version of fork for Trace. function AdvancedPS.fork(trace::LibtaskTrace, isref::Bool=false) newtrace = copy(trace) - update_rng!(newtrace) + AdvancedPS.update_rng!(newtrace) isref && AdvancedPS.delete_retained!(newtrace.model.f) isref && delete_seeds!(newtrace) # add backward reference - addreference!(newtrace.model.ctask.task, newtrace) + AdvancedPS.addreference!(newtrace.model.ctask.task, newtrace) return newtrace end @@ -94,11 +91,11 @@ function AdvancedPS.forkr(trace::LibtaskTrace) ctask = Libtask.TapedTask( newf, trace.rng; deepcopy_types=Union{AdvancedPS.TracedRNG,typeof(trace.model.f)} ) - new_tapedmodel = LibtaskModel(newf, ctask) + new_tapedmodel = AdvancedPS.LibtaskModel(newf, ctask) # add backward reference newtrace = AdvancedPS.Trace(new_tapedmodel, trace.rng) - addreference!(ctask.task, newtrace) + AdvancedPS.addreference!(ctask.task, newtrace) AdvancedPS.gen_refseed!(newtrace) return newtrace end @@ -135,9 +132,8 @@ function AbstractMCMC.step( AdvancedPS.forkr(copy(state.trajectory)) else trng = AdvancedPS.TracedRNG() - gen_model = LibtaskModel(deepcopy(model), trng) - trace = AdvancedPS.Trace(LibtaskModel(deepcopy(model), trng), trng) - addreference!(gen_model.ctask.task, trace) # Do we need it here ? + trace = AdvancedPS.Trace(deepcopy(model), trng) + AdvancedPS.addreference!(trace.model.ctask.task, trace) # TODO: Do we need it here ? trace end end @@ -174,9 +170,8 @@ function AbstractMCMC.sample( traces = map(1:(sampler.nparticles)) do i trng = AdvancedPS.TracedRNG() - gen_model = LibtaskModel(deepcopy(model), trng) - trace = AdvancedPS.Trace(LibtaskModel(deepcopy(model), trng), trng) - addreference!(gen_model.ctask.task, trace) # Do we need it here ? + trace = AdvancedPS.Trace(deepcopy(model), trng) + AdvancedPS.addreference!(trace.model.ctask.task, trace) # Do we need it here ? trace end @@ -202,7 +197,9 @@ function AdvancedPS.replay(particle::AdvancedPS.Particle) trng = deepcopy(particle.rng) Random123.set_counter!(trng.rng, 0) trng.count = 1 - trace = AdvancedPS.Trace(LibtaskModel(deepcopy(particle.model.f), trng), trng) + trace = AdvancedPS.Trace( + AdvancedPS.LibtaskModel(deepcopy(particle.model.f), trng), trng + ) score = AdvancedPS.advance!(trace, true) while !isnothing(score) score = AdvancedPS.advance!(trace, true) diff --git a/src/model.jl b/src/model.jl index 9ed13e59..fab9a34f 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1,5 +1,5 @@ """ - Trace{F,R} + Trace{F,R} """ mutable struct Trace{F,R} model::F @@ -10,24 +10,37 @@ const Particle = Trace const SSMTrace{R} = Trace{<:AbstractStateSpaceModel,R} const GenericTrace{R} = Trace{<:AbstractGenericModel,R} -# reset log probability reset_logprob!(::AdvancedPS.Particle) = nothing - reset_model(f) = deepcopy(f) delete_retained!(f) = nothing -Base.copy(trace::Trace) = Trace(copy(trace.model), deepcopy(trace.rng)) +""" + copy(trace::Trace) -# This is required to make it visible from outside extensions -function observe end -function replay end +Copy a trace. The `TracedRNG` is deep-copied. The inner model is shallow-copied. +""" +Base.copy(trace::Trace) = Trace(copy(trace.model), deepcopy(trace.rng)) """ - gen_refseed!(particle::Particle) + gen_refseed!(particle::Particle) -Generate a new seed for the reference particle +Generate a new seed for the reference particle. """ function gen_refseed!(particle::Particle) seed = split(state(particle.rng.rng), 1) return safe_set_refseed!(particle.rng, seed[1]) end + +# A few internal functions used in the Libtask extension. Since it is not possible to access objects defined +# in an extension, we just define dummy in the main module and implement them in the extension. +function observe end +function replay end +function addreference! end + +current_trace() = current_task().storage[:__trace] + +# We need this one to be visible outside of the extension for dispatching (Turing.jl). +struct LibtaskModel{F,T} + f::F + ctask::T +end diff --git a/src/rng.jl b/src/rng.jl index 1fbf0158..a43e94a0 100644 --- a/src/rng.jl +++ b/src/rng.jl @@ -117,3 +117,5 @@ Random123.set_counter!(r::TracedRNG, n::Integer) = r.count = n Increase the model step counter by `n` """ inc_counter!(r::TracedRNG, n::Integer=1) = r.count += n + +function update_rng! end diff --git a/test/container.jl b/test/container.jl index bada9b9c..7a312291 100644 --- a/test/container.jl +++ b/test/container.jl @@ -130,7 +130,7 @@ # Test task copy version of trace trng = AdvancedPS.TracedRNG() - tr = AdvancedPS.Trace(Model(Ref(0)), trng, trng) + tr = AdvancedPS.Trace(Model(Ref(0)), trng) consume(tr.model.ctask) consume(tr.model.ctask) @@ -143,6 +143,21 @@ @test consume(a.model.ctask) == 4 end + @testset "current trace" begin + struct TaskIdModel <: AdvancedPS.AbstractGenericModel end + + function (model::TaskIdModel)(rng::Random.AbstractRNG) + # Just print the task it's running in + id = objectid(AdvancedPS.current_trace()) + return Libtask.produce(id) + end + + trace = AdvancedPS.Trace(TaskIdModel(), AdvancedPS.TracedRNG()) + AdvancedPS.addreference!(trace.model.ctask.task, trace) + + @test AdvancedPS.advance!(trace, false) === objectid(trace) + end + @testset "seed container" begin seed = 1 n = 3