Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update documentation #16

Merged
merged 18 commits into from
Sep 13, 2023
Binary file added docs/src/images/state_space_model.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
87 changes: 86 additions & 1 deletion docs/src/index.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,91 @@
# SSMProblems

### API
### Installation
In the `julia` REPL:
```julia
]add SSMProblems
```

### Documentation

`SSMProblems` defines a generic interface for State Space Problems (SSM). The main objective is to provide a consistent
interface to work with SSMs and their logdensities.

Consider a markovian model from[^Murray]:
yebai marked this conversation as resolved.
Show resolved Hide resolved
![state space model](./docs/images/state_space_model.png)

[^Murray]:
> Murray, Lawrence & Lee, Anthony & Jacob, Pierre. (2013). Rethinking resampling in the particle filter on graphics processing units.

The model is fully specified by the following densities:
- __Initialisation__: ``f_0(x)``
- __Transition__: ``f(x)``
- __Emission__: ``g(x)``
Comment on lines +21 to +23
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it can be a bit confusing here to use x which otherwise only denotes (internal) states. Maybe just define it in/via the model equations below and remove this list?


And the dynamics of the model reduces to:
```math
\begin{aligned}
x_t | x_{t-1} &\sim f(x_t | x_{t-1}) \\
y_t | x_t &\sim g(y_t | x_{t})
\end{aligned}
```
assuming ``x_0 \sim f_0(x)``.

The joint law follows:

```math
p(x_{0:T}, y_{0:T}) = f_0(x_0) \prod_t g(y_t | x_t) f(x_t | x_{t-1})
```

Users can define their SSM with `SSMProblems` in the following way:
```julia
struct Model <: AbstractStateSpaceModel end

# Define the structure of the latent space
particleof(::Model) = Float64
dimension(::Model) = 2
Comment on lines +45 to +46
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO these names are quite confusing. It was not clear to me at all that particleof was referring to some type (which type? of the states I assume?) and that dimension only refers to the internal states and does not consider y.


function transition!!(
rng::Random.AbstractRNG,
step,
model::Model,
particle::AbstractParticl{<:AbstractStateSpaceModel}
)
if step == 1
... # Sample from the initial density
end
... # Sample from the transition density
Comment on lines +54 to +57
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think if every user has to write this check for every model, this indicates that the API could be better designed. Maybe the easiest would be to have different functions for whether its the initial transition or not?

Or maybe the example just incorrectly suggestes that the code always takes this structure?

end

function emission_logdensity(step, model::Model, particle::AbstractParticle)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
function emission_logdensity(step, model::Model, particle::AbstractParticle)
function emission_logdensity(step, model::Model, particle::AbstractParticle, observations)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again I think this needs more explanations - what does this function do? What are the arguments? Should it return something or modify something?

... # Return log density of the model at *time* `step`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean with "log density of the model"? Just g(y_t | x_t) I assume based on the function name?

end

isdone(step, model::Model, particle::AbstractParticle) = ... # Stops the state machine
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is a state machine?

More conceptually/design-wise, why do we have to define such a function? The Markovian model above works just fine without any such restrictions - wouldn't it be more natural to just iterate as long as there are observations?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will consider removing isdone function, and get such information from the size of the observation array.


# Optionally, if the transition density is known, the model can also specify it
function transition_logdensity(step, prev_particle::AbstractParticle, next_particle::AbstractParticle)
... # Scores the forward transition at `x`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean here?

end
```

Package users can then consume the model `logdensity` through calls to `emission_logdensity`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, I'm not sure about the term "model logdensity" here due to the multitude of different log densities that could be considered in this setting.


