Skip to content

Commit

Permalink
Test deploy (#50)
Browse files Browse the repository at this point in the history
* Fix #53

* Test deploy

* Missing base package

* Develop local

* Formatx

* Typos

* Print

* Update src/model.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Color

* Seed

* Particle Gibbs

Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored May 26, 2022
1 parent cc2b444 commit 6b8bc0c
Show file tree
Hide file tree
Showing 6 changed files with 199 additions and 39 deletions.
1 change: 1 addition & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
[deps]
AdvancedPS = "576499cb-2369-40b2-a588-c64705576edc"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"

Expand Down
3 changes: 2 additions & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ mkpath(EXAMPLES_OUT)
# Install and precompile all packages
# Workaround for https://github.com/JuliaLang/Pkg.jl/issues/2219
examples = filter!(isdir, readdir(joinpath(@__DIR__, "..", "examples"); join=true))
let script = "using Pkg; Pkg.activate(ARGS[1]); Pkg.instantiate()"
above = joinpath(@__DIR__, "..")
let script = "using Pkg; Pkg.activate(ARGS[1]); Pkg.instantiate(); Pkg.develop(path=\"$(above)\");"
for example in examples
if !success(`$(Base.julia_cmd()) -e $script $example`)
error(
Expand Down
94 changes: 68 additions & 26 deletions examples/gaussian-ssm/script.jl
Original file line number Diff line number Diff line change
@@ -1,51 +1,79 @@
# # Gaussian State Space Model
# # Particle Gibbs with Ancestor Sampling
using AdvancedPS
using Random
using Distributions
using Plots

# Parameters
# We consider the following linear state-space model with Gaussian innovations. The latent state is a simple gaussian random walk
# and the observation is linear in the latent states, namely:
#
# ```math
# x_{t+1} = a x_{t} + \epsilon_t \quad \epsilon_t \sim \mathcal{N}(0,q^2)
# ```
# ```math
# y_{t} = x_{t} + \nu_{t} \quad \nu_{t} \sim \mathcal{N}(0, r^2)
# ```
#
# Here we assume the static parameters $\theta = (a, q^2, r^2)$ are known and we are only interested in sampling from the latent state $x_t$.
# To use particle gibbs with the ancestor sampling step we need to provide both the transition and observation densities.
# From the definition above we get:
# ```math
# x_{t+1} \sim f_{\theta}(x_{t+1}|x_t) = \mathcal{N}(a x_t, q^2)
# ```
# ```math
# y_t \sim g_{\theta}(y_t|x_t) = \mathcal{N}(x_t, q^2)
# ```
# as well as the initial distribution $f_0(x) = \mathcal{N}(0, q^2/(1-a^2))$.

# We are ready to use `AdvancedPS` with our model. We first need to define a model type that subtypes `AdvancedPS.AbstractStateSpaceModel`.
Parameters = @NamedTuple begin
a::Float64
q::Float64
r::Float64
end

a = 0.9 # Scale
q = 0.32 # State variance
r = 1 # Observation variance
Tₘ = 300 # Number of observation
Nₚ = 50 # Number of particles
Nₛ = 1000 # Number of samples
seed = 9 # Reproduce everything

θ₀ = Parameters((a, q, r))

mutable struct NonLinearTimeSeries <: AdvancedPS.AbstractStateSpaceModel
X::Vector{Float64}
θ::Parameters
NonLinearTimeSeries::Parameters) = new(Vector{Float64}(), θ)
NonLinearTimeSeries() = new(Vector{Float64}(), θ₀)
end

# and the densities defined above.
f(m::NonLinearTimeSeries, state, t) = Normal(m.θ.a * state, m.θ.q) # Transition density
g(m::NonLinearTimeSeries, state, t) = Normal(state, m.θ.r) # Observation density
f₀(m::NonLinearTimeSeries) = Normal(0, m.θ.q / (1 - m.θ.a^2)) # Initial state density
g(m::NonLinearTimeSeries, state, t) = Normal(state, m.θ.r) # Observation density
f₀(m::NonLinearTimeSeries) = Normal(0, m.θ.q^2 / (1 - m.θ.a^2)) # Initial state density
#md nothing #hide

# To implement `AdvancedPS.AbstractStateSpaceModel` we need to define a few functions to specify the dynamics of the system:
# - `AdvancedPS.initialization` the initial state density
# - `AdvancedPS.transition` the state transition density
# - `AdvancedPS.observation` the observation score given the observed data
# - `AdvancedPS.isdone` signals the end of the execution for the model
AdvancedPS.initialization(model::NonLinearTimeSeries) = f₀(model)
AdvancedPS.transition(model::NonLinearTimeSeries, state, step) = f(model, state, step)
function AdvancedPS.observation(model::NonLinearTimeSeries, state, step)
return logpdf(g(model, state, step), y[step])
end # Return log-pdf of the obs
end
AdvancedPS.isdone(::NonLinearTimeSeries, step) = step > Tₘ

# Generate some synthetic data
# Everything is now ready to simulate some data.

a = 0.9 # Scale
q = 0.32 # State variance
r = 1 # Observation variance
Tₘ = 300 # Number of observation
Nₚ = 50 # Number of particles
Nₛ = 500 # Number of samples
seed = 9 # Reproduce everything

θ₀ = Parameters((a, q, r))

rng = Random.MersenneTwister(seed)

x = zeros(Tₘ)
y = zeros(Tₘ)

reference = NonLinearTimeSeries(θ₀) # Reference model
reference = NonLinearTimeSeries(θ₀)
x[1] = rand(rng, f₀(reference))
for t in 1:Tₘ
if t < Tₘ
Expand All @@ -54,22 +82,36 @@ for t in 1:Tₘ
y[t] = rand(rng, g(reference, x[t], t))
end

plot(x; label="x", color=:black)
plot(y; label="y", color=:black)
# Let's have a look at the simulated data from the latent state dynamics
plot(x; label="x")
xlabel!("t")

# and the observation data
plot(y; label="y")
xlabel!("x")

# Setup up a particle Gibbs sampler
#
model = NonLinearTimeSeries(θ₀)
pgas = AdvancedPS.PGAS(Nₚ)
chains = sample(rng, model, pgas, Nₛ)
chains = sample(rng, model, pgas, Nₛ; progress=false)
#md nothing #hide

# The actual sampled trajectory is in the trajectory inner model
particles = hcat([chain.trajectory.model.X for chain in chains]...) # Concat all sampled states
mean_trajectory = mean(particles; dims=2)
#md nothing #hide

#
scatter(particles; label=false, opacity=0.01, color=:black)
plot!(x; color=:red, label="Original Trajectory")
plot!(mean_trajectory; color=:blue, label="Posterior mean", opacity=0.9)
plot!(mean_trajectory; color=:orange, label="Mean trajectory", opacity=0.9)
xlabel!("t")
ylabel!("State")

# Compute mixing rate
# By sampling an ancestor from the reference particle in the particle gibbs sampler we split the
# trajectory of the reference particle.
update_rate = sum(abs.(diff(particles; dims=2)) .> 0; dims=2) / Nₛ
plot(update_rate; label=false, ylim=[0, 1])
hline!([1 - 1 / Nₚ]; label=false)
plot(update_rate; label=false, ylim=[0, 1], legend=:bottomleft)
hline!([1 - 1 / Nₚ]; label="N: $(Nₚ)")
xlabel!("Iteration")
ylabel!("Update rate")
10 changes: 10 additions & 0 deletions examples/particle-gibbs/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
AdvancedPS = "576499cb-2369-40b2-a588-c64705576edc"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Random123 = "74087812-796a-5b5d-8853-05524746bad3"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
106 changes: 106 additions & 0 deletions examples/particle-gibbs/script.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# # Particle Gibbs with Ancestor Sampling
using AdvancedPS
using Random
using Distributions
using Plots
using AbstractMCMC
using Random123
using Libtask: TArray
using Libtask

Parameters = @NamedTuple begin
a::Float64
q::Float64
r::Float64
T::Int
end

mutable struct NonLinearTimeSeries <: AbstractMCMC.AbstractModel
X::TArray
θ::Parameters
NonLinearTimeSeries::Parameters) = new(TArray(Float64, θ.T), θ)
end

