Skip to content

Commit

Permalink
changing names a bit (#14)
Browse files Browse the repository at this point in the history
  • Loading branch information
FredericWantiez authored Jul 28, 2023
1 parent bfcf29c commit 7181eb3
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 27 deletions.
58 changes: 34 additions & 24 deletions examples/smc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,23 @@ using StatsFuns
# Particle Filter implementation
struct Particle{T} # Here we just need a tree
parent::Union{Particle,Nothing}
state::T
model::T
end

Particle(state::T) where {T} = Particle(nothing, state)
Particle(model::T) where {T} = Particle(nothing, model)
Particle() = Particle(nothing, nothing)
Base.show(io::IO, p::Particle) = print(io, "Particle($(p.state))")
Base.show(io::IO, p::Particle) = print(io, "Particle($(p.model))")

"""
linearize(particle)
Return the trace of a particle, i.e. the sequence of states from the root to the particle.
"""
function linearize(particle::Particle{T}) where {T}
trace = T[]
function linearize(particle::Particle)
trace = []
parent = particle.parent
while !isnothing(parent)
push!(trace, parent.state)
push!(trace, parent.model)
parent = parent.parent
end
return trace
Expand All @@ -44,7 +44,7 @@ function sweep!(rng::AbstractRNG, particles::ParticleContainer, resampling, thre
t = 1
N = length(particles)
logweights = zeros(length(particles))
while !isdone(t, particles[1].state)
while !isdone(t, particles[1].model)

# Resample step
weights = get_weights(logweights)
Expand All @@ -57,9 +57,9 @@ function sweep!(rng::AbstractRNG, particles::ParticleContainer, resampling, thre
# Mutation step
for i in eachindex(particles)
parent = particles[i]
mutated = transition!!(rng, t, parent.state)
mutated = transition!!(rng, t, parent.model)
particles[i] = Particle(parent, mutated)
logweights[i] += emission_logdensity(t, particles[i].state)
logweights[i] += emission_logdensity(t, particles[i].model)
end

t += 1
Expand All @@ -76,43 +76,53 @@ function sweep!(rng::AbstractRNG, n::Int, resampling, threshold=0.5)
end

# Inference code
struct LinearSSM{T} <: AbstractStateSpaceModel
state::T
end

LinearSSM() = LinearSSM(nothing)

Base.@kwdef struct Parameters
v::Float64 = 0.2 # Transition noise stdev
u::Float64 = 0.7 # Observation noise stdev
end

# Simulation
T = 250
seed = 1
seed = 12
N = 1000
rng = MersenneTwister(seed)
params = Parameters(; v=0.2, u=0.7)

function transition!!(rng::AbstractRNG, t::Int, state=nothing)
if isnothing(state)
return rand(rng, Normal(0, 1))
function emission_density(x)
return Normal(x, params.u)
end

function transition!!(rng::AbstractRNG, t::Int, model::LinearSSM{T}=LinearSSM()) where {T}
if t == 1
return LinearSSM(rand(rng, Normal(0, 1)))
end
return rand(rng, Normal(state, params.v))
return LinearSSM(rand(rng, Normal(model.state, params.v)))
end

function emission_logdensity(t, state)
return logpdf(Normal(state, params.u), observations[t])
function emission_logdensity(t, model::LinearSSM)
return logpdf(emission_density(model.state), observations[t])
end

isdone(t, state) = t > T
isdone(t, ::Nothing) = false
isdone(t, ::LinearSSM) = t > T

x, observations = zeros(T), zeros(T)
x[1] = rand(rng, Normal(0, 1))
x, observations = Vector{LinearSSM}(undef, T), zeros(T)
x[1] = transition!!(rng, 1)
for t in 1:T
observations[t] = rand(rng, Normal(x[t], params.u))
observations[t] = rand(emission_density(x[t].state))
if t < T
x[t + 1] = rand(rng, Normal(x[t], params.v))
x[t + 1] = transition!!(rng, t, x[t])
end
end

samples = sweep!(rng, fill(Particle(x[1]), N), systematic_resampling)
traces = reverse(hcat(map(linearize, samples)...))
traces = map(model -> model.state, reverse(hcat(map(linearize, samples)...)))

scatter(traces; color=:black, opacity=0.3, label=false)
plot!(x; label="True state")
plot!(map(model -> getfield(model, :state), x); label="True state")
plot!(mean(traces; dims=2); label="Posterior mean")
7 changes: 4 additions & 3 deletions src/SSMProblems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ module SSMProblems

"""
"""
abstract type AbstractParticle end
abstract type AbstractParticleCache end
abstract type AbstractStateSpaceModel end
abstract type AbstractSSMCache end

"""
transition!!(rng, step, particle[, cache])
Expand Down Expand Up @@ -36,6 +36,7 @@ Determine whether we have reached the last time step of the Markov process. Retu
"""
function isdone end

export transition!!, transition_logdensity, emission_logdensity, isdone, AbstractParticle
export transition!!,
transition_logdensity, emission_logdensity, isdone, AbstractStateSpaceModel

end

0 comments on commit 7181eb3

Please sign in to comment.