Skip to content

Commit

Permalink
Utils module
Browse files Browse the repository at this point in the history
  • Loading branch information
FredericWantiez committed Sep 13, 2023
1 parent 5bcb4f1 commit 09f1687
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 6 deletions.
8 changes: 4 additions & 4 deletions examples/smc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ end
function sweep!(
rng::AbstractRNG,
model::AbstractStateSpaceModel,
particles::SSMProblems.ParticleContainer,
particles::SSMProblems.Utils.ParticleContainer,
observations::AbstractArray,
resampling=systematic_resampling,
threshold=0.5,
Expand All @@ -37,7 +37,7 @@ function sweep!(
# Mutation step
for i in eachindex(particles)
latent_state = transition!!(rng, model, particles[i].state, timestep)
particles[i] = SSMProblems.Particle(particles[i], latent_state)
particles[i] = SSMProblems.Utils.Particle(particles[i], latent_state)
logweights[i] += emission_logdensity(
model, particles[i].state, observation, timestep
)
Expand All @@ -60,7 +60,7 @@ function sample(
)
particles = map(1:N) do i
state = transition!!(rng, model)
SSMProblems.Particle(state)
SSMProblems.Utils.Particle(state)
end
samples = sweep!(rng, model, particles, observations, resampling, threshold)
return samples
Expand Down Expand Up @@ -109,7 +109,7 @@ end

# Sample latent state trajectories
samples = sample(rng, LinearSSM(), N, observations)
traces = reverse(hcat(map(SSMProblems.linearize, samples)...))
traces = reverse(hcat(map(SSMProblems.Utils.linearize, samples)...))

scatter(traces; color=:black, opacity=0.3, label=false)
plot!(x; label="True state")
Expand Down
2 changes: 1 addition & 1 deletion src/SSMProblems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ abstract type AbstractStateSpaceModel end
abstract type AbstractParticleCache end

"""
transition!!(rng, model[, timestep, state, cache])
transition!!(rng, model[, state, timestep, cache])
Simulate the particle for the next time step from the forward dynamics.
"""
Expand Down
13 changes: 12 additions & 1 deletion src/utils/particles.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
# Concrete Particles as LinkedLists
"""
Common concrete implementations of Particle types for Particle Filter kernels.
"""
module Utils

abstract type Node{T} end

struct Root{T} <: Node{T} end
Root(T) = Root{T}()
Root() = Root(Any)

"""
Particle{T}
Particle as immutable LinkedList.
"""
struct Particle{T} <: Node{T}
parent::Node{T}
state::T
Expand All @@ -30,3 +39,5 @@ function linearize(particle::Particle{T}) where {T}
end
return trace
end

end

0 comments on commit 09f1687

Please sign in to comment.