Skip to content

Commit

Permalink
Merge d603f56 into bfcf29c
Browse files Browse the repository at this point in the history
  • Loading branch information
FredericWantiez authored Jul 29, 2023
2 parents bfcf29c + d603f56 commit e5e6194
Show file tree
Hide file tree
Showing 4 changed files with 178 additions and 48 deletions.
Binary file added docs/src/images/state_space_model.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
76 changes: 75 additions & 1 deletion docs/src/index.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,80 @@
# SSMProblems

### API
### Installation
In the `julia` REPL:
```julia
]add SSMProblems
```

### Documentation

The package defines a generic interface to work with State Space Problems (SSM). The main objective is to provide a consistent
interface to work with SSMs and their logdensities.

Consider a markovian model from[^Murray]:
![state space model](docs/images/state_space_model.png)

[^Murray]:
> Murray, Lawrence & Lee, Anthony & Jacob, Pierre. (2013). Rethinking resampling in the particle filter on graphics processing
units.

The model is fully specified by the following densities:
- __Initialisation__: ``f_0(x)``
- __Transition__: ``f(x)``
- __Emission__: ``g(x)``

And the dynamics of the model reduces to:
```math
x_t | x_{t-1} \sim f(x_t | x_{t-1})
```
```math
y_t | x_t \sim g(y_t | x_{t})
```
assuming ``x_0 \sim f_0(x)``. The joint law is then fully described:

```math
p(x_{0:T}, y_{0:T}) = f_0{x_0} \prod_t g(y_t | x_t) f(x_t | x_{t-1})
```

Model users can define their `SSM` using the following interface:
```julia

struct Model <: AbstractParticle end

function transition!!(rng, step, model::Model)
if step == 1
... # Sample from the initial density
end
... # Sample from the transition density
end

function emission_logdensity(step, model::Model)
... # Return log density of the model at
end

isdone(step, model::Model) = ... # Define the stopping criterion

# Optionally, if the transition density is known, the model can also specify it
function transition_logdensity(step, particle, x)
... # Scores the forward transition at `x`
end
```

Package users can then consume the model `logdensity` through calls to `emission_logdensity`.

For example, a bootstrap filter targeting the filtering distribution ``p(x_t | y_{0:t})`` using `N` particles would read:
```julia
while !isdone(t, model)
ancestors = resample(rng, logweigths)
particles = particles[ancestors]
for i in 1:N
particles[i] = transition!!(rng, t, particles[i])
logweights[i] += emission_logdensity(t, particles[i])
end
end
```

