Skip to content

Commit

Permalink
Merge pull request #17 from TuringLang/container
Browse files Browse the repository at this point in the history
Add Trace without Turing
  • Loading branch information
yebai authored Dec 3, 2020
2 parents 0ca372d + 241e9b5 commit c87cba7
Show file tree
Hide file tree
Showing 7 changed files with 147 additions and 159 deletions.
4 changes: 4 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@ version = "0.1.0"

[deps]
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"

[compat]
Distributions = "0.23, 0.24"
Libtask = "0.5"
StatsFuns = "0.9"
julia = "1.3"
4 changes: 4 additions & 0 deletions src/AdvancedPS.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
module AdvancedPS

import Distributions
import Libtask
import StatsFuns

include("resampling.jl")
include("container.jl")

end
163 changes: 80 additions & 83 deletions src/container.jl
Original file line number Diff line number Diff line change
@@ -1,81 +1,79 @@
mutable struct Trace{Tspl<:AbstractSampler, Tvi<:AbstractVarInfo, Tmodel<:Model}
model::Tmodel
spl::Tspl
vi::Tvi
ctask::CTask

function Trace{SampleFromPrior}(model::Model, spl::AbstractSampler, vi::AbstractVarInfo)
return new{SampleFromPrior,typeof(vi),typeof(model)}(model, SampleFromPrior(), vi)
end
function Trace{S}(model::Model, spl::S, vi::AbstractVarInfo) where S<:Sampler
return new{S,typeof(vi),typeof(model)}(model, spl, vi)
end
struct Trace{F}
f::F
ctask::Libtask.CTask
end

function Base.copy(trace::Trace)
vi = deepcopy(trace.vi)
res = Trace{typeof(trace.spl)}(trace.model, trace.spl, vi)
res.ctask = copy(trace.ctask)
return res
end
const Particle = Trace

# NOTE: this function is called by `forkr`
function Trace(f, m::Model, spl::AbstractSampler, vi::AbstractVarInfo)
res = Trace{typeof(spl)}(m, spl, deepcopy(vi))
ctask = CTask() do
res = f()
produce(nothing)
return res
end
task = ctask.task
if task.storage === nothing
task.storage = IdDict()
function Trace(f)
ctask = let f=f
Libtask.CTask() do
res = f()
Libtask.produce(nothing)
return res
end
end
task.storage[:turing_trace] = res # create a backward reference in task_local_storage
res.ctask = ctask
return res
end

function Trace(m::Model, spl::AbstractSampler, vi::AbstractVarInfo)
res = Trace{typeof(spl)}(m, spl, deepcopy(vi))
reset_num_produce!(res.vi)
ctask = CTask() do
res = m(vi, spl)
produce(nothing)
return res
end
task = ctask.task
if task.storage === nothing
task.storage = IdDict()
end
task.storage[:turing_trace] = res # create a backward reference in task_local_storage
res.ctask = ctask
return res
# add backward reference
newtrace = Trace(f, ctask)
addreference!(ctask.task, newtrace)

return newtrace
end

# step to the next observe statement, return log likelihood
Libtask.consume(t::Trace) = (increment_num_produce!(t.vi); consume(t.ctask))
Base.copy(trace::Trace) = Trace(trace.f, copy(trace.ctask))

# step to the next observe statement and
# return the log probability of the transition (or nothing if done)
advance!(t::Trace) = Libtask.consume(t.ctask)

# reset log probability
reset_logprob!(t::Trace) = nothing

reset_model(f) = nothing
delete_retained!(f) = nothing

# Task copying version of fork for Trace.
function fork(trace :: Trace, is_ref :: Bool = false)
function fork(trace::Trace, isref::Bool = false)
newtrace = copy(trace)
is_ref && set_retained_vns_del_by_spl!(newtrace.vi, newtrace.spl)
newtrace.ctask.task.storage[:turing_trace] = newtrace
isref && delete_retained!(newtrace.f)

# add backward reference
addreference!(newtrace.ctask.task, newtrace)

return newtrace
end

# PG requires keeping all randomness for the reference particle
# Create new task and copy randomness
function forkr(trace::Trace)
newtrace = Trace(trace.ctask.task.code, trace.model, trace.spl, deepcopy(trace.vi))
newtrace.spl = trace.spl
reset_num_produce!(newtrace.vi)
newf = reset_model(trace.f)
ctask = let f=trace.ctask.task.code
Libtask.CTask() do
res = f()
Libtask.produce(nothing)
return res
end
end

# add backward reference
newtrace = Trace(newf, ctask)
addreference!(ctask.task, newtrace)

return newtrace
end

current_trace() = current_task().storage[:turing_trace]
# create a backward reference in task_local_storage
function addreference!(task::Task, trace::Trace)
if task.storage === nothing
task.storage = IdDict()
end
task.storage[:__trace] = trace

const Particle = Trace
return task
end

current_trace() = current_task().storage[:__trace]

"""
Data structure for particle filters
Expand Down Expand Up @@ -141,7 +139,7 @@ end
Compute the normalized weights of the particles.
"""
getweights(pc::ParticleContainer) = softmax(pc.logWs)
getweights(pc::ParticleContainer) = StatsFuns.softmax(pc.logWs)

"""
getweight(pc::ParticleContainer, i)
Expand All @@ -155,7 +153,7 @@ getweight(pc::ParticleContainer, i) = exp(pc.logWs[i] - logZ(pc))
Return the logarithm of the normalizing constant of the unnormalized logarithmic weights.
"""
logZ(pc::ParticleContainer) = logsumexp(pc.logWs)
logZ(pc::ParticleContainer) = StatsFuns.logsumexp(pc.logWs)

