Skip to content

Commit

Permalink
Merge branch 'master' into fred/extension
Browse files Browse the repository at this point in the history
  • Loading branch information
yebai authored Sep 17, 2023
2 parents 85f9338 + 72a1e55 commit 631900a
Showing 1 changed file with 40 additions and 11 deletions.
51 changes: 40 additions & 11 deletions examples/particle-gibbs/script.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# # Particle Gibbs for non-linear models
using AdvancedPS
using Random
using Distributions
Expand All @@ -13,34 +12,33 @@ using Libtask
# x_{t+1} = a x_t + v_t \quad v_{t} \sim \mathcal{N}(0, r^2)
# ```
# ```math
# y_{t} = e_t \exp(\frac{1}{2}x_t) \quad v_{t} \sim \mathcal{N}(0, 1)
# y_{t} = e_t \exp(\frac{1}{2}x_t) \quad e_t \sim \mathcal{N}(0, 1)
# ```
#
# Here we assume the static parameters $\theta = (q^2, r^2)$ are known and we are only interested in sampling from the latent state $x_t$.
# We can reformulate the above in terms of transition and observation densities:
# ```math
# x_{t+1} \sim f_{\theta}(x_{t+1}|x_t) = \mathcal{N}(a x_t, q^2)
# x_{t+1} \sim f_{\theta}(x_{t+1}|x_t) = \mathcal{N}(a x_t, r^2)
# ```
# ```math
# y_t \sim g_{\theta}(y_t|x_t) = \mathcal{N}(0, \exp(\frac{1}{2}x_t)^2)
# ```
# with the initial distribution $f_0(x) = \mathcal{N}(0, q^2)$.
# Here we assume the static parameters $\theta = (q^2, r^2)$ are known and we are only interested in sampling from the latent state $x_t$.
Parameters = @NamedTuple begin
a::Float64
q::Float64
T::Int
end

mutable struct NonLinearTimeSeries <: AbstractMCMC.AbstractModel
mutable struct NonLinearTimeSeries <: AdvancedPS.AbstractStateSpaceModel
X::Array
θ::Parameters
NonLinearTimeSeries::Parameters) = new(zeros(Float64, θ.T), θ)
end

f(model::NonLinearTimeSeries, state, t) = Normal(model.θ.a * state, model.θ.q)
g(model::NonLinearTimeSeries, state, t) = Normal(0, exp(0.5 * state)^2)
g(model::NonLinearTimeSeries, state, t) = Normal(0, exp(0.5 * state))
f₀(model::NonLinearTimeSeries) = Normal(0, model.θ.q)
#md nothing #hide

# Let's simulate some data
a = 0.9 # State Variance
Expand Down Expand Up @@ -88,8 +86,8 @@ end

# Here we use the particle gibbs kernel without adaptive resampling.
model = NonLinearTimeSeries(θ₀)
pgas = AdvancedPS.PG(Nₚ, 1.0)
chains = sample(rng, model, pgas, Nₛ; progress=false);
pg = AdvancedPS.PG(Nₚ, 1.0)
chains = sample(rng, model, pg, Nₛ; progress=false);
#md nothing #hide

# Each sampled trajectory holds a NonLinearTimeSeries model
Expand All @@ -98,8 +96,7 @@ mean_trajectory = mean(particles; dims=2)
#md nothing #hide

# We can now plot all the generated traces.
# Beyond the last few timesteps all the trajectories collapse into one. Using the ancestor updating step can help
# with the degeneracy problem.
# Beyond the last few timesteps all the trajectories collapse into one. Using the ancestor updating step can help with the degeneracy problem, as we show below.
scatter(particles; label=false, opacity=0.01, color=:black, xlabel="t", ylabel="state")
plot!(x; color=:darkorange, label="Original Trajectory")
plot!(mean_trajectory; color=:dodgerblue, label="Mean trajectory", opacity=0.9)
Expand All @@ -119,3 +116,35 @@ plot(
ylabel="Update rate",
)
hline!([1 - 1 / Nₚ]; label="N: $(Nₚ)")

# Let's see if ancestor sampling can help with the degeneracy problem. We use the same number of particles, but replace the sampler with PGAS.
# To use this sampler we need to define the transition and observation densities as well as the initial distribution in the following way:
AdvancedPS.initialization(model::NonLinearTimeSeries) = f₀(model)
AdvancedPS.transition(model::NonLinearTimeSeries, state, step) = f(model, state, step)
function AdvancedPS.observation(model::NonLinearTimeSeries, state, step)
return logpdf(g(model, state, step), y[step])
end
AdvancedPS.isdone(::NonLinearTimeSeries, step) = step > Tₘ

# We can now sample from the model using the PGAS sampler and collect the trajectories.
pg = AdvancedPS.PGAS(Nₚ)
chains = sample(model, pg, Nₛ);
particles = hcat([trajectory.model.f.X for trajectory in trajectories]...)
mean_trajectory = mean(particles; dims=2)

# The ancestor sampling has helped with the degeneracy problem and we now have a much more diverse set of trajectories, also at earlier time periods.
scatter(particles; label=false, opacity=0.01, color=:black, xlabel="t", ylabel="state")
plot!(x; color=:darkorange, label="Original Trajectory")
plot!(mean_trajectory; color=:dodgerblue, label="Mean trajectory", opacity=0.9)

# The update rate is now much higher throughout time.
update_rate = sum(abs.(diff(particles; dims=2)) .> 0; dims=2) / Nₛ
plot(
update_rate;
label=false,
ylim=[0, 1],
legend=:bottomleft,
xlabel="Iteration",
ylabel="Update rate",
)
hline!([1 - 1 / Nₚ]; label="N: $(Nₚ)")

0 comments on commit 631900a

Please sign in to comment.