f(model::NonLinearTimeSeries, state, t) = Normal(model.θ.a * state, model.θ.q) # Transition density
g(model::NonLinearTimeSeries, state, t) = Normal(state, model.θ.r) # Observation density
f₀(model::NonLinearTimeSeries) = Normal(0, model.θ.q) # Initial state density

# Everything is now ready to simulate some data.

a = 0.9 # Scale
q = 0.32 # State variance
r = 1 # Observation variance
Tₘ = 300 # Number of observation
Nₚ = 20 # Number of particles
Nₛ = 500 # Number of samples
seed = 9 # Reproduce everything

θ₀ = Parameters((a, q, r, Tₘ))

rng = Random.MersenneTwister(seed)

x = zeros(Tₘ)
y = zeros(Tₘ)

reference = NonLinearTimeSeries(θ₀)
x[1] = 0
for t in 1:Tₘ
if t < Tₘ
x[t + 1] = rand(rng, f(reference, x[t], t))
end
y[t] = rand(rng, g(reference, x[t], t))
end

function (model::NonLinearTimeSeries)(rng::Random.AbstractRNG)
x₀ = rand(rng, f₀(model))
model.X[1] = x₀
score = logpdf(g(model, x₀, 1), y[1])
Libtask.produce(score)

for t in 2:(model.θ.T)
state = rand(rng, f(model, model.X[t - 1], t - 1))
model.X[t] = state
score = logpdf(g(model, state, t), y[t])
Libtask.produce(score)
end
end

