From 467a9f3501ce6bca3ecc34c64437ad311fe0ab45 Mon Sep 17 00:00:00 2001 From: THargreaves Date: Mon, 15 Apr 2024 16:20:14 +0100 Subject: [PATCH] Update test to SSMProblems interface --- test/linear-gaussian.jl | 28 +++++++++++++++++++--------- 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/test/linear-gaussian.jl b/test/linear-gaussian.jl index e48d443..9872f1f 100644 --- a/test/linear-gaussian.jl +++ b/test/linear-gaussian.jl @@ -66,26 +66,36 @@ end p0::Float64 end - mutable struct LinearGaussianModel <: AdvancedPS.AbstractStateSpaceModel + mutable struct LinearGaussianModel <: SSMProblems.AbstractStateSpaceModel X::Vector{Float64} + observations::Vector{Float64} θ::LinearGaussianParams - LinearGaussianModel(params::LinearGaussianParams) = new(Vector{Float64}(), params) + function LinearGaussianModel(y::Vector{Float64}, θ::LinearGaussianParams) + return new(Vector{Float64}(), y, θ) + end end - function AdvancedPS.initialization(model::LinearGaussianModel) - return Normal(model.θ.x0, model.θ.p0) + function SSMProblems.transition!!(rng::AbstractRNG, model::LinearGaussianModel) + return rand(rng, Normal(model.θ.x0, model.θ.p0)) end - function AdvancedPS.transition(model::LinearGaussianModel, state, step) - return Normal(model.θ.a * state + model.θ.b, model.θ.q) + function SSMProblems.transition!!( + rng::AbstractRNG, model::LinearGaussianModel, state, step + ) + return rand(rng, Normal(model.θ.a * state + model.θ.b, model.θ.q)) end - function AdvancedPS.observation(model::LinearGaussianModel, state, step) - return logpdf(Normal(model.θ.h * state, model.θ.r), ys[step]) + function SSMProblems.transition_logdensity( + model::LinearGaussianModel, prev_state, current_state, step + ) + return logpdf(Normal(model.θ.a * prev_state + model.θ.b, model.θ.q), current_state) + end + function SSMProblems.emission_logdensity(model::LinearGaussianModel, state, step) + return logpdf(Normal(model.θ.h * state, model.θ.r), model.observations[step]) end AdvancedPS.isdone(::LinearGaussianModel, step) = step > T params = LinearGaussianParams(a, b, q, H, R, x0, P0) - model = LinearGaussianModel(params) + model = LinearGaussianModel(ys, params) @testset "PGAS" begin pgas = AdvancedPS.PGAS(N_PARTICLES)