Skip to content

Commit

Permalink
Merge branch 'master' into fred/tests
Browse files Browse the repository at this point in the history
  • Loading branch information
yebai authored Oct 29, 2022
2 parents 57325db + c99f920 commit 7905882
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 6 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "AdvancedPS"
uuid = "576499cb-2369-40b2-a588-c64705576edc"
authors = ["TuringLang"]
version = "0.4.0"
version = "0.4.1"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down
7 changes: 7 additions & 0 deletions examples/gaussian-process/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[deps]
AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918"
AdvancedPS = "576499cb-2369-40b2-a588-c64705576edc"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
73 changes: 73 additions & 0 deletions examples/gaussian-process/script.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# # Gaussian Process innovation
using LinearAlgebra
using Random
using AdvancedPS
using AbstractGPs
using Plots
using Distributions
using Libtask

Parameters = @NamedTuple begin
a::Float64
q::Float64
kernel
end

mutable struct GPSSM <: AdvancedPS.AbstractStateSpaceModel
X::Vector{Float64}
θ::Parameters

GPSSM(params::Parameters) = new(Vector{Float64}(), params)
end

seed = 1
T = 100
Nₚ = 20
Nₛ = 250
a = 0.9
q = 0.5

params = Parameters((a, q, SqExponentialKernel()))

f(model::GPSSM, x, t) = Normal(model.θ.a * x, model.θ.q)
h(model::GPSSM) = Normal(0, model.θ.q)
g(model::GPSSM, x, t) = Normal(0, exp(0.5 * x)^2)

rng = Random.MersenneTwister(seed)
ref_model = GPSSM(params)

x = zeros(T)
y = similar(x)
x[1] = rand(rng, h(ref_model))
for t in 1:T
if t < T
x[t + 1] = rand(rng, f(ref_model, x[t], t))
end
y[t] = rand(rng, g(ref_model, x[t], t))
end

function gp_update(model::GPSSM, state, step)
gp = GP(model.θ.kernel)
prior = gp(1:(step - 1))
post = posterior(prior, model.X[1:(step - 1)])
μ, σ = mean_and_cov(post, [step])
return Normal(μ[1], σ[1])
end

Libtask.tape_copy(model::GPSSM) = deepcopy(model)

AdvancedPS.initialization(::GPSSM) = h(model)
AdvancedPS.transition(model::GPSSM, state, step) = gp_update(model, state, step)
AdvancedPS.observation(model::GPSSM, state, step) = logpdf(g(model, state, step), y[step])
AdvancedPS.isdone(::GPSSM, step) = step > T

model = GPSSM(params)
pg = AdvancedPS.PGAS(Nₚ)
chains = sample(rng, model, pg, Nₛ)

particles = hcat([chain.trajectory.model.X for chain in chains]...)
mean_trajectory = mean(particles; dims=2);

scatter(particles; label=false, opacity=0.01, color=:black, xlabel="t", ylabel="state")
plot!(x; color=:darkorange, label="Original Trajectory")
plot!(mean_trajectory; color=:dodgerblue, label="Mean trajectory", opacity=0.9)
17 changes: 12 additions & 5 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
State wrapper to hold `Libtask.CTask` model initiated from `f`
"""
struct GenericModel{F} <: AbstractMCMC.AbstractModel
f::F
ctask::Libtask.TapedTask{F}
struct GenericModel{F1,F2} <: AbstractMCMC.AbstractModel
f::F1
ctask::Libtask.TapedTask{F2}

GenericModel(f::F1, ctask::Libtask.TapedTask{F2}) where {F1,F2} = new{F1,F2}(f, ctask)
end

GenericModel(f, args...) = GenericModel(f, Libtask.TapedTask(f, args...))
Expand Down Expand Up @@ -48,11 +50,16 @@ end

current_trace() = current_task().storage[:__trace]

function update_rng!(trace::GenericTrace)
rng, = trace.model.ctask.args
trace.rng = rng
return trace
end

# Task copying version of fork for Trace.
function fork(trace::GenericTrace, isref::Bool=false)
newtrace = copy(trace)
rng, = newtrace.model.ctask.args
newtrace.rng = rng
update_rng!(newtrace)
isref && delete_retained!(newtrace.model.f)
isref && delete_seeds!(newtrace)

Expand Down

0 comments on commit 7905882

Please sign in to comment.