Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update documentation #16

Merged
merged 18 commits into from
Sep 13, 2023
Binary file added docs/src/images/state_space_model.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
76 changes: 75 additions & 1 deletion docs/src/index.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,80 @@
# SSMProblems

### API
### Installation
In the `julia` REPL:
```julia
]add SSMProblems
```

### Documentation

The package defines a generic interface to work with State Space Problems (SSM). The main objective is to provide a consistent
interface to work with SSMs and their logdensities.

Consider a markovian model from[^Murray]:
yebai marked this conversation as resolved.
Show resolved Hide resolved
![state space model](docs/images/state_space_model.png)
yebai marked this conversation as resolved.
Show resolved Hide resolved

[^Murray]:
> Murray, Lawrence & Lee, Anthony & Jacob, Pierre. (2013). Rethinking resampling in the particle filter on graphics processing
units.

The model is fully specified by the following densities:
- __Initialisation__: ``f_0(x)``
- __Transition__: ``f(x)``
- __Emission__: ``g(x)``
Comment on lines +21 to +23
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it can be a bit confusing here to use x which otherwise only denotes (internal) states. Maybe just define it in/via the model equations below and remove this list?


And the dynamics of the model reduces to:
```math
x_t | x_{t-1} \sim f(x_t | x_{t-1})
```
```math
y_t | x_t \sim g(y_t | x_{t})
FredericWantiez marked this conversation as resolved.
Show resolved Hide resolved
```
assuming ``x_0 \sim f_0(x)``. The joint law is then fully described:

```math
p(x_{0:T}, y_{0:T}) = f_0{x_0} \prod_t g(y_t | x_t) f(x_t | x_{t-1})
yebai marked this conversation as resolved.
Show resolved Hide resolved
yebai marked this conversation as resolved.
Show resolved Hide resolved
```

Model users can define their `SSM` using the following interface:
FredericWantiez marked this conversation as resolved.
Show resolved Hide resolved
```julia

FredericWantiez marked this conversation as resolved.
Show resolved Hide resolved
struct Model <: AbstractParticle end

function transition!!(rng, step, model::Model)
if step == 1
... # Sample from the initial density
end
... # Sample from the transition density
Comment on lines +54 to +57
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think if every user has to write this check for every model, this indicates that the API could be better designed. Maybe the easiest would be to have different functions for whether its the initial transition or not?

Or maybe the example just incorrectly suggestes that the code always takes this structure?

end

function emission_logdensity(step, model::Model)
... # Return log density of the model at
end

isdone(step, model::Model) = ... # Define the stopping criterion

# Optionally, if the transition density is known, the model can also specify it
function transition_logdensity(step, particle, x)
yebai marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again more explanations are needed IMO - what are the arguments and their types?

... # Scores the forward transition at `x`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean here?

end
```

Package users can then consume the model `logdensity` through calls to `emission_logdensity`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, I'm not sure about the term "model logdensity" here due to the multitude of different log densities that could be considered in this setting.


For example, a bootstrap filter targeting the filtering distribution ``p(x_t | y_{0:t})`` using `N` particles would read:
```julia
while !isdone(t, model)
ancestors = resample(rng, logweigths)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to reset all log weights to zero after the resampling step.

Suggested change
ancestors = resample(rng, logweigths)
ancestors = resample(rng, logweigths)
logweights = zero(logweights)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is resample, also a part of the interface? I think more explanations would make this example much clearer.

particles = particles[ancestors]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shadowing the variable makes it more difficult to follow the code, I assume. Maybe just save the particles at each time point for illustration purposes to make the sequential nature and the time-dependency of the particles clearer.

for i in 1:N
particles[i] = transition!!(rng, t, particles[i])
logweights[i] += emission_logdensity(t, particles[i])
end
end
```

