Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft DA script #57

Draft
wants to merge 26 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
440252c
added basic particle methods and filters
charlesknipp Aug 9, 2024
9fd4453
added qualifiers
charlesknipp Aug 12, 2024
3fd90c4
added parameter priors
charlesknipp Aug 12, 2024
884b9e3
Merge branch 'main' into ck/particle-methods
charlesknipp Aug 30, 2024
1def6a1
Merge branch 'main' into ck/particle-methods
charlesknipp Sep 24, 2024
a5a2e05
added adaptive resampling to bootstrap filter (WIP)
charlesknipp Sep 25, 2024
57da3ff
Julia fomatter changes
charlesknipp Sep 25, 2024
dc713b0
Merge branch 'ck/particle-methods' of https://github.com/TuringLang/S…
charlesknipp Sep 25, 2024
b846fa4
changed eltype for <: StateSpaceModel
charlesknipp Sep 26, 2024
4263ae7
updated naming conventions
charlesknipp Sep 26, 2024
5a2aeb4
formatter
charlesknipp Sep 26, 2024
8db658b
fixed adaptive resampling
charlesknipp Sep 27, 2024
15dfa9f
added particle ancestry
charlesknipp Oct 1, 2024
7e3c93d
formatter issues
charlesknipp Oct 1, 2024
f905a41
fixed metropolis and added rejection resampler
charlesknipp Oct 1, 2024
8ac1455
Keep track of free indices using stack
THargreaves Oct 2, 2024
f11a63e
updated particle types and organized directory
charlesknipp Oct 2, 2024
1fa3c93
weakened SSM type parameter assertions
charlesknipp Oct 4, 2024
8cb4338
improved particle state containment and resampling
charlesknipp Oct 4, 2024
73dd433
added hacky sparse ancestry to example
charlesknipp Oct 5, 2024
f71ab32
fixed RNG in rejection resampling
charlesknipp Oct 6, 2024
25cebf4
improved callbacks and resamplers
charlesknipp Oct 6, 2024
c729879
formatting
charlesknipp Oct 6, 2024
cf8ce02
Draft DA script
FredericWantiez Oct 7, 2024
6723a94
Graph
FredericWantiez Oct 8, 2024
1452069
Format
FredericWantiez Oct 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions examples/particle-mcmc/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
AdvancedMH = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
GLMakie = "e9467ef8-e4e7-5192-8a1a-b1aee30e663a"
GaussianDistributions = "43dcc890-d446-5863-8d1a-14597580bb8d"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
SSMProblems = "26aad666-b158-4e64-9d35-0e672562fa48"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm sure this is just left over from some experimentation. Did you try running anything with StaticArrays?

StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
93 changes: 93 additions & 0 deletions examples/particle-mcmc/lorenz.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
using Distributions
using Random
using SSMProblems
using UnPack
using OrdinaryDiffEq
using LinearAlgebra
using PDMats
using GLMakie

include("particles.jl")
include("resamplers.jl")
include("simple-filters.jl")

Base.@kwdef struct Parameters{T<:Real}
β::T = 8 / 3
ρ::T = 28.0
σ::T = 10.0
ν::T = 1.0 # Obs noise variance
dt::T = 0.025 # Time step
end

function lorenz!(du, u, p::Parameters, t)
@unpack β, ρ, σ = p
du[1] = σ * (u[2] - u[1])
du[2] = u[1] * (ρ - u[3]) - u[2]
return du[3] = u[1] * u[2] - β * u[3]
end

struct LatentNoiseProcess{T} <: LatentDynamics{Vector{T}}
σ::AbstractPDMat{T}
dt::T
integrator
end

struct ObservationNoiseProcess{T} <: ObservationProcess{Vector{T}}
σ::AbstractPDMat{T}
end

function SSMProblems.distribution(dyn::LatentNoiseProcess, step::Integer, prev_state, extra)
reinit!(dyn.integrator, prev_state)
step!(dyn.integrator, dyn.dt, true)
return MvNormal(dyn.integrator.u, dyn.σ)
end

function SSMProblems.distribution(dyn::LatentNoiseProcess, extra)
return MvNormal([1; 0; 0], dyn.σ)
end

function SSMProblems.distribution(obs::ObservationNoiseProcess, step::Integer, state, extra)
return MvNormal(state, obs.σ * I)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the importance of having the identity matrix in obs.σ * I?

Copy link
Member Author

@FredericWantiez FredericWantiez Oct 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's redundant, obs.σ is already a PDMat. Didn't catch that

end

# Simulate some data
u0 = [1.0; 0.0; 0.0]
params = Parameters()

