Skip to content

Commit

Permalink
Merge 593d305 into bfcf29c
Browse files Browse the repository at this point in the history
  • Loading branch information
FredericWantiez authored Sep 7, 2023
2 parents bfcf29c + 593d305 commit 14a0106
Show file tree
Hide file tree
Showing 4 changed files with 171 additions and 59 deletions.
Binary file added docs/src/images/state_space_model.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
88 changes: 87 additions & 1 deletion docs/src/index.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,92 @@
# SSMProblems

### API
### Installation
In the `julia` REPL:
```julia
]add SSMProblems
```

### Documentation

`SSMProblems` defines a generic interface for State Space Problems (SSM). The main objective is to provide a consistent
interface to work with SSMs and their logdensities.

Consider a markovian model from[^Murray]:
![state space model](./docs/images/state_space_model.png)

[^Murray]:
> Murray, Lawrence & Lee, Anthony & Jacob, Pierre. (2013). Rethinking resampling in the particle filter on graphics processing units.
The model is fully specified by the following densities:
- __Initialisation__: ``f_0(x)``
- __Transition__: ``f(x)``
- __Emission__: ``g(x)``

And the dynamics of the model reduces to:
```math
x_t | x_{t-1} \sim f(x_t | x_{t-1})
```
```math
y_t | x_t \sim g(y_t | x_{t})
```
assuming ``x_0 \sim f_0(x)``.

The joint law follows:

```math
p(x_{0:T}, y_{0:T}) = f_0(x_0) \prod_t g(y_t | x_t) f(x_t | x_{t-1})
```

Model users can define their `SSM` using the following interface:
```julia

struct Model <: AbstractStateSpaceModel end

# Define the structure of the latent space
particleof(::Model) = Float64
dimension(::Model) = 2

function transition!!(
rng::Random.AbstractRNG,
step,
model::Model,
particle::AbstractParticl{<:AbstractStateSpaceModel}
)
if step == 1
... # Sample from the initial density
end
... # Sample from the transition density
end

function emission_logdensity(step, model::Model, particle::AbstractParticle)
... # Return log density of the model at *time* `step`
end

isdone(step, model::Model, particle::AbstractParticle) = ... # Stops the state machine

# Optionally, if the transition density is known, the model can also specify it
function transition_logdensity(step, prev_particle::AbstractParticle, next_particle::AbstractParticle)
... # Scores the forward transition at `x`
end
```

Package users can then consume the model `logdensity` through calls to `emission_logdensity`.

For example, a bootstrap filter targeting the filtering distribution ``p(x_t | y_{0:t})`` using `N` particles would roughly follow:
```julia
struct Particle{T<:AbstractStateSpaceModel} <: AbstractParticle{T} end

while !all(map(particle -> isdone(t, model, particles), particles)):
ancestors = resample(rng, logweigths)
particles = particles[ancestors]
for i in 1:N
particles[i] = transition!!(rng, t, model, particles[i])
logweights[i] += emission_logdensity(t, model, particles[i])
end
end
```

