diff --git a/docs/src/images/state_space_model.png b/docs/src/images/state_space_model.png new file mode 100644 index 0000000..02f1afb Binary files /dev/null and b/docs/src/images/state_space_model.png differ diff --git a/docs/src/index.md b/docs/src/index.md index de6bc24..541e875 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -1,6 +1,91 @@ # SSMProblems -### API +### Installation +In the `julia` REPL: +```julia +]add SSMProblems +``` + +### Documentation + +`SSMProblems` defines a generic interface for State Space Problems (SSM). The main objective is to provide a consistent +interface to work with SSMs and their logdensities. + +Consider a markovian model from[^Murray]: +![state space model](./docs/images/state_space_model.png) + +[^Murray]: + > Murray, Lawrence & Lee, Anthony & Jacob, Pierre. (2013). Rethinking resampling in the particle filter on graphics processing units. + +The model is fully specified by the following densities: +- __Initialisation__: ``f_0(x)`` +- __Transition__: ``f(x)`` +- __Emission__: ``g(x)`` + +And the dynamics of the model reduces to: +```math +\begin{aligned} +x_t | x_{t-1} &\sim f(x_t | x_{t-1}) \\ +y_t | x_t &\sim g(y_t | x_{t}) +\end{aligned} +``` +assuming ``x_0 \sim f_0(x)``. + +The joint law follows: + +```math +p(x_{0:T}, y_{0:T}) = f_0(x_0) \prod_t g(y_t | x_t) f(x_t | x_{t-1}) +``` + +Users can define their SSM with `SSMProblems` in the following way: +```julia +struct Model <: AbstractStateSpaceModel end + +# Define the structure of the latent space +particleof(::Model) = Float64 +dimension(::Model) = 2 + +function transition!!( + rng::Random.AbstractRNG, + step, + model::Model, + particle::AbstractParticl{<:AbstractStateSpaceModel} +) + if step == 1 + ... # Sample from the initial density + end + ... # Sample from the transition density +end + +function emission_logdensity(step, model::Model, particle::AbstractParticle) + ... # Return log density of the model at *time* `step` +end + +isdone(step, model::Model, particle::AbstractParticle) = ... # Stops the state machine + +# Optionally, if the transition density is known, the model can also specify it +function transition_logdensity(step, prev_particle::AbstractParticle, next_particle::AbstractParticle) + ... # Scores the forward transition at `x` +end +``` + +Package users can then consume the model `logdensity` through calls to `emission_logdensity`. + +For example, a bootstrap filter targeting the filtering distribution ``p(x_t | y_{0:t})`` using `N` particles would roughly follow: +```julia +struct Particle{T<:AbstractStateSpaceModel} <: AbstractParticle{T} end + +while !all(map(particle -> isdone(t, model, particles), particles)): + ancestors = resample(rng, logweigths) + particles = particles[ancestors] + for i in 1:N + particles[i] = transition!!(rng, t, model, particles[i]) + logweights[i] += emission_logdensity(t, model, particles[i]) + end +end +``` + +### Interface ```@autodocs Modules = [SSMProblems] Order = [:type, :function] diff --git a/examples/smc.jl b/examples/smc.jl index 9945e16..e27d7d6 100644 --- a/examples/smc.jl +++ b/examples/smc.jl @@ -4,33 +4,7 @@ using Distributions using Plots using StatsFuns -# Particle Filter implementation -struct Particle{T} # Here we just need a tree - parent::Union{Particle,Nothing} - state::T -end - -Particle(state::T) where {T} = Particle(nothing, state) -Particle() = Particle(nothing, nothing) -Base.show(io::IO, p::Particle) = print(io, "Particle($(p.state))") - -""" - linearize(particle) - -Return the trace of a particle, i.e. the sequence of states from the root to the particle. -""" -function linearize(particle::Particle{T}) where {T} - trace = T[] - parent = particle.parent - while !isnothing(parent) - push!(trace, parent.state) - parent = parent.parent - end - return trace -end - -ParticleContainer = AbstractVector{<:Particle} - +# Particle Filter ess(weights) = inv(sum(abs2, weights)) get_weights(logweights::T) where {T<:AbstractVector{<:Real}} = StatsFuns.softmax(logweights) @@ -40,29 +14,34 @@ function systematic_resampling( return rand(rng, Distributions.Categorical(weights), n) end -function sweep!(rng::AbstractRNG, particles::ParticleContainer, resampling, threshold=0.5) - t = 1 +function sweep!( + rng::AbstractRNG, + model::AbstractStateSpaceModel, + particles::SSMProblems.Utils.ParticleContainer, + observations::AbstractArray, + resampling=systematic_resampling, + threshold=0.5, +) N = length(particles) - logweights = zeros(length(particles)) - while !isdone(t, particles[1].state) + logweights = zeros(N) + for (timestep, observation) in enumerate(observations) # Resample step weights = get_weights(logweights) if ess(weights) <= threshold * N idx = resampling(rng, weights) particles = particles[idx] - logweights = zeros(length(particles)) + fill!(logweights, 0) end # Mutation step for i in eachindex(particles) - parent = particles[i] - mutated = transition!!(rng, t, parent.state) - particles[i] = Particle(parent, mutated) - logweights[i] += emission_logdensity(t, particles[i].state) + latent_state = transition!!(rng, model, particles[i].state, timestep) + particles[i] = SSMProblems.Utils.Particle(particles[i], latent_state) + logweights[i] += emission_logdensity( + model, particles[i].state, observation, timestep + ) end - - t += 1 end # Return unweighted set @@ -70,13 +49,25 @@ function sweep!(rng::AbstractRNG, particles::ParticleContainer, resampling, thre return particles[idx] end -function sweep!(rng::AbstractRNG, n::Int, resampling, threshold=0.5) - particles = [Particle(0.0) for _ in 1:n] - return sweep!(rng, particles, resampling, threshold) +# Turing style sample method +function sample( + rng::AbstractRNG, + model::AbstractStateSpaceModel, + n::Int, + observations::AbstractVector; + resampling=systematic_resampling, + threshold=0.5, +) + particles = map(1:N) do i + state = transition!!(rng, model) + SSMProblems.Utils.Particle(state) + end + samples = sweep!(rng, model, particles, observations, resampling, threshold) + return samples end # Inference code -Base.@kwdef struct Parameters +Base.@kwdef struct LinearSSM <: AbstractStateSpaceModel v::Float64 = 0.2 # Transition noise stdev u::Float64 = 0.7 # Observation noise stdev end @@ -84,35 +75,42 @@ end # Simulation T = 250 seed = 1 -N = 1000 +N = 1_000 rng = MersenneTwister(seed) -params = Parameters(; v=0.2, u=0.7) -function transition!!(rng::AbstractRNG, t::Int, state=nothing) - if isnothing(state) - return rand(rng, Normal(0, 1)) - end - return rand(rng, Normal(state, params.v)) -end +model = LinearSSM(0.2, 0.7) -function emission_logdensity(t, state) - return logpdf(Normal(state, params.u), observations[t]) -end - -isdone(t, state) = t > T -isdone(t, ::Nothing) = false +f0(::LinearSSM) = Normal(0, 1) +f(x::Float64, model::LinearSSM) = Normal(x, model.v) +g(y::Float64, model::LinearSSM) = Normal(y, model.u) +# Generate synthtetic data x, observations = zeros(T), zeros(T) -x[1] = rand(rng, Normal(0, 1)) +x[1] = rand(rng, f0(model)) for t in 1:T - observations[t] = rand(rng, Normal(x[t], params.u)) + observations[t] = rand(rng, g(x[t], model)) if t < T - x[t + 1] = rand(rng, Normal(x[t], params.v)) + x[t + 1] = rand(rng, f(x[t], model)) end end -samples = sweep!(rng, fill(Particle(x[1]), N), systematic_resampling) -traces = reverse(hcat(map(linearize, samples)...)) +# Model dynamics +function transition!!(rng::AbstractRNG, model::LinearSSM) + return rand(rng, f0(model)) +end + +function transition!!(rng::AbstractRNG, model::LinearSSM, state::Float64, ::Int) + return rand(rng, f(state, model)) +end + +function emission_logdensity(model::LinearSSM, state::Float64, observation::Float64, ::Int) + return logpdf(g(state, model), observation) +end + +# Sample latent state trajectories +samples = sample(rng, LinearSSM(), N, observations) +traces = reverse(hcat(map(SSMProblems.Utils.linearize, samples)...)) scatter(traces; color=:black, opacity=0.3, label=false) plot!(x; label="True state") +plot!(mean(traces; dims=2); label="Posterior mean") diff --git a/src/SSMProblems.jl b/src/SSMProblems.jl index fd91a4b..b3f1cac 100644 --- a/src/SSMProblems.jl +++ b/src/SSMProblems.jl @@ -4,38 +4,35 @@ A unified interface to define State Space Models interfaces in the context of Pa module SSMProblems """ + AbstractStateSpaceModel """ -abstract type AbstractParticle end +abstract type AbstractStateSpaceModel end abstract type AbstractParticleCache end """ - transition!!(rng, step, particle[, cache]) + transition!!(rng, model[, state, timestep, cache]) Simulate the particle for the next time step from the forward dynamics. """ function transition!! end """ - transition_logdensity(step, particle, x[, cache]) + transition_logdensity(model, prev_state, current_state[, timestep, cache]) (Optional) Computes the log-density of the forward transition if the density is available. """ function transition_logdensity end """ - emission_logdensity(step, particle[, cache]) + emission_logdensity(model, state, observation[, timestep, cache]) -Compute the log potential of current particle. This effectively "reweight" each particle. +Compute the log potential of the current particle. This effectively "reweight" each particle. """ function emission_logdensity end -""" - isdone(step, particle[, cache]) - -Determine whether we have reached the last time step of the Markov process. Return `true` if yes, otherwise return `false`. -""" -function isdone end +# Include utils and adjacent code +include("utils/particles.jl") -export transition!!, transition_logdensity, emission_logdensity, isdone, AbstractParticle +export AbstractStateSpaceModel end diff --git a/src/utils/particles.jl b/src/utils/particles.jl new file mode 100644 index 0000000..15ecb5d --- /dev/null +++ b/src/utils/particles.jl @@ -0,0 +1,43 @@ +""" + Common concrete implementations of Particle types for Particle Filter kernels. +""" +module Utils + +abstract type Node{T} end + +struct Root{T} <: Node{T} end +Root(T) = Root{T}() +Root() = Root(Any) + +""" + Particle{T} + +Particle as immutable LinkedList. +""" +struct Particle{T} <: Node{T} + parent::Node{T} + state::T +end + +Particle(state::T) where {T} = Particle(Root(T), state) + +Base.show(io::IO, p::Particle{T}) where {T} = print(io, "Particle{$T}($(p.state))") + +const ParticleContainer{T} = AbstractVector{<:Particle{T}} + +""" + linearize(particle) + +Return the trace of a particle, i.e. the sequence of states from the root to the particle. +""" +function linearize(particle::Particle{T}) where {T} + trace = T[] + current = particle + while !isa(current, Root) + push!(trace, current.state) + current = current.parent + end + return trace +end + +end