dt = 0.025
N = 100
Np = 512
tspan = (0.0, dt * N)

rng = MersenneTwister()

prob = ODEProblem(lorenz!, u0, tspan, params)
alg = Tsit5()
integrator = init(prob, Tsit5(); dt=dt, adaptive=false)
sol = solve(prob, alg; dt=dt, adaptive=false)

# SSM Noise Model
dyn = LatentNoiseProcess(ScalMat(3, params.dt), params.dt, integrator)
obs = ObservationNoiseProcess(ScalMat(3, params.ν))
model = StateSpaceModel(dyn, obs)
x0, x, y = sample(rng, model, N)

filter = BF(Np; threshold=1.0, resampler=Systematic());
sparse_ancestry = AncestorCallback(eltype(model.dyn), filter.N, 1.0);
tree, llbf = sample(rng, model, filter, y; callback=sparse_ancestry);
lineage = get_ancestry(sparse_ancestry.tree)

# Fancy 3D plot
# fig = Figure()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm obsessed with this plot

# lines(fig[1, 1], hcat(x0, x...))
# for (i, path) in enumerate(lineage)
# lines!(fig[1, 1], hcat(path...), color=:black)
# end

fig = Figure()
for i in eachindex(first(x))
lines(fig[i, 1], hcat(x0, x...)[i, :])
for path in lineage
lines!(fig[i, 1], hcat(path...)[i, :]; color=:black, alpha=0.1)
end
end
176 changes: 176 additions & 0 deletions examples/particle-mcmc/particles.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
using DataStructures: Stack
using StatsBase

## PARTICLES ###############################################################################

mutable struct ParticleContainer{T,WT<:Real}
filtered::Vector{T}
proposed::Vector{T}
ancestors::Vector{Int64}
log_weights::Vector{WT}

function ParticleContainer(
initial_states::Vector{T}, log_weights::Vector{WT}
) where {T,WT<:Real}
return new{T,WT}(
initial_states, similar(initial_states), eachindex(log_weights), log_weights
)
end
end

Base.collect(pc::ParticleContainer) = pc.vals
Base.length(pc::ParticleContainer) = length(pc.vals)
Base.keys(pc::ParticleContainer) = LinearIndices(pc.vals)

# not sure if this is kosher, since it doesn't follow the convention of Base.getindex
Base.@propagate_inbounds Base.getindex(pc::ParticleContainer, i::Int) = pc.vals[i]
Base.@propagate_inbounds Base.getindex(pc::ParticleContainer, i::Vector{Int}) = pc.vals[i]

function reset_weights!(pc::ParticleContainer{T,WT}) where {T,WT<:Real}
fill!(pc.log_weights, zero(WT))
return pc.log_weights
end

function StatsBase.weights(pc::ParticleContainer)
return softmax(pc.log_weights)
end

## SPARSE PARTICLE STORAGE #################################################################

Base.append!(s::Stack, a::AbstractVector) = map(x -> push!(s, x), a)

mutable struct ParticleTree{T}
states::Vector{T}
parents::Vector{Int64}
leaves::Vector{Int64}
offspring::Vector{Int64}
free_indices::Stack{Int64}

function ParticleTree(states::Vector{T}, M::Integer) where {T}
nodes = Vector{T}(undef, M)
initial_free_indices = Stack{Int64}()
append!(initial_free_indices, M:-1:(length(states) + 1))
@inbounds nodes[1:length(states)] = states
return new{T}(
nodes, zeros(Int64, M), 1:length(states), zeros(Int64, M), initial_free_indices
)
end
end

Base.length(tree::ParticleTree) = length(tree.states)
Base.keys(tree::ParticleTree) = LinearIndices(tree.states)

function prune!(tree::ParticleTree, offspring::Vector{Int64})
# insert new offspring counts
setindex!(tree.offspring, offspring, tree.leaves)

# update each branch
@inbounds for i in eachindex(offspring)
j = tree.leaves[i]
while (j > 0) && (tree.offspring[j] == 0)
push!(tree.free_indices, j)
j = tree.parents[j]
if j > 0
tree.offspring[j] -= 1
end
end
end
end

function insert!(
tree::ParticleTree{T}, states::Vector{T}, a::AbstractVector{Int64}
) where {T}
# parents of new generation
parents = getindex(tree.leaves, a)

# ensure there are enough dead branches
if (length(tree.free_indices) < length(a))
@debug "expanding tree"
expand!(tree)
end

# find places for new states
@inbounds for i in eachindex(states)
tree.leaves[i] = pop!(tree.free_indices)
end

# insert new generation and update parent child relationships
setindex!(tree.states, states, tree.leaves)
setindex!(tree.parents, parents, tree.leaves)
return tree
end