### Interface
```@autodocs
Modules = [SSMProblems]
Order = [:type, :function]
Expand Down
125 changes: 84 additions & 41 deletions examples/smc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,31 +5,43 @@ using Plots
using StatsFuns

# Particle Filter implementation
struct Particle{T} # Here we just need a tree
struct Particle{T<:AbstractStateSpaceModel,V} <: AbstractParticle{T}
parent::Union{Particle,Nothing}
state::T
state::V
end

Particle(state::T) where {T} = Particle(nothing, state)
Particle() = Particle(nothing, nothing)
function Particle(parent, state, model::T) where {T<:AbstractStateSpaceModel}
return Particle{T,particleof(model)}(parent, state)
end
function Particle(model::T) where {T<:AbstractStateSpaceModel}
N = dimension(model)
V = particleof(model)
state = N == 1 ? zero(V) : zeros(V, N)
return Particle{T,V}(nothing, state)
end
Base.show(io::IO, p::Particle) = print(io, "Particle($(p.state))")

"""
linearize(particle)
Return the trace of a particle, i.e. the sequence of states from the root to the particle.
"""
function linearize(particle::Particle{T}) where {T}
trace = T[]
parent = particle.parent
while !isnothing(parent)
push!(trace, parent.state)
parent = parent.parent
function linearize(particle::Particle{T,V}) where {T,V}
trace = V[]
current = particle
while !isnothing(current)
push!(trace, current.state)
current = current.parent
end
return trace
return trace[1:(end - 1)]
end

ParticleContainer = AbstractVector{<:Particle}
const ParticleContainer{T} = AbstractVector{<:Particle{T}}

# Specialize `isdone` to the concrete `Particle` type
function isdone(t, model::AbstractStateSpaceModel, particles::ParticleContainer)
return all(map(particle -> isdone(t, model, particle), particles))
end

ess(weights) = inv(sum(abs2, weights))
get_weights(logweights::T) where {T<:AbstractVector{<:Real}} = StatsFuns.softmax(logweights)
Expand All @@ -40,11 +52,17 @@ function systematic_resampling(
return rand(rng, Distributions.Categorical(weights), n)
end

function sweep!(rng::AbstractRNG, particles::ParticleContainer, resampling, threshold=0.5)
function sweep!(
rng::AbstractRNG,
model::AbstractStateSpaceModel,
particles::ParticleContainer,
resampling,
threshold=0.5,
)
t = 1
N = length(particles)
logweights = zeros(length(particles))
while !isdone(t, particles[1].state)
while !isdone(t, model, particles)

# Resample step
weights = get_weights(logweights)
Expand All @@ -57,9 +75,9 @@ function sweep!(rng::AbstractRNG, particles::ParticleContainer, resampling, thre
# Mutation step
for i in eachindex(particles)
parent = particles[i]
mutated = transition!!(rng, t, parent.state)
particles[i] = Particle(parent, mutated)
logweights[i] += emission_logdensity(t, particles[i].state)
mutated = transition!!(rng, t, model, parent)
particles[i] = mutated
logweights[i] += emission_logdensity(t, model, particles[i])
end

t += 1
Expand All @@ -70,49 +88,74 @@ function sweep!(rng::AbstractRNG, particles::ParticleContainer, resampling, thre
return particles[idx]
end

function sweep!(rng::AbstractRNG, n::Int, resampling, threshold=0.5)
particles = [Particle(0.0) for _ in 1:n]
return sweep!(rng, particles, resampling, threshold)
function sample(
rng::AbstractRNG,
n::Int,
model::AbstractStateSpaceModel;
resampling=systematic_resampling,
threshold=0.5,
)
particles = fill(Particle(model), N)
samples = sweep!(rng, model, particles, resampling, threshold)
return samples
end

# Inference code
Base.@kwdef struct Parameters
Base.@kwdef struct LinearSSM <: AbstractStateSpaceModel
v::Float64 = 0.2 # Transition noise stdev
u::Float64 = 0.7 # Observation noise stdev
end

# Simulation
T = 250
seed = 1
N = 1000
N = 1_000
rng = MersenneTwister(seed)
params = Parameters(; v=0.2, u=0.7)

function transition!!(rng::AbstractRNG, t::Int, state=nothing)
if isnothing(state)
return rand(rng, Normal(0, 1))
end
return rand(rng, Normal(state, params.v))
end

function emission_logdensity(t, state)
return logpdf(Normal(state, params.u), observations[t])
end
model = LinearSSM(0.2, 0.7)

isdone(t, state) = t > T
isdone(t, ::Nothing) = false
f0(t) = Normal(0, 1)
f(t, x) = Normal(x, model.v)
g(t, x) = Normal(x, model.u)

# Generate synthtetic data
x, observations = zeros(T), zeros(T)
x[1] = rand(rng, Normal(0, 1))
x[1] = rand(rng, f0(1))
for t in 1:T
observations[t] = rand(rng, Normal(x[t], params.u))
observations[t] = rand(rng, g(t, x[t]))
if t < T
x[t + 1] = rand(rng, Normal(x[t], params.v))
x[t + 1] = rand(rng, f(t, x[t]))
end
end

samples = sweep!(rng, fill(Particle(x[1]), N), systematic_resampling)
function transition!!(
rng::AbstractRNG, t::Int, model::LinearSSM, particle::AbstractParticle
)
if t == 1
return Particle(particle, rand(rng, f0(t)), model)
else
return Particle(particle, rand(rng, f(t, particle.state)), model)
end
end

function emission_logdensity(t, model::LinearSSM, particle::AbstractParticle)
return logpdf(g(t, particle.state), observations[t])
end

# isdone
isdone(t, ::LinearSSM, ::AbstractParticle) = t > T

# Type of latent space
# particleof(::S) :: S -> T
# f(t, x) : Int -> T -> T
particleof(::LinearSSM) = Float64
dimension(::LinearSSM) = 1

samples = sample(rng, N, LinearSSM())
traces = reverse(hcat(map(linearize, samples)...))

scatter(traces; color=:black, opacity=0.3, label=false)
plot!(x; label="True state")
#scatter(traces; color=:black, opacity=0.3, label=false)
plot(x; label="True state")
plot!(mean(traces; dims=2); label="Posterior mean")

gui()
25 changes: 19 additions & 6 deletions src/SSMProblems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,37 +5,50 @@ module SSMProblems

"""
"""
abstract type AbstractParticle end
abstract type AbstractStateSpaceModel end
abstract type AbstractParticle{T<:AbstractStateSpaceModel} end
abstract type AbstractParticleCache end

"""
transition!!(rng, step, particle[, cache])
transition!!(rng, step, model, particle[, cache])
Simulate the particle for the next time step from the forward dynamics.
"""
function transition!! end

"""
transition_logdensity(step, particle, x[, cache])
transition_logdensity(step, model, particle, x[, cache])
(Optional) Computes the log-density of the forward transition if the density is available.
"""
function transition_logdensity end

"""
emission_logdensity(step, particle[, cache])
emission_logdensity(step, model, particle[, cache])
Compute the log potential of current particle. This effectively "reweight" each particle.
"""
function emission_logdensity end

"""
isdone(step, particle[, cache])
isdone(step, model, particle[, cache])
Determine whether we have reached the last time step of the Markov process. Return `true` if yes, otherwise return `false`.
"""
function isdone end

export transition!!, transition_logdensity, emission_logdensity, isdone, AbstractParticle
"""
particleof(::Type{AbstractStateSpaceModel})
"""
particleof(::Type{AbstractStateSpaceModel}) = Nothing
particleof(model::AbstractStateSpaceModel) = particleof(typeof(model))

export transition!!,
transition_logdensity,
emission_logdensity,
isdone,
AbstractParticle,
AbstractStateSpaceModel,
dimension

end

0 comments on commit e5e6194

Please sign in to comment.