### Interface
```@autodocs
Modules = [SSMProblems]
Order = [:type, :function]
Expand Down
122 changes: 77 additions & 45 deletions examples/smc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,23 @@ using Distributions
using Plots
using StatsFuns

# Particle Filter implementation
struct Particle{T} # Here we just need a tree
parent::Union{Particle,Nothing}
# Particle Filter
abstract type Node{T} end

struct Root{T} <: Node{T} end
Root(T) = Root{T}()
Root() = Root(Any)

struct Particle{T} <: Node{T}
parent::Node{T}
state::T
end

Particle(state::T) where {T} = Particle(nothing, state)
Particle() = Particle(nothing, nothing)
Base.show(io::IO, p::Particle) = print(io, "Particle($(p.state))")
Particle(state::T) where {T} = Particle(Root(T), state)

Base.show(io::IO, p::Particle{T}) where {T} = print(io, "Particle{$T}($(p.state))")

const ParticleContainer{T} = AbstractVector{<:Particle{T}}

"""
linearize(particle)
Expand All @@ -21,16 +29,14 @@ Return the trace of a particle, i.e. the sequence of states from the root to the
"""
function linearize(particle::Particle{T}) where {T}
trace = T[]
parent = particle.parent
while !isnothing(parent)
push!(trace, parent.state)
parent = parent.parent
current = particle
while !isa(current, Root)
push!(trace, current.state)
current = current.parent
end
return trace
end

ParticleContainer = AbstractVector{<:Particle}

ess(weights) = inv(sum(abs2, weights))
get_weights(logweights::T) where {T<:AbstractVector{<:Real}} = StatsFuns.softmax(logweights)

Expand All @@ -40,79 +46,105 @@ function systematic_resampling(
return rand(rng, Distributions.Categorical(weights), n)
end

function sweep!(rng::AbstractRNG, particles::ParticleContainer, resampling, threshold=0.5)
t = 1
function sweep!(
rng::AbstractRNG,
model::AbstractStateSpaceModel,
particles::ParticleContainer,
observations::AbstractArray,
resampling=systematic_resampling,
threshold=0.5,
)
N = length(particles)
logweights = zeros(length(particles))
while !isdone(t, particles[1].state)
logweights = zeros(N)

for (timestep, observation) in enumerate(observations)
# Resample step
weights = get_weights(logweights)
if ess(weights) <= threshold * N
idx = resampling(rng, weights)
particles = particles[idx]
logweights = zeros(length(particles))
fill!(logweights, 0)
end

# Mutation step
for i in eachindex(particles)
parent = particles[i]
mutated = transition!!(rng, t, parent.state)
particles[i] = Particle(parent, mutated)
logweights[i] += emission_logdensity(t, particles[i].state)
latent_state = transition!!(rng, model, timestep, particles[i].state)
particles[i] = Particle(particles[i], latent_state)
logweights[i] += emission_logdensity(
model, timestep, particles[i].state, observation
)
end

t += 1
end

# Return unweighted set
idx = resampling(rng, get_weights(logweights))
return particles[idx]
end

function sweep!(rng::AbstractRNG, n::Int, resampling, threshold=0.5)
particles = [Particle(0.0) for _ in 1:n]
return sweep!(rng, particles, resampling, threshold)
# Turing style sample method
function sample(
rng::AbstractRNG,
model::AbstractStateSpaceModel,
n::Int,
observations::AbstractVector;
resampling=systematic_resampling,
threshold=0.5,
)
particles = map(1:N) do i
state = transition!!(rng, model)
Particle(state)
end
samples = sweep!(rng, model, particles, observations, resampling, threshold)
return samples
end

# Inference code
Base.@kwdef struct Parameters
Base.@kwdef struct LinearSSM <: AbstractStateSpaceModel
v::Float64 = 0.2 # Transition noise stdev
u::Float64 = 0.7 # Observation noise stdev
end

# Simulation
T = 250
seed = 1
N = 1000
N = 1_000
rng = MersenneTwister(seed)
params = Parameters(; v=0.2, u=0.7)

function transition!!(rng::AbstractRNG, t::Int, state=nothing)
if isnothing(state)
return rand(rng, Normal(0, 1))
end
return rand(rng, Normal(state, params.v))
end

function emission_logdensity(t, state)
return logpdf(Normal(state, params.u), observations[t])
end
model = LinearSSM(0.2, 0.7)

isdone(t, state) = t > T
isdone(t, ::Nothing) = false
f0() = Normal(0, 1)
f(t::Int, x::Float64) = Normal(x, model.v)
g(t::Int, y::Float64) = Normal(y, model.u)

# Generate synthtetic data
x, observations = zeros(T), zeros(T)
x[1] = rand(rng, Normal(0, 1))
x[1] = rand(rng, f0())
for t in 1:T
observations[t] = rand(rng, Normal(x[t], params.u))
observations[t] = rand(rng, g(t, x[t]))
if t < T
x[t + 1] = rand(rng, Normal(x[t], params.v))
x[t + 1] = rand(rng, f(t, x[t]))
end
end

samples = sweep!(rng, fill(Particle(x[1]), N), systematic_resampling)
# Model dynamics
function transition!!(rng::AbstractRNG, model::LinearSSM)
return rand(rng, f0())
end

function transition!!(rng::AbstractRNG, model::LinearSSM, timestep::Int, state::Float64)
return rand(rng, f(timestep, state))
end

function emission_logdensity(
model::LinearSSM, timestep::Int, state::Float64, observation::Float64
)
return logpdf(g(timestep, state), observation)
end

# Sample latent state trajectories
samples = sample(rng, LinearSSM(), N, observations)
traces = reverse(hcat(map(linearize, samples)...))

scatter(traces; color=:black, opacity=0.3, label=false)
plot!(x; label="True state")
plot!(mean(traces; dims=2); label="Posterior mean")
20 changes: 7 additions & 13 deletions src/SSMProblems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,38 +4,32 @@ A unified interface to define State Space Models interfaces in the context of Pa
module SSMProblems

"""
AbstractStateSpaceModel
"""
abstract type AbstractParticle end
abstract type AbstractStateSpaceModel end
abstract type AbstractParticleCache end

"""
transition!!(rng, step, particle[, cache])
transition!!(rng, model[, timestep, state, cache])
Simulate the particle for the next time step from the forward dynamics.
"""
function transition!! end

"""
transition_logdensity(step, particle, x[, cache])
transition_logdensity(model, timestep, prev_state, next_state[, cache])
(Optional) Computes the log-density of the forward transition if the density is available.
"""
function transition_logdensity end

"""
emission_logdensity(step, particle[, cache])
emission_logdensity(model, timestep, state, observation[, cache])
Compute the log potential of current particle. This effectively "reweight" each particle.
Compute the log potential of the current particle. This effectively "reweight" each particle.
"""
function emission_logdensity end

"""
isdone(step, particle[, cache])
Determine whether we have reached the last time step of the Markov process. Return `true` if yes, otherwise return `false`.
"""
function isdone end

export transition!!, transition_logdensity, emission_logdensity, isdone, AbstractParticle
export AbstractStateSpaceModel

end

0 comments on commit 14a0106

Please sign in to comment.