Skip to content

Commit

Permalink
Use Gaussian (#28)
Browse files Browse the repository at this point in the history
* Format, use `Gaussian`

* Fix the maths

* Format

* Tweaks

* Update script.jl

* Update script.jl

---------

Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com>
  • Loading branch information
FredericWantiez and yebai authored Nov 17, 2023
1 parent c4b1f61 commit 2ac231b
Showing 1 changed file with 29 additions and 19 deletions.
48 changes: 29 additions & 19 deletions examples/kalman-filter/script.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
# Kalman filter using Kalman.jl
using Distributions
using GaussianDistributions
using GaussianDistributions:
using Kalman
## Kalman filter using Kalman.jl
using GaussianDistributions: correct, Gaussian
using LinearAlgebra
using Statistics
using Plots
using Random
using SSMProblems
Expand All @@ -26,9 +24,19 @@ struct LinearGaussianSSM <: AbstractStateSpaceModel
R::Matrix{Float64}
end

f0(model::LinearGaussianSSM) = MvNormal(model.z, model.P)
f(x::Vector{Float64}, model::LinearGaussianSSM) = MvNormal(model.Φ * x + model.b, model.Q)
g(y::Vector{Float64}, model::LinearGaussianSSM) = MvNormal(model.H * y, model.R)
f0(model::LinearGaussianSSM) = Gaussian(model.z, model.P)
f(x::Vector{Float64}, model::LinearGaussianSSM) = Gaussian(model.Φ * x + model.b, model.Q)
g(y::Vector{Float64}, model::LinearGaussianSSM) = Gaussian(model.H * y, model.R)

function transition!!(rng::AbstractRNG, model::LinearGaussianSSM)
return Gaussian(model.z, model.P)
end

function transition!!(rng::AbstractRNG, model::LinearGaussianSSM, state::Gaussian)
let Φ = model.Φ, Q = model.Q, μ = state.μ, Σ = state.Σ
return Gaussian* μ, Φ * Σ * Φ' + Q)
end
end

# Simulation parameters
SEED = 1
Expand All @@ -55,41 +63,43 @@ for t in 1:T
end

# Kalman filter
function filter(model::LinearGaussianSSM, y::Vector{Any})
function filter(rng::Random.AbstractRNG, model::LinearGaussianSSM, y::Vector{Any})
T = length(y)
p = Gaussian(model.z, model.P)
p = transition!!(rng, model)
ps = [p]
for i in 1:T
p = Φ * p Gaussian(zero(z), Q)
p, yres, _ = Kalman.correct(
Kalman.JosephForm(), p, (Gaussian(y[i], model.R), model.H)
)
p = transition!!(rng, model, p)
p, yres, _ = correct(p, Gaussian(y[i], model.R), model.H)
push!(ps, p)
end
return ps
end

# Run filter and plot results
ps = filter(model, y)
ps = filter(rng, model, y)

p_mean = mean.(ps)
p_cov = sqrt.(cov.(ps))

p1 = scatter(1:T, first.(y); color="red", label="Observations")
plot!(
p1,
0:T,
[mean(p)[1] for p in ps];
first.(p_mean);
color="orange",
label="Filtered x1",
grid=false,
ribbon=[sqrt(cov(p)[1, 1]) for p in ps],
ribbon=getindex.(p_cov, 1, 1),
fillalpha=0.5,
)

plot!(
p1,
0:T,
[mean(p)[2] for p in ps];
last.(p_mean);
color="blue",
label="Filtered x2",
grid=false,
ribbon=[sqrt(cov(p)[2, 2]) for p in ps],
ribbon=getindex.(p_cov, 2, 2),
fillalpha=0.5,
)

0 comments on commit 2ac231b

Please sign in to comment.