diff --git a/examples/smc.jl b/examples/smc.jl index 9945e16..24e24af 100644 --- a/examples/smc.jl +++ b/examples/smc.jl @@ -7,23 +7,23 @@ using StatsFuns # Particle Filter implementation struct Particle{T} # Here we just need a tree parent::Union{Particle,Nothing} - state::T + model::T end -Particle(state::T) where {T} = Particle(nothing, state) +Particle(model::T) where {T} = Particle(nothing, model) Particle() = Particle(nothing, nothing) -Base.show(io::IO, p::Particle) = print(io, "Particle($(p.state))") +Base.show(io::IO, p::Particle) = print(io, "Particle($(p.model))") """ linearize(particle) Return the trace of a particle, i.e. the sequence of states from the root to the particle. """ -function linearize(particle::Particle{T}) where {T} - trace = T[] +function linearize(particle::Particle) + trace = [] parent = particle.parent while !isnothing(parent) - push!(trace, parent.state) + push!(trace, parent.model) parent = parent.parent end return trace @@ -44,7 +44,7 @@ function sweep!(rng::AbstractRNG, particles::ParticleContainer, resampling, thre t = 1 N = length(particles) logweights = zeros(length(particles)) - while !isdone(t, particles[1].state) + while !isdone(t, particles[1].model) # Resample step weights = get_weights(logweights) @@ -57,9 +57,9 @@ function sweep!(rng::AbstractRNG, particles::ParticleContainer, resampling, thre # Mutation step for i in eachindex(particles) parent = particles[i] - mutated = transition!!(rng, t, parent.state) + mutated = transition!!(rng, t, parent.model) particles[i] = Particle(parent, mutated) - logweights[i] += emission_logdensity(t, particles[i].state) + logweights[i] += emission_logdensity(t, particles[i].model) end t += 1 @@ -76,6 +76,12 @@ function sweep!(rng::AbstractRNG, n::Int, resampling, threshold=0.5) end # Inference code +struct LinearSSM{T} <: AbstractStateSpaceModel + state::T +end + +LinearSSM() = LinearSSM(nothing) + Base.@kwdef struct Parameters v::Float64 = 0.2 # Transition noise stdev u::Float64 = 0.7 # Observation noise stdev @@ -83,36 +89,40 @@ end # Simulation T = 250 -seed = 1 +seed = 12 N = 1000 rng = MersenneTwister(seed) params = Parameters(; v=0.2, u=0.7) -function transition!!(rng::AbstractRNG, t::Int, state=nothing) - if isnothing(state) - return rand(rng, Normal(0, 1)) +function emission_density(x) + return Normal(x, params.u) +end + +function transition!!(rng::AbstractRNG, t::Int, model::LinearSSM{T}=LinearSSM()) where {T} + if t == 1 + return LinearSSM(rand(rng, Normal(0, 1))) end - return rand(rng, Normal(state, params.v)) + return LinearSSM(rand(rng, Normal(model.state, params.v))) end -function emission_logdensity(t, state) - return logpdf(Normal(state, params.u), observations[t]) +function emission_logdensity(t, model::LinearSSM) + return logpdf(emission_density(model.state), observations[t]) end -isdone(t, state) = t > T -isdone(t, ::Nothing) = false +isdone(t, ::LinearSSM) = t > T -x, observations = zeros(T), zeros(T) -x[1] = rand(rng, Normal(0, 1)) +x, observations = Vector{LinearSSM}(undef, T), zeros(T) +x[1] = transition!!(rng, 1) for t in 1:T - observations[t] = rand(rng, Normal(x[t], params.u)) + observations[t] = rand(emission_density(x[t].state)) if t < T - x[t + 1] = rand(rng, Normal(x[t], params.v)) + x[t + 1] = transition!!(rng, t, x[t]) end end samples = sweep!(rng, fill(Particle(x[1]), N), systematic_resampling) -traces = reverse(hcat(map(linearize, samples)...)) +traces = map(model -> model.state, reverse(hcat(map(linearize, samples)...))) scatter(traces; color=:black, opacity=0.3, label=false) -plot!(x; label="True state") +plot!(map(model -> getfield(model, :state), x); label="True state") +plot!(mean(traces; dims=2); label="Posterior mean") diff --git a/src/SSMProblems.jl b/src/SSMProblems.jl index fd91a4b..9f80702 100644 --- a/src/SSMProblems.jl +++ b/src/SSMProblems.jl @@ -5,8 +5,8 @@ module SSMProblems """ """ -abstract type AbstractParticle end -abstract type AbstractParticleCache end +abstract type AbstractStateSpaceModel end +abstract type AbstractSSMCache end """ transition!!(rng, step, particle[, cache]) @@ -36,6 +36,7 @@ Determine whether we have reached the last time step of the Markov process. Retu """ function isdone end -export transition!!, transition_logdensity, emission_logdensity, isdone, AbstractParticle +export transition!!, + transition_logdensity, emission_logdensity, isdone, AbstractStateSpaceModel end