diff --git a/examples/particle-mcmc/conditional-filter.jl b/examples/particle-mcmc/conditional-filter.jl new file mode 100644 index 0000000..8d84d74 --- /dev/null +++ b/examples/particle-mcmc/conditional-filter.jl @@ -0,0 +1,108 @@ +using CairoMakie + +include("particles.jl") +include("resamplers.jl") +include("simple-filters.jl") + +## MULTICALLBACKS ########################################################################## + +# borrowed from TuringCallbacks, and repurposed to double check reference trajectories exist +struct MultiCallback{Cs} + callbacks::Cs +end + +MultiCallback() = MultiCallback(()) +MultiCallback(callbacks...) = MultiCallback(callbacks) + +(c::MultiCallback)(args...; kwargs...) = foreach(c -> c(args...; kwargs...), c.callbacks) + +## CONDITINAL SMC ########################################################################## + +struct ConditionalSMC{F<:AbstractFilter} <: AbstractSampler + filter::F + N::Integer +end + +function CSMC(filter::AbstractFilter, N::Integer) + return ConditionalSMC(filter, N) +end + +# this is pretty useless without sampling model parameters +function sample( + rng::AbstractRNG, + model::StateSpaceModel, + sampler::ConditionalSMC, + observations::AbstractVector; + kwargs..., +) + # not type stable, but I'm not that concerned with it right now + star_trajectory = nothing + multi_callback = nothing + ll = zero(eltype(model)) + + for n in 1:(sampler.N) + # store resampling index for testing purposes + multi_callback = MultiCallback( + AncestorCallback(eltype(model.dyn), sampler.filter.N, 1.0), + ResamplerCallback(sampler.filter.N), + ) + + states, ll = sample( + rng, + model, + sampler.filter, + observations; + ref_state=star_trajectory, + callback=multi_callback, + ) + + weights = softmax(states.log_weights) + star_trajectory = rand(rng, multi_callback.callbacks[1].tree, weights) + println("n = $n \t $ll") + end + + return multi_callback.callbacks[2], ll +end + +## PARTICLE GIBBS ########################################################################## + +#= + TODO: think about interfacing with AbstractMCMC for something like this +=# + +## TESTING ################################################################################# + +# use a local level trend model +function simulation_model(σx²::T, σy²::T) where {T<:Real} + init = Gaussian(zeros(T, 2), PDMat(diagm(ones(T, 2)))) + dyn = LinearGaussianLatentDynamics(T[1 1; 0 1], T[0; 0], [σx² 0; 0 0], init) + obs = LinearGaussianObservationProcess(T[1 0], [σy²;;]) + return StateSpaceModel(dyn, obs) +end + +# generate model and data +rng = MersenneTwister(1234); +true_params = randexp(rng, Float64, 2); +true_model = simulation_model(true_params...); +_, _, data = sample(rng, true_model, 50); + +filter = BF(20; threshold=1.0, resampler=Systematic()); +rs_path, ll = sample(rng, true_model, CSMC(filter, 5), data); + +# check that y = 1 always has a path +rs_check = begin + fig = Figure(; size=(600, 300)) + ax = Axis( + fig[1, 1]; + xticks=0:10:50, + yticks=0:10:(filter.N), + limits=(nothing, (-5, filter.N + 5)), + ) + + paths = get_ancestry(rs_path.tree) + scatterlines!.( + ax, paths, color=(:black, 0.25), markercolor=:black, markersize=5, linewidth=1 + ) + + fig +end diff --git a/examples/particle-mcmc/particles.jl b/examples/particle-mcmc/particles.jl index f43b532..4feb826 100644 --- a/examples/particle-mcmc/particles.jl +++ b/examples/particle-mcmc/particles.jl @@ -1,5 +1,6 @@ using DataStructures: Stack using StatsBase +using Random ## PARTICLES ############################################################################### @@ -26,13 +27,23 @@ Base.keys(pc::ParticleContainer) = LinearIndices(pc.vals) Base.@propagate_inbounds Base.getindex(pc::ParticleContainer, i::Int) = pc.vals[i] Base.@propagate_inbounds Base.getindex(pc::ParticleContainer, i::Vector{Int}) = pc.vals[i] +StatsBase.weights(pc::ParticleContainer) = softmax(pc.log_weights) + function reset_weights!(pc::ParticleContainer{T,WT}) where {T,WT<:Real} fill!(pc.log_weights, zero(WT)) return pc.log_weights end -function StatsBase.weights(pc::ParticleContainer) - return softmax(pc.log_weights) +function update_ref!( + pc::ParticleContainer{T}, ref_state::Union{Nothing,AbstractVector{T}}, step::Integer=0 +) where {T} + # this comes from Nicolas Chopin's package particles + if !isnothing(ref_state) + pc.proposed[1] = ref_state[step + 1] + pc.filtered[1] = ref_state[step + 1] + pc.ancestors[1] = 1 + end + return pc end ## SPARSE PARTICLE STORAGE ################################################################# @@ -75,16 +86,17 @@ function prune!(tree::ParticleTree, offspring::Vector{Int64}) end end end + return tree end function insert!( - tree::ParticleTree{T}, states::Vector{T}, a::AbstractVector{Int64} + tree::ParticleTree{T}, states::Vector{T}, ancestors::AbstractVector{Int64} ) where {T} # parents of new generation - parents = getindex(tree.leaves, a) + parents = getindex(tree.leaves, ancestors) # ensure there are enough dead branches - if (length(tree.free_indices) < length(a)) + if (length(tree.free_indices) < length(ancestors)) @debug "expanding tree" expand!(tree) end @@ -135,6 +147,22 @@ function get_ancestry(tree::ParticleTree{T}) where {T} return paths end +# this could be improved for sure... +function Random.rand(rng::AbstractRNG, tree::ParticleTree, weights::AbstractVector{<:Real}) + b = randcat(rng, weights) + leaf = tree.leaves[b] + + j = tree.parents[leaf] + xi = tree.states[leaf] + + xs = [xi] + while j > 0 + push!(xs, tree.states[j]) + j = tree.parents[j] + end + return reverse(xs) +end + ## ANCESTOR STORAGE CALLBACK ############################################################### mutable struct AncestorCallback diff --git a/examples/particle-mcmc/resamplers.jl b/examples/particle-mcmc/resamplers.jl index c678481..c9a3b95 100644 --- a/examples/particle-mcmc/resamplers.jl +++ b/examples/particle-mcmc/resamplers.jl @@ -3,6 +3,26 @@ using Distributions abstract type AbstractResampler end +## CATEGORICAL RESAMPLE #################################################################### + +# this is adapted from AdvancedPS +function randcat(rng::AbstractRNG, weights::AbstractVector{WT}) where {WT<:Real} + # pre-calculations + @inbounds v = weights[1] + u = rand(rng, WT) + + # initialize sampling algorithm + n = length(weights) + idx = 1 + + while (v ≤ u) && (idx < n) + idx += 1 + v += weights[idx] + end + + return idx +end + ## DOUBLE PRECISION STABLE ALGORITHMS ###################################################### struct Multinomial <: AbstractResampler end @@ -20,7 +40,7 @@ function resample( ) where {WT<:Real} # pre-calculations @inbounds v = n * weights[1] - u = oftype(v, rand(rng)) + u = rand(rng, WT) # initialize sampling algorithm a = Vector{Int64}(undef, n) @@ -66,7 +86,7 @@ function resample( for _ in 1:B j = rand(rng, 1:n) v = weights[j] / weights[k] - if rand(rng) ≤ v + if rand(rng, WT) ≤ v k = j end end @@ -93,7 +113,7 @@ function resample( u = rand(rng) while u > weights[j] / max_weight j = rand(rng, 1:n) - u = rand(rng) + u = rand(rng, WT) end a[i] = j end diff --git a/examples/particle-mcmc/simple-filters.jl b/examples/particle-mcmc/simple-filters.jl index 11c460b..507852a 100644 --- a/examples/particle-mcmc/simple-filters.jl +++ b/examples/particle-mcmc/simple-filters.jl @@ -214,12 +214,17 @@ end resample_threshold(filter::BootstrapFilter) = filter.threshold * filter.N function initialise( - rng::AbstractRNG, model::StateSpaceModel, filter::BootstrapFilter, extra; kwargs... + rng::AbstractRNG, + model::StateSpaceModel, + filter::BootstrapFilter, + extra; + ref_state::Union{Nothing,AbstractVector}=nothing, + kwargs..., ) initial_states = map(x -> SSMProblems.simulate(rng, model.dyn, extra), 1:(filter.N)) initial_weights = zeros(eltype(model), filter.N) - return ParticleContainer(initial_states, initial_weights) + return update_ref!(ParticleContainer(initial_states, initial_weights), ref_state) end function resample(rng::AbstractRNG, states::ParticleContainer, filter::BootstrapFilter) @@ -242,17 +247,18 @@ function predict( model::StateSpaceModel, filter::BootstrapFilter, step::Integer, - states::ParticleContainer, + states::ParticleContainer{T}, extra; + ref_state::Union{Nothing,AbstractVector{T}}=nothing, kwargs..., -) +) where {T} states.ancestors = resample(rng, states, filter) states.proposed = map( x -> SSMProblems.simulate(rng, model.dyn, step, x, extra), states.filtered[states.ancestors], ) - return states + return update_ref!(states, ref_state, step) end function update(