Skip to content

Maximum Likelihood Demonstration #26

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 23 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
04ecb4e
fixed type stability of linear filter
charlesknipp Jan 8, 2025
152917b
added MLE demonstration
charlesknipp Jan 8, 2025
efe1573
flipped sign of objective function
charlesknipp Jan 8, 2025
e99be16
Merge remote-tracking branch 'origin/main' into ck/maximum-likelihood
charlesknipp Mar 7, 2025
acbc4e1
reorganized and added Mooncake MWE
charlesknipp Mar 7, 2025
25eb4d6
fixed KF type stability in Enzyme
charlesknipp Mar 10, 2025
6d061b0
add MWE for Kalman filtering
charlesknipp Mar 11, 2025
f8f0e59
replaced second order optimizer and added backend testing
charlesknipp Mar 11, 2025
bd96222
Make HierarchicalSSM compatible with the BF (#70)
THargreaves Mar 11, 2025
c2b12db
Remove redundant computation from batch KF
THargreaves Mar 12, 2025
ba214d6
Factor out intermediate storage (#72)
THargreaves Mar 13, 2025
a3ceb96
Remove prototype CUDA kernels from package
THargreaves Mar 13, 2025
d47b1c7
Make RB particle creation more robust
THargreaves Mar 13, 2025
e429e6b
Export callback methods and types
THargreaves Mar 13, 2025
da6a54b
fixed type stability of linear filter
charlesknipp Jan 8, 2025
39d9f32
added MLE demonstration
charlesknipp Jan 8, 2025
30a4be3
flipped sign of objective function
charlesknipp Jan 8, 2025
935df10
reorganized and added Mooncake MWE
charlesknipp Mar 7, 2025
2cc7d63
fixed KF type stability in Enzyme
charlesknipp Mar 10, 2025
5c9149e
add MWE for Kalman filtering
charlesknipp Mar 11, 2025
4936c52
replaced second order optimizer and added backend testing
charlesknipp Mar 11, 2025
e8e6c02
Merge branch 'ck/maximum-likelihood' of https://github.com/TuringLang…
charlesknipp Mar 13, 2025
56c6dc0
fixed Mooncake errors for Bootstrap filter
charlesknipp Mar 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 24 additions & 34 deletions GeneralisedFilters/src/GeneralisedFilters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,16 @@ using StatsBase
using CUDA
using NNlib

# Filtering utilities
include("callbacks.jl")
include("containers.jl")
include("resamplers.jl")

## FILTERING BASE ##########################################################################

abstract type AbstractFilter <: AbstractSampler end
abstract type AbstractBatchFilter <: AbstractFilter end

"""
instantiate(model, alg, initial; kwargs...)

Create an intermediate storage object to store the proposed/filtered states at each step.
"""
function instantiate end

# Default method
function instantiate(model, alg, initial; kwargs...)
return Intermediate(initial, initial)
end

"""
initialise([rng,] model, alg; kwargs...)

Expand All @@ -37,9 +30,9 @@ Propose an initial state distribution.
function initialise end

"""
step([rng,] model, alg, iter, intermediate, observation; kwargs...)
step([rng,] model, alg, iter, state, observation; kwargs...)

Perform a combined predict and update call of the filtering on the intermediate storage.
Perform a combined predict and update call of the filtering on the state.
"""
function step end

Expand Down Expand Up @@ -70,24 +63,22 @@ function filter(
model::AbstractStateSpaceModel,
alg::AbstractFilter,
observations::AbstractVector;
callback=nothing,
callback::Union{AbstractCallback,Nothing}=nothing,
kwargs...,
)
initial = initialise(rng, model, alg; kwargs...)
intermediate = instantiate(model, alg, initial; kwargs...)
isnothing(callback) || callback(model, alg, intermediate, observations; kwargs...)
state = initialise(rng, model, alg; kwargs...)
isnothing(callback) || callback(model, alg, state, observations, PostInit; kwargs...)

log_evidence = initialise_log_evidence(alg, model)

for t in eachindex(observations)
intermediate, ll_increment = step(
rng, model, alg, t, intermediate, observations[t]; callback, kwargs...
state, ll_increment = step(
rng, model, alg, t, state, observations[t]; callback, kwargs...
)
log_evidence += ll_increment
isnothing(callback) ||
callback(model, alg, t, intermediate, observations; kwargs...)
end

return intermediate.filtered, log_evidence
return state, log_evidence
end

function initialise_log_evidence(::AbstractFilter, model::AbstractStateSpaceModel)
Expand All @@ -112,16 +103,20 @@ function step(
model::AbstractStateSpaceModel,
alg::AbstractFilter,
iter::Integer,
intermediate,
state,
observation;
callback::Union{AbstractCallback,Nothing}=nothing,
kwargs...,
)
intermediate.proposed = predict(rng, model, alg, iter, intermediate.filtered; kwargs...)
intermediate.filtered, ll_increment = update(
model, alg, iter, intermediate.proposed, observation; kwargs...
)
state = predict(rng, model, alg, iter, state; kwargs...)
isnothing(callback) ||
callback(model, alg, iter, state, observation, PostPredict; kwargs...)

return intermediate, ll_increment
state, ll_increment = update(model, alg, iter, state, observation; kwargs...)
isnothing(callback) ||
callback(model, alg, iter, state, observation, PostUpdate; kwargs...)

return state, ll_increment
end

## SMOOTHING BASE ##########################################################################
Expand All @@ -131,11 +126,6 @@ abstract type AbstractSmoother <: AbstractSampler end
# function smooth end
# function backward end

# Filtering utilities
include("callbacks.jl")
include("containers.jl")
include("resamplers.jl")

# Model types
include("models/linear_gaussian.jl")
include("models/discrete.jl")
Expand Down
75 changes: 43 additions & 32 deletions GeneralisedFilters/src/algorithms/bootstrap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,28 +7,31 @@ function step(
model::AbstractStateSpaceModel,
alg::AbstractParticleFilter,
iter::Integer,
intermediate,
state,
observation;
ref_state::Union{Nothing,AbstractVector}=nothing,
callback::Union{AbstractCallback,Nothing}=nothing,
kwargs...,
)
intermediate.proposed, intermediate.ancestors = resample(
rng, alg.resampler, intermediate.filtered
)
state = resample(rng, alg.resampler, state)
isnothing(callback) ||
callback(model, alg, iter, state, observation, PostResample; kwargs...)

intermediate.proposed = predict(
rng, model, alg, iter, intermediate.proposed; ref_state=ref_state, kwargs...
)
# TODO: this is quite inelegant and should be refactored
state = predict(rng, model, alg, iter, state; ref_state=ref_state, kwargs...)

# TODO: this is quite inelegant and should be refactored. It also might introduce bugs
# with callbacks that track the ancestry (and use PostResample)
if !isnothing(ref_state)
CUDA.@allowscalar intermediate.ancestors[1] = 1
CUDA.@allowscalar state.ancestors[1] = 1
end
isnothing(callback) ||
callback(model, alg, iter, state, observation, PostPredict; kwargs...)

intermediate.filtered, ll_increment = update(
model, alg, iter, intermediate.proposed, observation; kwargs...
)
state, ll_increment = update(model, alg, iter, state, observation; kwargs...)
isnothing(callback) ||
callback(model, alg, iter, state, observation, PostUpdate; kwargs...)

return intermediate, ll_increment
return state, ll_increment
end

struct BootstrapFilter{RS<:AbstractResampler} <: AbstractParticleFilter
Expand All @@ -46,13 +49,6 @@ function BootstrapFilter(
return BootstrapFilter{ESSResampler}(N, conditional_resampler)
end

function instantiate(
::StateSpaceModel{T}, filter::BootstrapFilter, initial; kwargs...
) where {T}
N = filter.N
return ParticleIntermediate(initial, initial, Vector{Int}(undef, N))
end

function initialise(
rng::AbstractRNG,
model::StateSpaceModel{T},
Expand All @@ -71,36 +67,51 @@ function predict(
model::StateSpaceModel,
filter::BootstrapFilter,
step::Integer,
filtered::ParticleDistribution;
state::ParticleDistribution;
ref_state::Union{Nothing,AbstractVector}=nothing,
kwargs...,
)
new_particles = map(
x -> SSMProblems.simulate(rng, model.dyn, step, x; kwargs...), collect(filtered)
state.particles = map(
x -> SSMProblems.simulate(rng, model.dyn, step, x; kwargs...), collect(state)
)
# Don't need to deep copy weights as filtered will be overwritten in the update step
proposed = ParticleDistribution(new_particles, filtered.log_weights)

return update_ref!(proposed, ref_state, step)
return update_ref!(state, ref_state, step)
end

function update(
model::StateSpaceModel{T},
filter::BootstrapFilter,
step::Integer,
proposed::ParticleDistribution,
state::ParticleDistribution,
observation;
kwargs...,
) where {T}
old_ll = logsumexp(state.log_weights)

log_increments = map(
x -> SSMProblems.logdensity(model.obs, step, x, observation; kwargs...),
collect(proposed),
collect(state),
)

new_weights = proposed.log_weights + log_increments
filtered = ParticleDistribution(deepcopy(proposed.particles), new_weights)
state.log_weights += log_increments

ll_increment = logsumexp(filtered.log_weights) - logsumexp(proposed.log_weights)
ll_increment = logsumexp(state.log_weights) - old_ll

return filtered, ll_increment
return state, ll_increment
end

# Application of bootstrap filter to hierarchical models
function filter(
rng::AbstractRNG,
model::HierarchicalSSM,
alg::BootstrapFilter,
observations::AbstractVector;
ref_state::Union{Nothing,AbstractVector}=nothing,
kwargs...,
)
ssm = StateSpaceModel(
HierarchicalDynamics(model.outer_dyn, model.inner_model.dyn),
HierarchicalObservations(model.inner_model.obs),
)
return filter(rng, ssm, alg, observations; ref_state=ref_state, kwargs...)
end
Loading
Loading