"""
effectiveSampleSize(pc::ParticleContainer)
Expand All @@ -168,7 +166,7 @@ function effectiveSampleSize(pc::ParticleContainer)
end

"""
resample_propagate!(pc::ParticleContainer[, randcat = resample_systematic, ref = nothing;
resample_propagate!(pc::ParticleContainer[, randcat = resample, ref = nothing;
weights = getweights(pc)])
Resample and propagate the particles in `pc`.
Expand All @@ -179,7 +177,7 @@ of the particle `weights`. For Particle Gibbs sampling, one can provide a refere
"""
function resample_propagate!(
pc::ParticleContainer,
randcat = Turing.Inference.resample_systematic,
randcat = resample,
ref::Union{Particle, Nothing} = nothing;
weights = getweights(pc)
)
Expand Down Expand Up @@ -231,6 +229,22 @@ function resample_propagate!(
pc
end

function resample_propagate!(
pc::ParticleContainer,
resampler::ResampleWithESSThreshold,
ref::Union{Particle,Nothing} = nothing;
weights = getweights(pc)
)
# Compute the effective sample size ``1 / ∑ wᵢ²`` with normalized weights ``wᵢ``
ess = inv(sum(abs2, weights))

if ess resampler.threshold * length(pc)
resample_propagate!(pc, resampler.resampler, ref; weights = weights)
end

pc
end

"""
reweight!(pc::ParticleContainer)
Expand All @@ -249,19 +263,18 @@ function reweight!(pc::ParticleContainer)
# the execution of the model is finished.
# Here ``yᵢ`` are observations, ``xᵢ`` variables of the particle filter, and
# ``θᵢ`` are variables of other samplers.
score = Libtask.consume(p)
score = advance!(p)

if score === nothing
numdone += 1
else
# Increase the unnormalized logarithmic weights, accounting for the variables
# of other samplers.
increase_logweight!(pc, i, score + getlogp(p.vi))
# Increase the unnormalized logarithmic weights.
increase_logweight!(pc, i, score)

# Reset the accumulator of the log probability in the model so that we can
# accumulate log probabilities of variables of other samplers until the next
# observation.
resetlogp!(p.vi)
reset_logprob!(p)
end
end

Expand Down Expand Up @@ -333,19 +346,3 @@ function sweep!(pc::ParticleContainer, resampler)

return logevidence
end

function resample_propagate!(
pc::ParticleContainer,
resampler::ResampleWithESSThreshold,
ref::Union{Particle,Nothing} = nothing;
weights = getweights(pc)
)
# Compute the effective sample size ``1 / ∑ wᵢ²`` with normalized weights ``wᵢ``
ess = inv(sum(abs2, weights))

if ess resampler.threshold * length(pc)
resample_propagate!(pc, resampler.resampler, ref; weights = weights)
end

pc
end
22 changes: 13 additions & 9 deletions src/resampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,6 @@ function ResampleWithESSThreshold(resampler = resample)
ResampleWithESSThreshold(resampler, 0.5)
end

# Default resampling scheme
function resample(w::AbstractVector{<:Real}, num_particles::Integer=length(w))
return resample_systematic(w, num_particles)
end

# More stable, faster version of rand(Categorical)
function randcat(p::AbstractVector{<:Real})
T = eltype(p)
Expand All @@ -36,11 +31,17 @@ function randcat(p::AbstractVector{<:Real})
return s
end

function resample_multinomial(w::AbstractVector{<:Real}, num_particles::Integer)
function resample_multinomial(
w::AbstractVector{<:Real},
num_particles::Integer = length(w),
)
return rand(Distributions.sampler(Distributions.Categorical(w)), num_particles)
end

function resample_residual(w::AbstractVector{<:Real}, num_particles::Integer)
function resample_residual(
w::AbstractVector{<:Real},
num_particles::Integer = length(weights),
)
# Pre-allocate array for resampled particles
indices = Vector{Int}(undef, num_particles)

Expand Down Expand Up @@ -79,7 +80,7 @@ are selected according to the multinomial distribution defined by the normalized
i.e., `xᵢ = j` if and only if
``uᵢ \\in [\\sum_{s=1}^{j-1} weights_{s}, \\sum_{s=1}^{j} weights_{s})``.
"""
function resample_stratified(weights::AbstractVector{<:Real}, n::Integer)
function resample_stratified(weights::AbstractVector{<:Real}, n::Integer = length(weights))
# check input
m = length(weights)
m > 0 || error("weight vector is empty")
Expand Down Expand Up @@ -124,7 +125,7 @@ numbers `u₁`, ..., `uₙ` where ``uₖ = (u + k − 1) / n``. Based on these n
normalized `weights`, i.e., `xᵢ = j` if and only if
``uᵢ \\in [\\sum_{s=1}^{j-1} weights_{s}, \\sum_{s=1}^{j} weights_{s})``.
"""
function resample_systematic(weights::AbstractVector{<:Real}, n::Integer)
function resample_systematic(weights::AbstractVector{<:Real}, n::Integer = length(weights))
# check input
m = length(weights)
m > 0 || error("weight vector is empty")
Expand Down Expand Up @@ -157,3 +158,6 @@ function resample_systematic(weights::AbstractVector{<:Real}, n::Integer)

return samples
end

# Default resampling scheme
const resample = resample_systematic
4 changes: 3 additions & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
[deps]
Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
julia = "1.3"
Libtask = "0.5"
julia = "1.3"
Loading

2 comments on commit c87cba7

@yebai
Copy link
Member Author

@yebai yebai commented on c87cba7 Dec 3, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/25755

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.1.0 -m "<description of version>" c87cba73a30b3cf7e9b2ad207cceeaaae55de378
git push origin v0.1.0

Please sign in to comment.