diff --git a/test/Project.toml b/test/Project.toml index 686826f..d7556c9 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,9 +1,14 @@ [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +DynamicIterators = "6c76993d-992e-5bf1-9e63-34920a5a5a38" +GaussianDistributions = "43dcc890-d446-5863-8d1a-14597580bb8d" +HypothesisTests = "09f84164-cd44-5f33-b23f-e6b0d136a0d5" +Kalman = "d59c0ba6-2ef2-5409-8dc5-1fd9a2b46832" Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Random123 = "74087812-796a-5b5d-8853-05524746bad3" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" SSMProblems = "26aad666-b158-4e64-9d35-0e672562fa48" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/linear-gaussian.jl b/test/linear-gaussian.jl new file mode 100644 index 0000000..9872f1f --- /dev/null +++ b/test/linear-gaussian.jl @@ -0,0 +1,111 @@ +""" +Unit tests for the validity of the SMC algorithms included in this package. + +We test each SMC algorithm on a one-dimensional linear Gaussian state space model for which +an analytic filtering distribution can be computed using the Kalman filter provided by the +`Kalman.jl` package. + +The validity of the algorithm is tested by comparing the final estimated filtering +distribution ground truth using a one-sided Kolmogorov-Smirnov test. +""" + +using DynamicIterators +using GaussianDistributions +using HypothesisTests +using Kalman + +function test_algorithm(rng, algorithm, model, N_SAMPLES, Xf) + chains = sample(rng, model, algorithm, N_SAMPLES; progress=false) + particles = hcat([chain.trajectory.model.X for chain in chains]...) + final_particles = particles[:, end] + + test = ExactOneSampleKSTest(final_particles, Normal(Xf.x[end].μ, sqrt(Xf.x[end].Σ))) + return pvalue(test) +end + +@testset "linear-gaussian.jl" begin + T = 3 + N_PARTICLES = 100 + N_SAMPLES = 50 + + # Model dynamics + a = 0.5 + b = 0.2 + q = 0.1 + E = LinearEvolution(a, Gaussian(b, q)) + + H = 1.0 + R = 0.1 + Obs = LinearObservationModel(H, R) + + x0 = 0.0 + P0 = 1.0 + G0 = Gaussian(x0, P0) + + M = LinearStateSpaceModel(E, Obs) + O = LinearObservation(E, H, R) + + # Simulate from model + rng = StableRNG(1234) + initial = rand(rng, StateObs(G0, M.obs)) + trajectory = trace(DynamicIterators.Sampled(M, rng), 1 => initial, endtime(T)) + y_pairs = collect(t => y for (t, (x, y)) in pairs(trajectory)) + ys = [y for (t, (x, y)) in pairs(trajectory)] + + # Ground truth smoothing + Xf, ll = kalmanfilter(M, 1 => G0, y_pairs) + + # Define AdvancedPS model + mutable struct LinearGaussianParams + a::Float64 + b::Float64 + q::Float64 + h::Float64 + r::Float64 + x0::Float64 + p0::Float64 + end + + mutable struct LinearGaussianModel <: SSMProblems.AbstractStateSpaceModel + X::Vector{Float64} + observations::Vector{Float64} + θ::LinearGaussianParams + function LinearGaussianModel(y::Vector{Float64}, θ::LinearGaussianParams) + return new(Vector{Float64}(), y, θ) + end + end + + function SSMProblems.transition!!(rng::AbstractRNG, model::LinearGaussianModel) + return rand(rng, Normal(model.θ.x0, model.θ.p0)) + end + function SSMProblems.transition!!( + rng::AbstractRNG, model::LinearGaussianModel, state, step + ) + return rand(rng, Normal(model.θ.a * state + model.θ.b, model.θ.q)) + end + function SSMProblems.transition_logdensity( + model::LinearGaussianModel, prev_state, current_state, step + ) + return logpdf(Normal(model.θ.a * prev_state + model.θ.b, model.θ.q), current_state) + end + function SSMProblems.emission_logdensity(model::LinearGaussianModel, state, step) + return logpdf(Normal(model.θ.h * state, model.θ.r), model.observations[step]) + end + + AdvancedPS.isdone(::LinearGaussianModel, step) = step > T + + params = LinearGaussianParams(a, b, q, H, R, x0, P0) + model = LinearGaussianModel(ys, params) + + @testset "PGAS" begin + pgas = AdvancedPS.PGAS(N_PARTICLES) + p = test_algorithm(rng, pgas, model, N_SAMPLES, Xf) + @test p > 0.05 + end + + @testset "PG" begin + pg = AdvancedPS.PG(N_PARTICLES) + p = test_algorithm(rng, pg, model, N_SAMPLES, Xf) + @test p > 0.05 + end +end diff --git a/test/runtests.jl b/test/runtests.jl index b2c0990..03dd253 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,6 +3,7 @@ using AbstractMCMC using Distributions using Libtask using Random +using StableRNGs using Test using SSMProblems @@ -22,4 +23,7 @@ using SSMProblems @testset "PG-AS" begin include("pgas.jl") end + @testset "Linear Gaussian SSM tests" begin + include("linear-gaussian.jl") + end end