Skip to content

Commit

Permalink
added conditional SMC
Browse files Browse the repository at this point in the history
  • Loading branch information
charlesknipp committed Oct 8, 2024
1 parent c729879 commit d13c80c
Show file tree
Hide file tree
Showing 4 changed files with 175 additions and 13 deletions.
108 changes: 108 additions & 0 deletions examples/particle-mcmc/conditional-filter.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
using CairoMakie

include("particles.jl")
include("resamplers.jl")
include("simple-filters.jl")

## MULTICALLBACKS ##########################################################################

# borrowed from TuringCallbacks, and repurposed to double check reference trajectories exist
struct MultiCallback{Cs}
callbacks::Cs
end

MultiCallback() = MultiCallback(())
MultiCallback(callbacks...) = MultiCallback(callbacks)

(c::MultiCallback)(args...; kwargs...) = foreach(c -> c(args...; kwargs...), c.callbacks)

## CONDITINAL SMC ##########################################################################

struct ConditionalSMC{F<:AbstractFilter} <: AbstractSampler
filter::F
N::Integer
end

function CSMC(filter::AbstractFilter, N::Integer)
return ConditionalSMC(filter, N)
end

# this is pretty useless without sampling model parameters
function sample(
rng::AbstractRNG,
model::StateSpaceModel,
sampler::ConditionalSMC,
observations::AbstractVector;
kwargs...,
)
# not type stable, but I'm not that concerned with it right now
star_trajectory = nothing
multi_callback = nothing
ll = zero(eltype(model))

for n in 1:(sampler.N)
# store resampling index for testing purposes
multi_callback = MultiCallback(
AncestorCallback(eltype(model.dyn), sampler.filter.N, 1.0),
ResamplerCallback(sampler.filter.N),
)

states, ll = sample(
rng,
model,
sampler.filter,
observations;
ref_state=star_trajectory,
callback=multi_callback,
)

weights = softmax(states.log_weights)
star_trajectory = rand(rng, multi_callback.callbacks[1].tree, weights)
println("n = $n \t $ll")
end

return multi_callback.callbacks[2], ll
end

## PARTICLE GIBBS ##########################################################################

#=
TODO: think about interfacing with AbstractMCMC for something like this
=#

## TESTING #################################################################################

# use a local level trend model
function simulation_model(σx²::T, σy²::T) where {T<:Real}
init = Gaussian(zeros(T, 2), PDMat(diagm(ones(T, 2))))
dyn = LinearGaussianLatentDynamics(T[1 1; 0 1], T[0; 0], [σx² 0; 0 0], init)
obs = LinearGaussianObservationProcess(T[1 0], [σy²;;])
return StateSpaceModel(dyn, obs)
end

# generate model and data
rng = MersenneTwister(1234);
true_params = randexp(rng, Float64, 2);
true_model = simulation_model(true_params...);
_, _, data = sample(rng, true_model, 50);

filter = BF(20; threshold=1.0, resampler=Systematic());
rs_path, ll = sample(rng, true_model, CSMC(filter, 5), data);

# check that y = 1 always has a path
rs_check = begin
fig = Figure(; size=(600, 300))
ax = Axis(
fig[1, 1];
xticks=0:10:50,
yticks=0:10:(filter.N),
limits=(nothing, (-5, filter.N + 5)),
)

paths = get_ancestry(rs_path.tree)
scatterlines!.(
ax, paths, color=(:black, 0.25), markercolor=:black, markersize=5, linewidth=1
)

fig
end
38 changes: 33 additions & 5 deletions examples/particle-mcmc/particles.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using DataStructures: Stack
using StatsBase
using Random

## PARTICLES ###############################################################################

Expand All @@ -26,13 +27,23 @@ Base.keys(pc::ParticleContainer) = LinearIndices(pc.vals)
Base.@propagate_inbounds Base.getindex(pc::ParticleContainer, i::Int) = pc.vals[i]
Base.@propagate_inbounds Base.getindex(pc::ParticleContainer, i::Vector{Int}) = pc.vals[i]

StatsBase.weights(pc::ParticleContainer) = softmax(pc.log_weights)

function reset_weights!(pc::ParticleContainer{T,WT}) where {T,WT<:Real}
fill!(pc.log_weights, zero(WT))
return pc.log_weights
end

function StatsBase.weights(pc::ParticleContainer)
return softmax(pc.log_weights)
function update_ref!(
pc::ParticleContainer{T}, ref_state::Union{Nothing,AbstractVector{T}}, step::Integer=0
) where {T}
# this comes from Nicolas Chopin's package particles
if !isnothing(ref_state)
pc.proposed[1] = ref_state[step + 1]
pc.filtered[1] = ref_state[step + 1]
pc.ancestors[1] = 1
end
return pc
end