For example, a bootstrap filter targeting the filtering distribution ``p(x_t | y_{0:t})`` using `N` particles would roughly follow:
```julia
struct Particle{T<:AbstractStateSpaceModel} <: AbstractParticle{T} end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This refers back to a question above - what is the interface and interpretation of AbstractParticle? And why are particles coupled with the model? Intuitively, to me they seem to be a separate entity.


while !all(map(particle -> isdone(t, model, particles), particles)):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume this should be

Suggested change
while !all(map(particle -> isdone(t, model, particles), particles)):
while !all(map(particle -> isdone(t, model, particle), particles)):

Conceptually, again I think the size of this loop should be governed rather by the observations than the model or particles.

ancestors = resample(rng, logweigths)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to reset all log weights to zero after the resampling step.

Suggested change
ancestors = resample(rng, logweigths)
ancestors = resample(rng, logweigths)
logweights = zero(logweights)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is resample, also a part of the interface? I think more explanations would make this example much clearer.

particles = particles[ancestors]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shadowing the variable makes it more difficult to follow the code, I assume. Maybe just save the particles at each time point for illustration purposes to make the sequential nature and the time-dependency of the particles clearer.

for i in 1:N
particles[i] = transition!!(rng, t, model, particles[i])
logweights[i] += emission_logdensity(t, model, particles[i])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't/shouldn't you just set

Suggested change
logweights[i] += emission_logdensity(t, model, particles[i])
logweights[i] = emission_logdensity(t, model, particles[i])

end
end
```

