Skip to content

Commit

Permalink
Add unit test for linear Gaussian SSM
Browse files Browse the repository at this point in the history
  • Loading branch information
THargreaves committed Apr 12, 2024
1 parent 1e5dfdd commit 7568d31
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 0 deletions.
5 changes: 5 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DynamicIterators = "6c76993d-992e-5bf1-9e63-34920a5a5a38"
GaussianDistributions = "43dcc890-d446-5863-8d1a-14597580bb8d"
HypothesisTests = "09f84164-cd44-5f33-b23f-e6b0d136a0d5"
Kalman = "d59c0ba6-2ef2-5409-8dc5-1fd9a2b46832"
Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Random123 = "74087812-796a-5b5d-8853-05524746bad3"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
Expand Down
103 changes: 103 additions & 0 deletions test/linear-gaussian.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
"""
Unit tests for the validity of the SMC algorithms included in this package.
We test each SMC algorithm on a one-dimensional linear Gaussian state space model for which
an analytic filtering distribution can be computed using the Kalman filter provided by the
`Kalman.jl` package.
The validity of the algorithm is tested by comparing the final estimated filtering
distribution ground truth using a one-sided Kolmogorov-Smirnov test.
"""

using DynamicIterators
using GaussianDistributions
using HypothesisTests
using Kalman

function test_algorithm(rng, algorithm, model, N_SAMPLES, Xf)
chains = sample(rng, model, algorithm, N_SAMPLES; progress=false)
particles = hcat([chain.trajectory.model.X for chain in chains]...)
final_particles = particles[:, end]

test = ExactOneSampleKSTest(
final_particles, Normal(Xf.x[end].μ[1], sqrt(Xf.x[end].Σ[1, 1]))
)
return pvalue(test)
end

@testset "linear-gaussian.jl" begin
T = 3
N_PARTICLES = 20
N_SAMPLES = 50

# Model dynamics (in matrix form, despite being one-dimensional, to work with Kalman.jl)
Φ = [0.5;;]
b = [0.2]
Q = [0.1;;]
E = LinearEvolution(Φ, Gaussian(b, Q))

H = [1.0;;]
R = [0.1;;]
Obs = LinearObservationModel(H, R)

x0 = [0.0]
P0 = [1.0;;]
G0 = Gaussian(x0, P0)

M = LinearStateSpaceModel(E, Obs)
O = LinearObservation(E, H, R)

# Simulate from model
rng = StableRNG(1234)
initial = rand(rng, StateObs(G0, M.obs))
trajectory = trace(DynamicIterators.Sampled(M), 1 => initial, endtime(T))
y_pairs = collect(t => y for (t, (x, y)) in pairs(trajectory))
ys = stack(y for (t, (x, y)) in pairs(trajectory))

# Ground truth smoothing
Xf, ll = kalmanfilter(M, 1 => G0, y_pairs)

# Define AdvancedPS model
mutable struct LinearGaussianParams
a::Float64
b::Float64
q::Float64
h::Float64
r::Float64
x0::Float64
p0::Float64
end

mutable struct LinearGaussianModel <: AdvancedPS.AbstractStateSpaceModel
X::Vector{Float64}
θ::LinearGaussianParams
LinearGaussianModel(params::LinearGaussianParams) = new(Vector{Float64}(), params)
end

function AdvancedPS.initialization(model::LinearGaussianModel)
return Normal(model.θ.x0, model.θ.p0)
end
function AdvancedPS.transition(model::LinearGaussianModel, state, step)
return Normal(model.θ.a * state + model.θ.b, model.θ.q)
end
function AdvancedPS.observation(model::LinearGaussianModel, state, step)
return logpdf(Normal(model.θ.h * state, model.θ.r), ys[step])
end

AdvancedPS.isdone(::LinearGaussianModel, step) = step > T

params = LinearGaussianParams(Φ[1, 1], b[1], Q[1, 1], H[1, 1], R[1, 1], x0[1], P0[1, 1])
model = LinearGaussianModel(params)

@testset "PGAS" begin
pgas = AdvancedPS.PGAS(N_PARTICLES)
p = test_algorithm(rng, pgas, model, N_SAMPLES, Xf)
@test p > 0.05
end

@testset "PG" begin
pg = AdvancedPS.PG(N_PARTICLES)
p = test_algorithm(rng, pg, model, N_SAMPLES, Xf)
@test p > 0.05
end
end
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ using AbstractMCMC
using Distributions
using Libtask
using Random
using StableRNGs
using Test

@testset "AdvancedPS.jl" begin
Expand All @@ -21,4 +22,7 @@ using Test
@testset "PG-AS" begin
include("pgas.jl")
end
@testset "Linear Gaussian SSM tests" begin
include("linear-gaussian.jl")
end
end

0 comments on commit 7568d31

Please sign in to comment.