## SPARSE PARTICLE STORAGE #################################################################
Expand Down Expand Up @@ -75,16 +86,17 @@ function prune!(tree::ParticleTree, offspring::Vector{Int64})
end
end
end
return tree
end

function insert!(
tree::ParticleTree{T}, states::Vector{T}, a::AbstractVector{Int64}
tree::ParticleTree{T}, states::Vector{T}, ancestors::AbstractVector{Int64}
) where {T}
# parents of new generation
parents = getindex(tree.leaves, a)
parents = getindex(tree.leaves, ancestors)

# ensure there are enough dead branches
if (length(tree.free_indices) < length(a))
if (length(tree.free_indices) < length(ancestors))
@debug "expanding tree"
expand!(tree)
end
Expand Down Expand Up @@ -135,6 +147,22 @@ function get_ancestry(tree::ParticleTree{T}) where {T}
return paths
end

# this could be improved for sure...
function Random.rand(rng::AbstractRNG, tree::ParticleTree, weights::AbstractVector{<:Real})
b = randcat(rng, weights)
leaf = tree.leaves[b]

j = tree.parents[leaf]
xi = tree.states[leaf]

xs = [xi]
while j > 0
push!(xs, tree.states[j])
j = tree.parents[j]
end
return reverse(xs)
end

## ANCESTOR STORAGE CALLBACK ###############################################################

mutable struct AncestorCallback
Expand Down
26 changes: 23 additions & 3 deletions examples/particle-mcmc/resamplers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,26 @@ using Distributions

abstract type AbstractResampler end

## CATEGORICAL RESAMPLE ####################################################################

# this is adapted from AdvancedPS
function randcat(rng::AbstractRNG, weights::AbstractVector{WT}) where {WT<:Real}
# pre-calculations
@inbounds v = weights[1]
u = rand(rng, WT)

# initialize sampling algorithm
n = length(weights)
idx = 1

while (v u) && (idx < n)
idx += 1
v += weights[idx]
end

return idx
end

## DOUBLE PRECISION STABLE ALGORITHMS ######################################################

struct Multinomial <: AbstractResampler end
Expand All @@ -20,7 +40,7 @@ function resample(
) where {WT<:Real}
# pre-calculations
@inbounds v = n * weights[1]
u = oftype(v, rand(rng))
u = rand(rng, WT)

# initialize sampling algorithm
a = Vector{Int64}(undef, n)
Expand Down Expand Up @@ -66,7 +86,7 @@ function resample(
for _ in 1:B
j = rand(rng, 1:n)
v = weights[j] / weights[k]
if rand(rng) v
if rand(rng, WT) v
k = j
end
end
Expand All @@ -93,7 +113,7 @@ function resample(
u = rand(rng)
while u > weights[j] / max_weight
j = rand(rng, 1:n)
u = rand(rng)
u = rand(rng, WT)
end
a[i] = j
end
Expand Down
16 changes: 11 additions & 5 deletions examples/particle-mcmc/simple-filters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -214,12 +214,17 @@ end
resample_threshold(filter::BootstrapFilter) = filter.threshold * filter.N

function initialise(
rng::AbstractRNG, model::StateSpaceModel, filter::BootstrapFilter, extra; kwargs...
rng::AbstractRNG,
model::StateSpaceModel,
filter::BootstrapFilter,
extra;
ref_state::Union{Nothing,AbstractVector}=nothing,
kwargs...,
)
initial_states = map(x -> SSMProblems.simulate(rng, model.dyn, extra), 1:(filter.N))
initial_weights = zeros(eltype(model), filter.N)

return ParticleContainer(initial_states, initial_weights)
return update_ref!(ParticleContainer(initial_states, initial_weights), ref_state)
end

function resample(rng::AbstractRNG, states::ParticleContainer, filter::BootstrapFilter)
Expand All @@ -242,17 +247,18 @@ function predict(
model::StateSpaceModel,
filter::BootstrapFilter,
step::Integer,
states::ParticleContainer,
states::ParticleContainer{T},
extra;
ref_state::Union{Nothing,AbstractVector{T}}=nothing,
kwargs...,
)
) where {T}
states.ancestors = resample(rng, states, filter)
states.proposed = map(
x -> SSMProblems.simulate(rng, model.dyn, step, x, extra),
states.filtered[states.ancestors],
)

return states
return update_ref!(states, ref_state, step)
end

function update(
Expand Down

0 comments on commit d13c80c

Please sign in to comment.