Libtask.tape_copy(model::NonLinearTimeSeries) = deepcopy(model)

Random.seed!(rng, seed)
model = NonLinearTimeSeries(θ₀)
pgas = AdvancedPS.PG(Nₚ)
chains = sample(rng, model, pgas, Nₛ; progress=false)

# Utility to replay a particle trajectory
function replay(particle::AdvancedPS.Particle)
trng = deepcopy(particle.rng)
Random123.set_counter!(trng.rng, 0)
trng.count = 1
model = NonLinearTimeSeries(θ₀)
trace = AdvancedPS.Trace(AdvancedPS.GenericModel(model, trng), trng)
score = AdvancedPS.advance!(trace, true)
while !isnothing(score)
score = AdvancedPS.advance!(trace, true)
end
return trace
end

trajectories = map(chains) do sample
replay(sample.trajectory)
end

particles = hcat([trajectory.model.f.X for trajectory in trajectories]...) # Concat all sampled states
mean_trajectory = mean(particles; dims=2)

scatter(particles; label=false, opacity=0.01, color=:black)
plot!(x; color=:red, label="Original Trajectory")
plot!(mean_trajectory; color=:orange, label="Mean trajectory", opacity=0.9)
xlabel!("t")
ylabel!("State")

update_rate = sum(abs.(diff(particles; dims=2)) .> 0; dims=2) / Nₛ
plot(update_rate; label=false, ylim=[0, 1], legend=:bottomleft)
hline!([1 - 1 / Nₚ]; label="N: $(Nₚ)")
xlabel!("Iteration")
ylabel!("Update rate")
24 changes: 12 additions & 12 deletions src/smc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ function AbstractMCMC.sample(

traces = map(1:(sampler.nparticles)) do i
trng = TracedRNG()
tmodel = GenericModel(model, trng)
tmodel = GenericModel(deepcopy(model), trng)
Trace(tmodel, trng)
end

Expand Down Expand Up @@ -85,15 +85,6 @@ struct PGSample{T,L}
logevidence::L
end

struct PGAS{R} <: AbstractMCMC.AbstractSampler
"""Number of particles."""
nparticles::Int
"""Resampling algorithm."""
resampler::R
end

PGAS(nparticles::Int) = PGAS(nparticles, ResampleWithESSThreshold(1.0))

function AbstractMCMC.step(
rng::Random.AbstractRNG,
model::AbstractMCMC.AbstractModel,
Expand All @@ -108,10 +99,10 @@ function AbstractMCMC.step(
traces = map(1:nparticles) do i
if i == nparticles && isref
# Create reference trajectory.
forkr(state.trajectory)
forkr(copy(state.trajectory))
else
trng = TracedRNG()
Trace(GenericModel(model, trng), trng)
Trace(GenericModel(deepcopy(model), trng), trng)
end
end

Expand All @@ -127,6 +118,15 @@ function AbstractMCMC.step(
return PGSample(newtrajectory, logevidence), PGState(newtrajectory)
end

struct PGAS{R} <: AbstractMCMC.AbstractSampler
"""Number of particles."""
nparticles::Int
"""Resampling algorithm."""
resampler::R
end

PGAS(nparticles::Int) = PGAS(nparticles, ResampleWithESSThreshold(1.0))

function AbstractMCMC.step(
rng::Random.AbstractRNG,
model::AbstractStateSpaceModel,
Expand Down

0 comments on commit 6b8bc0c

Please sign in to comment.