### Interface
```@autodocs
Modules = [SSMProblems]
Order = [:type, :function]
Expand Down
120 changes: 59 additions & 61 deletions examples/smc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,33 +4,7 @@ using Distributions
using Plots
using StatsFuns

# Particle Filter implementation
struct Particle{T} # Here we just need a tree
parent::Union{Particle,Nothing}
state::T
end

Particle(state::T) where {T} = Particle(nothing, state)
Particle() = Particle(nothing, nothing)
Base.show(io::IO, p::Particle) = print(io, "Particle($(p.state))")

"""
linearize(particle)

Return the trace of a particle, i.e. the sequence of states from the root to the particle.
"""
function linearize(particle::Particle{T}) where {T}
trace = T[]
parent = particle.parent
while !isnothing(parent)
push!(trace, parent.state)
parent = parent.parent
end
return trace
end

ParticleContainer = AbstractVector{<:Particle}

# Particle Filter
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to keep this code block in the module?

ess(weights) = inv(sum(abs2, weights))
get_weights(logweights::T) where {T<:AbstractVector{<:Real}} = StatsFuns.softmax(logweights)

Expand All @@ -40,79 +14,103 @@ function systematic_resampling(
return rand(rng, Distributions.Categorical(weights), n)
end

function sweep!(rng::AbstractRNG, particles::ParticleContainer, resampling, threshold=0.5)
t = 1
function sweep!(
rng::AbstractRNG,
model::AbstractStateSpaceModel,
particles::SSMProblems.Utils.ParticleContainer,
observations::AbstractArray,
resampling=systematic_resampling,
threshold=0.5,
)
N = length(particles)
logweights = zeros(length(particles))
while !isdone(t, particles[1].state)
logweights = zeros(N)

for (timestep, observation) in enumerate(observations)
# Resample step
weights = get_weights(logweights)
if ess(weights) <= threshold * N
idx = resampling(rng, weights)
particles = particles[idx]
logweights = zeros(length(particles))
fill!(logweights, 0)
end

# Mutation step
for i in eachindex(particles)
parent = particles[i]
mutated = transition!!(rng, t, parent.state)
particles[i] = Particle(parent, mutated)
logweights[i] += emission_logdensity(t, particles[i].state)
latent_state = transition!!(rng, model, particles[i].state, timestep)
particles[i] = SSMProblems.Utils.Particle(particles[i], latent_state)
logweights[i] += emission_logdensity(
model, particles[i].state, observation, timestep
)
end

t += 1
end

# Return unweighted set
idx = resampling(rng, get_weights(logweights))
return particles[idx]
end

function sweep!(rng::AbstractRNG, n::Int, resampling, threshold=0.5)
particles = [Particle(0.0) for _ in 1:n]
return sweep!(rng, particles, resampling, threshold)
# Turing style sample method
function sample(
rng::AbstractRNG,
model::AbstractStateSpaceModel,
n::Int,
observations::AbstractVector;
resampling=systematic_resampling,
threshold=0.5,
)
particles = map(1:N) do i
state = transition!!(rng, model)
SSMProblems.Utils.Particle(state)
end
samples = sweep!(rng, model, particles, observations, resampling, threshold)
return samples
end

# Inference code
Base.@kwdef struct Parameters
Base.@kwdef struct LinearSSM <: AbstractStateSpaceModel
v::Float64 = 0.2 # Transition noise stdev
u::Float64 = 0.7 # Observation noise stdev
end

# Simulation
T = 250
seed = 1
N = 1000
N = 1_000
rng = MersenneTwister(seed)
params = Parameters(; v=0.2, u=0.7)

function transition!!(rng::AbstractRNG, t::Int, state=nothing)
if isnothing(state)
return rand(rng, Normal(0, 1))
end
return rand(rng, Normal(state, params.v))
end
model = LinearSSM(0.2, 0.7)

function emission_logdensity(t, state)
return logpdf(Normal(state, params.u), observations[t])
end

isdone(t, state) = t > T
isdone(t, ::Nothing) = false
f0(::LinearSSM) = Normal(0, 1)
f(x::Float64, model::LinearSSM) = Normal(x, model.v)
g(y::Float64, model::LinearSSM) = Normal(y, model.u)

# Generate synthtetic data
x, observations = zeros(T), zeros(T)
x[1] = rand(rng, Normal(0, 1))
x[1] = rand(rng, f0(model))
for t in 1:T
observations[t] = rand(rng, Normal(x[t], params.u))
observations[t] = rand(rng, g(x[t], model))
if t < T
x[t + 1] = rand(rng, Normal(x[t], params.v))
x[t + 1] = rand(rng, f(x[t], model))
end
end

samples = sweep!(rng, fill(Particle(x[1]), N), systematic_resampling)
traces = reverse(hcat(map(linearize, samples)...))
# Model dynamics
function transition!!(rng::AbstractRNG, model::LinearSSM)
return rand(rng, f0(model))
end

function transition!!(rng::AbstractRNG, model::LinearSSM, state::Float64, ::Int)
return rand(rng, f(state, model))
end

function emission_logdensity(model::LinearSSM, state::Float64, observation::Float64, ::Int)
return logpdf(g(state, model), observation)
end

# Sample latent state trajectories
samples = sample(rng, LinearSSM(), N, observations)
traces = reverse(hcat(map(SSMProblems.Utils.linearize, samples)...))

scatter(traces; color=:black, opacity=0.3, label=false)
plot!(x; label="True state")
plot!(mean(traces; dims=2); label="Posterior mean")
21 changes: 9 additions & 12 deletions src/SSMProblems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,38 +4,35 @@ A unified interface to define State Space Models interfaces in the context of Pa
module SSMProblems

"""
AbstractStateSpaceModel
"""
abstract type AbstractParticle end
abstract type AbstractStateSpaceModel end
abstract type AbstractParticleCache end

"""
transition!!(rng, step, particle[, cache])
transition!!(rng, model[, state, timestep, cache])

Simulate the particle for the next time step from the forward dynamics.
"""
function transition!! end

"""
transition_logdensity(step, particle, x[, cache])
transition_logdensity(model, prev_state, current_state[, timestep, cache])

(Optional) Computes the log-density of the forward transition if the density is available.
"""
function transition_logdensity end

"""
emission_logdensity(step, particle[, cache])
emission_logdensity(model, state, observation[, timestep, cache])

Compute the log potential of current particle. This effectively "reweight" each particle.
Compute the log potential of the current particle. This effectively "reweight" each particle.
"""
function emission_logdensity end

"""
isdone(step, particle[, cache])

Determine whether we have reached the last time step of the Markov process. Return `true` if yes, otherwise return `false`.
"""
function isdone end
# Include utils and adjacent code
include("utils/particles.jl")

export transition!!, transition_logdensity, emission_logdensity, isdone, AbstractParticle
export AbstractStateSpaceModel

end
43 changes: 43 additions & 0 deletions src/utils/particles.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""
Common concrete implementations of Particle types for Particle Filter kernels.
"""
module Utils

abstract type Node{T} end

struct Root{T} <: Node{T} end
Root(T) = Root{T}()
Root() = Root(Any)

"""
Particle{T}

Particle as immutable LinkedList.
"""
struct Particle{T} <: Node{T}
parent::Node{T}
state::T
end

Particle(state::T) where {T} = Particle(Root(T), state)

Base.show(io::IO, p::Particle{T}) where {T} = print(io, "Particle{$T}($(p.state))")

const ParticleContainer{T} = AbstractVector{<:Particle{T}}

"""
linearize(particle)

Return the trace of a particle, i.e. the sequence of states from the root to the particle.
"""
function linearize(particle::Particle{T}) where {T}
trace = T[]
current = particle
while !isa(current, Root)
push!(trace, current.state)
current = current.parent
end
return trace
end

end
Loading