Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change name of AbstractParticle #14

Merged
merged 1 commit into from
Jul 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 34 additions & 24 deletions examples/smc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,23 @@ using StatsFuns
# Particle Filter implementation
struct Particle{T} # Here we just need a tree
parent::Union{Particle,Nothing}
state::T
model::T
end

Particle(state::T) where {T} = Particle(nothing, state)
Particle(model::T) where {T} = Particle(nothing, model)
Particle() = Particle(nothing, nothing)
Base.show(io::IO, p::Particle) = print(io, "Particle($(p.state))")
Base.show(io::IO, p::Particle) = print(io, "Particle($(p.model))")

"""
linearize(particle)

Return the trace of a particle, i.e. the sequence of states from the root to the particle.
"""
function linearize(particle::Particle{T}) where {T}
trace = T[]
function linearize(particle::Particle)
trace = []
parent = particle.parent
while !isnothing(parent)
push!(trace, parent.state)
push!(trace, parent.model)
parent = parent.parent
end
return trace
Expand All @@ -44,7 +44,7 @@ function sweep!(rng::AbstractRNG, particles::ParticleContainer, resampling, thre
t = 1
N = length(particles)
logweights = zeros(length(particles))
while !isdone(t, particles[1].state)
while !isdone(t, particles[1].model)

# Resample step
weights = get_weights(logweights)
Expand All @@ -57,9 +57,9 @@ function sweep!(rng::AbstractRNG, particles::ParticleContainer, resampling, thre
# Mutation step
for i in eachindex(particles)
parent = particles[i]
mutated = transition!!(rng, t, parent.state)
mutated = transition!!(rng, t, parent.model)
particles[i] = Particle(parent, mutated)
logweights[i] += emission_logdensity(t, particles[i].state)
logweights[i] += emission_logdensity(t, particles[i].model)
end

t += 1
Expand All @@ -76,43 +76,53 @@ function sweep!(rng::AbstractRNG, n::Int, resampling, threshold=0.5)
end

# Inference code
struct LinearSSM{T} <: AbstractStateSpaceModel
state::T
end

LinearSSM() = LinearSSM(nothing)

Base.@kwdef struct Parameters
v::Float64 = 0.2 # Transition noise stdev
u::Float64 = 0.7 # Observation noise stdev
end

# Simulation
T = 250
seed = 1
seed = 12
N = 1000
rng = MersenneTwister(seed)
params = Parameters(; v=0.2, u=0.7)

function transition!!(rng::AbstractRNG, t::Int, state=nothing)
if isnothing(state)
return rand(rng, Normal(0, 1))
function emission_density(x)
return Normal(x, params.u)
end

function transition!!(rng::AbstractRNG, t::Int, model::LinearSSM{T}=LinearSSM()) where {T}
if t == 1
return LinearSSM(rand(rng, Normal(0, 1)))
end
return rand(rng, Normal(state, params.v))
return LinearSSM(rand(rng, Normal(model.state, params.v)))
end

function emission_logdensity(t, state)
return logpdf(Normal(state, params.u), observations[t])
function emission_logdensity(t, model::LinearSSM)
return logpdf(emission_density(model.state), observations[t])
end

isdone(t, state) = t > T
isdone(t, ::Nothing) = false
isdone(t, ::LinearSSM) = t > T

x, observations = zeros(T), zeros(T)
x[1] = rand(rng, Normal(0, 1))
x, observations = Vector{LinearSSM}(undef, T), zeros(T)
x[1] = transition!!(rng, 1)
for t in 1:T
observations[t] = rand(rng, Normal(x[t], params.u))
observations[t] = rand(emission_density(x[t].state))
if t < T
x[t + 1] = rand(rng, Normal(x[t], params.v))
x[t + 1] = transition!!(rng, t, x[t])
end
end

samples = sweep!(rng, fill(Particle(x[1]), N), systematic_resampling)
traces = reverse(hcat(map(linearize, samples)...))
traces = map(model -> model.state, reverse(hcat(map(linearize, samples)...)))

scatter(traces; color=:black, opacity=0.3, label=false)
plot!(x; label="True state")
plot!(map(model -> getfield(model, :state), x); label="True state")
plot!(mean(traces; dims=2); label="Posterior mean")
7 changes: 4 additions & 3 deletions src/SSMProblems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ module SSMProblems

"""
"""
abstract type AbstractParticle end
abstract type AbstractParticleCache end
abstract type AbstractStateSpaceModel end
yebai marked this conversation as resolved.
Show resolved Hide resolved
abstract type AbstractSSMCache end

"""
transition!!(rng, step, particle[, cache])
Expand Down Expand Up @@ -36,6 +36,7 @@ Determine whether we have reached the last time step of the Markov process. Retu
"""
function isdone end

export transition!!, transition_logdensity, emission_logdensity, isdone, AbstractParticle
export transition!!,
transition_logdensity, emission_logdensity, isdone, AbstractStateSpaceModel

end
Loading