Skip to content

Commit

Permalink
Update test to SSMProblems interface
Browse files Browse the repository at this point in the history
  • Loading branch information
THargreaves committed Apr 15, 2024
1 parent 592a6ec commit 467a9f3
Showing 1 changed file with 19 additions and 9 deletions.
28 changes: 19 additions & 9 deletions test/linear-gaussian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,26 +66,36 @@ end
p0::Float64
end

mutable struct LinearGaussianModel <: AdvancedPS.AbstractStateSpaceModel
mutable struct LinearGaussianModel <: SSMProblems.AbstractStateSpaceModel
X::Vector{Float64}
observations::Vector{Float64}
θ::LinearGaussianParams
LinearGaussianModel(params::LinearGaussianParams) = new(Vector{Float64}(), params)
function LinearGaussianModel(y::Vector{Float64}, θ::LinearGaussianParams)
return new(Vector{Float64}(), y, θ)
end
end

function AdvancedPS.initialization(model::LinearGaussianModel)
return Normal(model.θ.x0, model.θ.p0)
function SSMProblems.transition!!(rng::AbstractRNG, model::LinearGaussianModel)
return rand(rng, Normal(model.θ.x0, model.θ.p0))
end
function AdvancedPS.transition(model::LinearGaussianModel, state, step)
return Normal(model.θ.a * state + model.θ.b, model.θ.q)
function SSMProblems.transition!!(
rng::AbstractRNG, model::LinearGaussianModel, state, step
)
return rand(rng, Normal(model.θ.a * state + model.θ.b, model.θ.q))
end
function AdvancedPS.observation(model::LinearGaussianModel, state, step)
return logpdf(Normal(model.θ.h * state, model.θ.r), ys[step])
function SSMProblems.transition_logdensity(
model::LinearGaussianModel, prev_state, current_state, step
)
return logpdf(Normal(model.θ.a * prev_state + model.θ.b, model.θ.q), current_state)
end
function SSMProblems.emission_logdensity(model::LinearGaussianModel, state, step)
return logpdf(Normal(model.θ.h * state, model.θ.r), model.observations[step])
end

AdvancedPS.isdone(::LinearGaussianModel, step) = step > T

params = LinearGaussianParams(a, b, q, H, R, x0, P0)
model = LinearGaussianModel(params)
model = LinearGaussianModel(ys, params)

@testset "PGAS" begin
pgas = AdvancedPS.PGAS(N_PARTICLES)
Expand Down

0 comments on commit 467a9f3

Please sign in to comment.