function expand!(tree::ParticleTree)
M = length(tree)
resize!(tree.states, 2 * M)

# new allocations must be zero valued, this is not a perfect solution
tree.parents = [tree.parents; zero(tree.parents)]
tree.offspring = [tree.offspring; zero(tree.offspring)]
append!(tree.free_indices, (2 * M):-1:(M + 1))
return tree
end

function get_offspring(a::AbstractVector{Int64})
offspring = zero(a)
for i in a
offspring[i] += 1
end
return offspring
end

function get_ancestry(tree::ParticleTree{T}) where {T}
paths = Vector{Vector{T}}(undef, length(tree.leaves))
@inbounds for (k, i) in enumerate(tree.leaves)
j = tree.parents[i]
xi = tree.states[i]

xs = [xi]
while j > 0
push!(xs, tree.states[j])
j = tree.parents[j]
end
paths[k] = reverse(xs)
end
return paths
end

## ANCESTOR STORAGE CALLBACK ###############################################################

mutable struct AncestorCallback
tree::ParticleTree

function AncestorCallback(::Type{T}, N::Integer, C::Real=1.0) where {T}
M = floor(Int64, C * N * log(N))
nodes = Vector{T}(undef, N)
return new(ParticleTree(nodes, M))
end
end

function (c::AncestorCallback)(model, filter, step, states, data; kwargs...)
if step == 1
# this may be incorrect, but it is functional
@inbounds c.tree.states[1:(filter.N)] = deepcopy(states.filtered)
end
prune!(c.tree, get_offspring(states.ancestors))
insert!(c.tree, states.filtered, states.ancestors)
return nothing
end

mutable struct ResamplerCallback
tree::ParticleTree

function ResamplerCallback(N::Integer, C::Real=1.0)
M = floor(Int64, C * N * log(N))
nodes = collect(1:N)
return new(ParticleTree(nodes, M))
end
end

function (c::ResamplerCallback)(model, filter, step, states, data; kwargs...)
if step != 1
prune!(c.tree, get_offspring(states.ancestors))
insert!(c.tree, collect(1:(filter.N)), states.ancestors)
end
return nothing
end
102 changes: 102 additions & 0 deletions examples/particle-mcmc/resamplers.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
using Random
using Distributions

abstract type AbstractResampler end

## DOUBLE PRECISION STABLE ALGORITHMS ######################################################

struct Multinomial <: AbstractResampler end

function resample(
rng::AbstractRNG, ::Multinomial, weights::AbstractVector{WT}, n::Int64=length(weights)
) where {WT<:Real}
return rand(rng, Distributions.Categorical(weights), n)
end

struct Systematic <: AbstractResampler end

function resample(
rng::AbstractRNG, ::Systematic, weights::AbstractVector{WT}, n::Int64=length(weights)
) where {WT<:Real}
# pre-calculations
@inbounds v = n * weights[1]
u = oftype(v, rand(rng))

# initialize sampling algorithm
a = Vector{Int64}(undef, n)
idx = 1

@inbounds for i in 1:n
while v < u
idx += 1
v += n * weights[idx]
end
a[i] = idx
u += one(u)
end

return a
end

## SINGLE PRECISION STABLE ALGORITHMS ######################################################

struct Metropolis <: AbstractResampler
ε::Float64
function Metropolis(ε::Float64=0.01)
return new(ε)
end
end

# TODO: this should be done in the log domain and also parallelized
function resample(
rng::AbstractRNG,
resampler::Metropolis,
weights::AbstractVector{WT},
n::Int64=length(weights);
) where {WT<:Real}
# pre-calculations
β = mean(weights)
B = Int64(cld(log(resampler.ε), log(1 - β)))

# initialize the algorithm
a = Vector{Int64}(undef, n)

@inbounds for i in 1:n
k = i
for _ in 1:B
j = rand(rng, 1:n)
v = weights[j] / weights[k]
if rand(rng) ≤ v
k = j
end
end
a[i] = k
end

return a
end

struct Rejection <: AbstractResampler end

# TODO: this should be done in the log domain and also parallelized
function resample(
rng::AbstractRNG, ::Rejection, weights::AbstractVector{WT}, n::Int64=length(weights)
) where {WT<:Real}
# pre-calculations
max_weight = maximum(weights)

# initialize the algorithm
a = Vector{Int64}(undef, n)

@inbounds for i in 1:n
j = i
u = rand(rng)
while u > weights[j] / max_weight
j = rand(rng, 1:n)
u = rand(rng)
end
a[i] = j
end

return a
end
Loading
Loading