### Interface
```@autodocs
Modules = [SSMProblems]
Order = [:type, :function]
Expand Down
89 changes: 57 additions & 32 deletions examples/smc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@ using Plots
using StatsFuns

# Particle Filter implementation
struct Particle{T} # Here we just need a tree
struct Particle{T} <: AbstractParticle
parent::Union{Particle,Nothing}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The compiler has many tricks and optimizations for this specific type of Union{Nothing,...} but it means that the type of parent can't be inferred.

state::T
model::T
end

Particle(state::T) where {T} = Particle(nothing, state)
Particle(model::T) where {T} = Particle(nothing, model)
Particle() = Particle(nothing, nothing)
Base.show(io::IO, p::Particle) = print(io, "Particle($(p.state))")
Base.show(io::IO, p::Particle) = print(io, "Particle($(p.model))")

"""
linearize(particle)
Expand All @@ -21,16 +21,22 @@ Return the trace of a particle, i.e. the sequence of states from the root to the
"""
function linearize(particle::Particle{T}) where {T}
trace = T[]
parent = particle.parent
parent = particle
while !isnothing(parent)
push!(trace, parent.state)
push!(trace, parent.model)
parent = parent.parent
end
return trace
return trace[1:(end - 1)]
end

ParticleContainer = AbstractVector{<:Particle}

# Specialize `isdone` to the concrete `Particle` type
function isdone(t, particles::AbstractVector{<:Particle})
return all(map(particle -> isdone(t, particle), particles))
end
isdone(t, particle::AbstractParticle) = isdone(t, particle.model)

ess(weights) = inv(sum(abs2, weights))
get_weights(logweights::T) where {T<:AbstractVector{<:Real}} = StatsFuns.softmax(logweights)

Expand All @@ -44,7 +50,7 @@ function sweep!(rng::AbstractRNG, particles::ParticleContainer, resampling, thre
t = 1
N = length(particles)
logweights = zeros(length(particles))
while !isdone(t, particles[1].state)
while !isdone(t, particles)

# Resample step
weights = get_weights(logweights)
Expand All @@ -57,9 +63,9 @@ function sweep!(rng::AbstractRNG, particles::ParticleContainer, resampling, thre
# Mutation step
for i in eachindex(particles)
parent = particles[i]
mutated = transition!!(rng, t, parent.state)
mutated = transition!!(rng, t, parent.model)
particles[i] = Particle(parent, mutated)
logweights[i] += emission_logdensity(t, particles[i].state)
logweights[i] += emission_logdensity(t, particles[i].model)
end

t += 1
Expand All @@ -70,9 +76,16 @@ function sweep!(rng::AbstractRNG, particles::ParticleContainer, resampling, thre
return particles[idx]
end

function sweep!(rng::AbstractRNG, n::Int, resampling, threshold=0.5)
particles = [Particle(0.0) for _ in 1:n]
return sweep!(rng, particles, resampling, threshold)
function sample(
rng::AbstractRNG,
n::Int,
model::AbstractStateSpaceModel;
resampling=systematic_resampling,
threshold=0.5,
)
particles = fill(Particle(model), N)
samples = sweep!(rng, particles, resampling, threshold)
return samples
end

# Inference code
Expand All @@ -81,38 +94,50 @@ Base.@kwdef struct Parameters
u::Float64 = 0.7 # Observation noise stdev
end

struct LinearSSM{T} <: AbstractStateSpaceModel
state::T
end

LinearSSM() = LinearSSM(zero(Float64))
dimension(::LinearSSM) = 1

# Simulation
T = 250
seed = 1
N = 1000
N = 1_000
rng = MersenneTwister(seed)
params = Parameters(; v=0.2, u=0.7)

function transition!!(rng::AbstractRNG, t::Int, state=nothing)
if isnothing(state)
return rand(rng, Normal(0, 1))
end
return rand(rng, Normal(state, params.v))
end

function emission_logdensity(t, state)
return logpdf(Normal(state, params.u), observations[t])
end

isdone(t, state) = t > T
isdone(t, ::Nothing) = false
f0(t) = Normal(0, 1)
f(t, x) = Normal(x, params.v)
g(t, x) = Normal(x, params.u)

# Generate synthtetic data
x, observations = zeros(T), zeros(T)
x[1] = rand(rng, Normal(0, 1))
x[1] = rand(rng, f0(1))
for t in 1:T
observations[t] = rand(rng, Normal(x[t], params.u))
observations[t] = rand(rng, g(t, x[t]))
if t < T
x[t + 1] = rand(rng, Normal(x[t], params.v))
x[t + 1] = rand(rng, f(t, x[t]))
end
end

samples = sweep!(rng, fill(Particle(x[1]), N), systematic_resampling)
traces = reverse(hcat(map(linearize, samples)...))
function transition!!(rng::AbstractRNG, t::Int, model::LinearSSM)
if t == 1
return LinearSSM(rand(rng, f0(t)))
else
return LinearSSM(rand(rng, f(t, model.state)))
end
end

function emission_logdensity(t, model::LinearSSM)
return logpdf(g(t, model.state), observations[t])
end

isdone(t, state::LinearSSM) = t > T

samples = sample(rng, N, LinearSSM())
traces = map(model -> model.state, reverse(hcat(map(linearize, samples)...)))

scatter(traces; color=:black, opacity=0.3, label=false)
plot!(x; label="True state")
27 changes: 21 additions & 6 deletions src/SSMProblems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,37 +5,52 @@ module SSMProblems

"""
"""
abstract type AbstractParticle end
abstract type AbstractStateSpaceModel end
abstract type AbstractParticle end
yebai marked this conversation as resolved.
Show resolved Hide resolved
abstract type AbstractParticleCache end

"""
transition!!(rng, step, particle[, cache])
transition!!(rng, step, model, particle[, cache])
yebai marked this conversation as resolved.
Show resolved Hide resolved

Simulate the particle for the next time step from the forward dynamics.
"""
function transition!! end

"""
transition_logdensity(step, particle, x[, cache])
transition_logdensity(step, model, particle, x[, cache])

(Optional) Computes the log-density of the forward transition if the density is available.
"""
function transition_logdensity end

"""
emission_logdensity(step, particle[, cache])
emission_logdensity(step, model, particle[, cache])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably make observation an explicit argument.

Suggested change
emission_logdensity(step, model, particle[, cache])
emission_logdensity(step, model, particle, observation [, cache])

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add an optional API for sampling from the emission process, so we can easily generate data from the SSM model.


Compute the log potential of current particle. This effectively "reweight" each particle.
"""
function emission_logdensity end

"""
isdone(step, particle[, cache])
isdone(step, model, particle[, cache])

Determine whether we have reached the last time step of the Markov process. Return `true` if yes, otherwise return `false`.
"""
function isdone end

export transition!!, transition_logdensity, emission_logdensity, isdone, AbstractParticle
"""
dimension(::Type{AbstractStateSpaceModel})

Returns the dimension of the state space for a given model type
"""
dimension(::Type{<:AbstractStateSpaceModel}) = nothing
dimension(model::AbstractStateSpaceModel) = dimension(typeof(model))

export transition!!,
transition_logdensity,
emission_logdensity,
isdone,
AbstractParticle,
AbstractStateSpaceModel,
dimension

end
Loading