Skip to content

Commit

Permalink
Linear Gaussian unit test (#98)
Browse files Browse the repository at this point in the history
* Add unit test for linear Gaussian SSM

* Replace matrix model dynamics with scalar

* Remove redundant stack

Fixes CI error that came from `stack` not being available in Julia 1.7.

* Increase particle count and ensure reproducibility

* Update test to SSMProblems interface
  • Loading branch information
THargreaves authored Apr 15, 2024
1 parent 2880bd3 commit c7766f5
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 0 deletions.
5 changes: 5 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DynamicIterators = "6c76993d-992e-5bf1-9e63-34920a5a5a38"
GaussianDistributions = "43dcc890-d446-5863-8d1a-14597580bb8d"
HypothesisTests = "09f84164-cd44-5f33-b23f-e6b0d136a0d5"
Kalman = "d59c0ba6-2ef2-5409-8dc5-1fd9a2b46832"
Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Random123 = "74087812-796a-5b5d-8853-05524746bad3"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
SSMProblems = "26aad666-b158-4e64-9d35-0e672562fa48"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

Expand Down
111 changes: 111 additions & 0 deletions test/linear-gaussian.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
"""
Unit tests for the validity of the SMC algorithms included in this package.
We test each SMC algorithm on a one-dimensional linear Gaussian state space model for which
an analytic filtering distribution can be computed using the Kalman filter provided by the
`Kalman.jl` package.
The validity of the algorithm is tested by comparing the final estimated filtering
distribution ground truth using a one-sided Kolmogorov-Smirnov test.
"""

using DynamicIterators
using GaussianDistributions
using HypothesisTests
using Kalman

function test_algorithm(rng, algorithm, model, N_SAMPLES, Xf)
chains = sample(rng, model, algorithm, N_SAMPLES; progress=false)
particles = hcat([chain.trajectory.model.X for chain in chains]...)
final_particles = particles[:, end]

test = ExactOneSampleKSTest(final_particles, Normal(Xf.x[end].μ, sqrt(Xf.x[end].Σ)))
return pvalue(test)
end

@testset "linear-gaussian.jl" begin
T = 3
N_PARTICLES = 100
N_SAMPLES = 50

# Model dynamics
a = 0.5
b = 0.2
q = 0.1
E = LinearEvolution(a, Gaussian(b, q))

H = 1.0
R = 0.1
Obs = LinearObservationModel(H, R)

x0 = 0.0
P0 = 1.0
G0 = Gaussian(x0, P0)

M = LinearStateSpaceModel(E, Obs)
O = LinearObservation(E, H, R)

# Simulate from model
rng = StableRNG(1234)
initial = rand(rng, StateObs(G0, M.obs))
trajectory = trace(DynamicIterators.Sampled(M, rng), 1 => initial, endtime(T))
y_pairs = collect(t => y for (t, (x, y)) in pairs(trajectory))
ys = [y for (t, (x, y)) in pairs(trajectory)]

# Ground truth smoothing
Xf, ll = kalmanfilter(M, 1 => G0, y_pairs)

# Define AdvancedPS model
mutable struct LinearGaussianParams
a::Float64
b::Float64
q::Float64
h::Float64
r::Float64
x0::Float64
p0::Float64
end

mutable struct LinearGaussianModel <: SSMProblems.AbstractStateSpaceModel
X::Vector{Float64}
observations::Vector{Float64}
θ::LinearGaussianParams
function LinearGaussianModel(y::Vector{Float64}, θ::LinearGaussianParams)
return new(Vector{Float64}(), y, θ)
end
end

function SSMProblems.transition!!(rng::AbstractRNG, model::LinearGaussianModel)
return rand(rng, Normal(model.θ.x0, model.θ.p0))
end
function SSMProblems.transition!!(
rng::AbstractRNG, model::LinearGaussianModel, state, step
)
return rand(rng, Normal(model.θ.a * state + model.θ.b, model.θ.q))
end
function SSMProblems.transition_logdensity(
model::LinearGaussianModel, prev_state, current_state, step
)
return logpdf(Normal(model.θ.a * prev_state + model.θ.b, model.θ.q), current_state)
end
function SSMProblems.emission_logdensity(model::LinearGaussianModel, state, step)
return logpdf(Normal(model.θ.h * state, model.θ.r), model.observations[step])
end

AdvancedPS.isdone(::LinearGaussianModel, step) = step > T

params = LinearGaussianParams(a, b, q, H, R, x0, P0)
model = LinearGaussianModel(ys, params)

@testset "PGAS" begin
pgas = AdvancedPS.PGAS(N_PARTICLES)
p = test_algorithm(rng, pgas, model, N_SAMPLES, Xf)
@test p > 0.05
end

@testset "PG" begin
pg = AdvancedPS.PG(N_PARTICLES)
p = test_algorithm(rng, pg, model, N_SAMPLES, Xf)
@test p > 0.05
end
end
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ using AbstractMCMC
using Distributions
using Libtask
using Random
using StableRNGs
using Test
using SSMProblems

Expand All @@ -22,4 +23,7 @@ using SSMProblems
@testset "PG-AS" begin
include("pgas.jl")
end
@testset "Linear Gaussian SSM tests" begin
include("linear-gaussian.jl")
end
end

2 comments on commit c7766f5

@yebai
Copy link
Member

@yebai yebai commented on c7766f5 Apr 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/104950

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.6.0 -m "<description of version>" c7766f5be9a414f2b587d73ef38c22c5a4622b92
git push origin v0.6.0

Please sign in to comment.