From 09f168703df90a968faafda93cec34dd41f996d1 Mon Sep 17 00:00:00 2001 From: FredericWantiez Date: Wed, 13 Sep 2023 22:40:55 +0100 Subject: [PATCH] Utils module --- examples/smc.jl | 8 ++++---- src/SSMProblems.jl | 2 +- src/utils/particles.jl | 13 ++++++++++++- 3 files changed, 17 insertions(+), 6 deletions(-) diff --git a/examples/smc.jl b/examples/smc.jl index e6a2e24..e27d7d6 100644 --- a/examples/smc.jl +++ b/examples/smc.jl @@ -17,7 +17,7 @@ end function sweep!( rng::AbstractRNG, model::AbstractStateSpaceModel, - particles::SSMProblems.ParticleContainer, + particles::SSMProblems.Utils.ParticleContainer, observations::AbstractArray, resampling=systematic_resampling, threshold=0.5, @@ -37,7 +37,7 @@ function sweep!( # Mutation step for i in eachindex(particles) latent_state = transition!!(rng, model, particles[i].state, timestep) - particles[i] = SSMProblems.Particle(particles[i], latent_state) + particles[i] = SSMProblems.Utils.Particle(particles[i], latent_state) logweights[i] += emission_logdensity( model, particles[i].state, observation, timestep ) @@ -60,7 +60,7 @@ function sample( ) particles = map(1:N) do i state = transition!!(rng, model) - SSMProblems.Particle(state) + SSMProblems.Utils.Particle(state) end samples = sweep!(rng, model, particles, observations, resampling, threshold) return samples @@ -109,7 +109,7 @@ end # Sample latent state trajectories samples = sample(rng, LinearSSM(), N, observations) -traces = reverse(hcat(map(SSMProblems.linearize, samples)...)) +traces = reverse(hcat(map(SSMProblems.Utils.linearize, samples)...)) scatter(traces; color=:black, opacity=0.3, label=false) plot!(x; label="True state") diff --git a/src/SSMProblems.jl b/src/SSMProblems.jl index 5854d03..b3f1cac 100644 --- a/src/SSMProblems.jl +++ b/src/SSMProblems.jl @@ -10,7 +10,7 @@ abstract type AbstractStateSpaceModel end abstract type AbstractParticleCache end """ - transition!!(rng, model[, timestep, state, cache]) + transition!!(rng, model[, state, timestep, cache]) Simulate the particle for the next time step from the forward dynamics. """ diff --git a/src/utils/particles.jl b/src/utils/particles.jl index ccb4697..15ecb5d 100644 --- a/src/utils/particles.jl +++ b/src/utils/particles.jl @@ -1,10 +1,19 @@ -# Concrete Particles as LinkedLists +""" + Common concrete implementations of Particle types for Particle Filter kernels. +""" +module Utils + abstract type Node{T} end struct Root{T} <: Node{T} end Root(T) = Root{T}() Root() = Root(Any) +""" + Particle{T} + +Particle as immutable LinkedList. +""" struct Particle{T} <: Node{T} parent::Node{T} state::T @@ -30,3 +39,5 @@ function linearize(particle::Particle{T}) where {T} end return trace end + +end