diff --git a/test/Project.toml b/test/Project.toml index 9fcb69d..c6d16e8 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,9 +1,14 @@ [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +DynamicIterators = "6c76993d-992e-5bf1-9e63-34920a5a5a38" +GaussianDistributions = "43dcc890-d446-5863-8d1a-14597580bb8d" +HypothesisTests = "09f84164-cd44-5f33-b23f-e6b0d136a0d5" +Kalman = "d59c0ba6-2ef2-5409-8dc5-1fd9a2b46832" Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Random123 = "74087812-796a-5b5d-8853-05524746bad3" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] diff --git a/test/linear-gaussian.jl b/test/linear-gaussian.jl new file mode 100644 index 0000000..c200efb --- /dev/null +++ b/test/linear-gaussian.jl @@ -0,0 +1,103 @@ +""" +Unit tests for the validity of the SMC algorithms included in this package. + +We test each SMC algorithm on a one-dimensional linear Gaussian state space model for which +an analytic filtering distribution can be computed using the Kalman filter provided by the +`Kalman.jl` package. + +The validity of the algorithm is tested by comparing the final estimated filtering +distribution ground truth using a one-sided Kolmogorov-Smirnov test. +""" + +using DynamicIterators +using GaussianDistributions +using HypothesisTests +using Kalman + +function test_algorithm(rng, algorithm, model, N_SAMPLES, Xf) + chains = sample(rng, model, algorithm, N_SAMPLES; progress=false) + particles = hcat([chain.trajectory.model.X for chain in chains]...) + final_particles = particles[:, end] + + test = ExactOneSampleKSTest( + final_particles, Normal(Xf.x[end].μ[1], sqrt(Xf.x[end].Σ[1, 1])) + ) + return pvalue(test) +end + +@testset "linear-gaussian.jl" begin + T = 3 + N_PARTICLES = 20 + N_SAMPLES = 50 + + # Model dynamics (in matrix form, despite being one-dimensional, to work with Kalman.jl) + Φ = [0.5;;] + b = [0.2] + Q = [0.1;;] + E = LinearEvolution(Φ, Gaussian(b, Q)) + + H = [1.0;;] + R = [0.1;;] + Obs = LinearObservationModel(H, R) + + x0 = [0.0] + P0 = [1.0;;] + G0 = Gaussian(x0, P0) + + M = LinearStateSpaceModel(E, Obs) + O = LinearObservation(E, H, R) + + # Simulate from model + rng = StableRNG(1234) + initial = rand(rng, StateObs(G0, M.obs)) + trajectory = trace(DynamicIterators.Sampled(M), 1 => initial, endtime(T)) + y_pairs = collect(t => y for (t, (x, y)) in pairs(trajectory)) + ys = stack(y for (t, (x, y)) in pairs(trajectory)) + + # Ground truth smoothing + Xf, ll = kalmanfilter(M, 1 => G0, y_pairs) + + # Define AdvancedPS model + mutable struct LinearGaussianParams + a::Float64 + b::Float64 + q::Float64 + h::Float64 + r::Float64 + x0::Float64 + p0::Float64 + end + + mutable struct LinearGaussianModel <: AdvancedPS.AbstractStateSpaceModel + X::Vector{Float64} + θ::LinearGaussianParams + LinearGaussianModel(params::LinearGaussianParams) = new(Vector{Float64}(), params) + end + + function AdvancedPS.initialization(model::LinearGaussianModel) + return Normal(model.θ.x0, model.θ.p0) + end + function AdvancedPS.transition(model::LinearGaussianModel, state, step) + return Normal(model.θ.a * state + model.θ.b, model.θ.q) + end + function AdvancedPS.observation(model::LinearGaussianModel, state, step) + return logpdf(Normal(model.θ.h * state, model.θ.r), ys[step]) + end + + AdvancedPS.isdone(::LinearGaussianModel, step) = step > T + + params = LinearGaussianParams(Φ[1, 1], b[1], Q[1, 1], H[1, 1], R[1, 1], x0[1], P0[1, 1]) + model = LinearGaussianModel(params) + + @testset "PGAS" begin + pgas = AdvancedPS.PGAS(N_PARTICLES) + p = test_algorithm(rng, pgas, model, N_SAMPLES, Xf) + @test p > 0.05 + end + + @testset "PG" begin + pg = AdvancedPS.PG(N_PARTICLES) + p = test_algorithm(rng, pg, model, N_SAMPLES, Xf) + @test p > 0.05 + end +end diff --git a/test/runtests.jl b/test/runtests.jl index d705806..337a582 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,6 +3,7 @@ using AbstractMCMC using Distributions using Libtask using Random +using StableRNGs using Test @testset "AdvancedPS.jl" begin @@ -21,4 +22,7 @@ using Test @testset "PG-AS" begin include("pgas.jl") end + @testset "Linear Gaussian SSM tests" begin + include("linear-gaussian.jl") + end end