Skip to content

Commit

Permalink
Clean up types, makes API area cleaner (but broken)
Browse files Browse the repository at this point in the history
  • Loading branch information
FredericWantiez committed Oct 1, 2023
1 parent 4f78aac commit 006c4fd
Show file tree
Hide file tree
Showing 8 changed files with 66 additions and 88 deletions.
21 changes: 9 additions & 12 deletions examples/particle-gibbs/script.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,14 @@ using Random123
using Libtask

# We consider the following stochastic volatility model:
#
#
# ```math
# x_{t+1} = a x_t + v_t \quad v_{t} \sim \mathcal{N}(0, r^2)
# ```
# ```math
# y_{t} = e_t \exp(\frac{1}{2}x_t) \quad v_{t} \sim \mathcal{N}(0, 1)
# y_{t} = e_t \exp(\frac{1}{2}x_t) \quad e_t \sim \mathcal{N}(0, 1)
# ```
#
# Here we assume the static parameters $\theta = (q^2, r^2)$ are known and we are only interested in sampling from the latent state $x_t$.
# We can reformulate the above in terms of transition and observation densities:
# ```math
# x_{t+1} \sim f_{\theta}(x_{t+1}|x_t) = \mathcal{N}(a x_t, r^2)
Expand All @@ -31,7 +30,7 @@ Parameters = @NamedTuple begin
T::Int
end

mutable struct NonLinearTimeSeries <: AdvancedPS.AbstractStateSpaceModel
mutable struct NonLinearTimeSeries <: AdvancedPS.AbstractGenericModel
X::Array
θ::Parameters
NonLinearTimeSeries::Parameters) = new(zeros(Float64, θ.T), θ)
Expand Down Expand Up @@ -67,7 +66,7 @@ end
# Here are the latent and observation series:
plot(x; label="x", xlabel="t")

#
#
plot(y; label="y", xlabel="t")

# Each model takes an `AbstractRNG` as input and generates the logpdf of the current transition:
Expand All @@ -88,23 +87,21 @@ end
# Here we use the particle gibbs kernel without adaptive resampling.
model = NonLinearTimeSeries(θ₀)
pg = AdvancedPS.PG(Nₚ, 1.0)
chains = sample(rng, model, pg, Nₛ; progress=false);
chains = sample(rng, model, pg, Nₛ; progress=true);
#md nothing #hide

# Each sampled trajectory holds a NonLinearTimeSeries model
particles = hcat([chain.trajectory.X for chain in chains]...) # Concat all sampled states
mean_trajectory = mean(particles; dims=2)
#md nothing #hide

# We can now plot all the generated traces.
# Beyond the last few timesteps all the trajectories collapse into one. Using the ancestor updating step can help
# with the degeneracy problem.
scatter(particles; label=false, opacity=0.01, color=:black, xlabel="t", ylabel="state")
# Beyond the last few timesteps all the trajectories collapse into one. Using the ancestor updating step can help with the degeneracy problem, as we show below.
scatter(particles; label=false, opacity=1.01, color=:black, xlabel="t", ylabel="state")
plot!(x; color=:darkorange, label="Original Trajectory")
plot!(mean_trajectory; color=:dodgerblue, label="Mean trajectory", opacity=0.9)

# We can also check the mixing as defined in the Gaussian State Space model example. As seen on the
# scatter plot above, we are mostly left with a single trajectory before timestep 150. The orange
# scatter plot above, we are mostly left with a single trajectory before timestep 150. The orange
# bar is the optimal mixing rate for the number of particles we use.
update_rate = sum(abs.(diff(particles; dims=2)) .> 0; dims=2) / Nₛ
#md nothing #hide
Expand All @@ -131,7 +128,7 @@ AdvancedPS.isdone(::NonLinearTimeSeries, step) = step > Tₘ
# We can now sample from the model using the PGAS sampler and collect the trajectories.
pg = AdvancedPS.PGAS(Nₚ)
chains = sample(model, pg, Nₛ);
particles = hcat([trajectory.model.f.X for trajectory in trajectories]...)
particles = hcat([AdvancedPS.replay(chain.trajectory).model.X for chain in chains]...)
mean_trajectory = mean(particles; dims=2)

# The ancestor sampling has helped with the degeneracy problem and we now have a much more diverse set of trajectories, also at earlier time periods.
Expand Down
18 changes: 12 additions & 6 deletions ext/AdvancedPSLibtaskExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@ Base.copy(model::LibtaskModel) = LibtaskModel(model.f, copy(model.ctask))

const LibtaskTrace{R} = AdvancedPS.Trace{<:LibtaskModel,R}

function AdvancedPS.Trace(
model::AdvancedPS.AbstractGenericModel, rng::Random.AbstractRNG, args...
)
return AdvancedPS.Trace(LibtaskModel(model, args...), rng)
end

# step to the next observe statement and
# return the log probability of the transition (or nothing if done)
function AdvancedPS.advance!(t::LibtaskTrace, isref::Bool=false)
Expand Down Expand Up @@ -114,7 +120,7 @@ AbstractMCMC interface. We need libtask to sample from arbitrary callable Abstra

function AbstractMCMC.step(
rng::Random.AbstractRNG,
model::AbstractMCMC.AbstractModel,
model::AdvancedPS.AbstractGenericModel,
sampler::AdvancedPS.PG,
state::Union{AdvancedPS.PGState,Nothing}=nothing;
kwargs...,
Expand Down Expand Up @@ -145,20 +151,20 @@ function AbstractMCMC.step(
# Pick a particle to be retained.
newtrajectory = rand(rng, particles)

replayed = replay(newtrajectory)
replayed = AdvancedPS.replay(newtrajectory)
return AdvancedPS.PGSample(replayed.model.f, logevidence),
AdvancedPS.PGState(newtrajectory)
end

function AbstractMCMC.sample(
model::AbstractMCMC.AbstractModel, sampler::AdvancedPS.SMC; kwargs...
model::AdvancedPS.AbstractGenericModel, sampler::AdvancedPS.SMC; kwargs...
)
return AbstractMCMC.sample(Random.GLOBAL_RNG, model, sampler; kwargs...)
end

function AbstractMCMC.sample(
rng::Random.AbstractRNG,
model::AbstractMCMC.AbstractModel,
model::AdvancedPS.AbstractGenericModel,
sampler::AdvancedPS.SMC;
kwargs...,
)
Expand All @@ -180,7 +186,7 @@ function AbstractMCMC.sample(
# Perform particle sweep.
logevidence = AdvancedPS.sweep!(rng, particles, sampler.resampler)

replayed = map(particle -> replay(particle).model.f, particles.vals)
replayed = map(particle -> AdvancedPS.replay(particle).model.f, particles.vals)

return AdvancedPS.SMCSample(
collect(replayed), AdvancedPS.getweights(particles), logevidence
Expand All @@ -192,7 +198,7 @@ end
Rewind the particle and regenerate all sampled values. This ensures that the returned trajectory has the correct values.
"""
function replay(particle::AdvancedPS.Particle)
function AdvancedPS.replay(particle::AdvancedPS.Particle)
trng = deepcopy(particle.rng)
Random123.set_counter!(trng.rng, 0)
trng.count = 1
Expand Down
5 changes: 4 additions & 1 deletion src/AdvancedPS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@ using Random: Random
using StatsFuns: StatsFuns
using Random123: Random123

abstract type AbstractParticleModel <: AbstractMCMC.AbstractModel end

""" Abstract type for an abstract model formulated in the state space form
"""
abstract type AbstractStateSpaceModel <: AbstractMCMC.AbstractModel end
abstract type AbstractStateSpaceModel <: AbstractParticleModel end
abstract type AbstractGenericModel <: AbstractParticleModel end

include("resampling.jl")
include("rng.jl")
Expand Down
3 changes: 3 additions & 0 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ end

const Particle = Trace
const SSMTrace{R} = Trace{<:AbstractStateSpaceModel,R}
const GenericTrace{R} = Trace{<:AbstractGenericModel,R}

# reset log probability
reset_logprob!(::AdvancedPS.Particle) = nothing
Expand All @@ -17,7 +18,9 @@ delete_retained!(f) = nothing

Base.copy(trace::Trace) = Trace(copy(trace.model), deepcopy(trace.rng))

# This is required to make it visible from outside extensions
function observe end
function replay end

"""
gen_refseed!(particle::Particle)
Expand Down
33 changes: 0 additions & 33 deletions src/smc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,39 +84,6 @@ struct PGSample{T,L}
logevidence::L
end

function AbstractMCMC.step(
rng::Random.AbstractRNG,
model::AbstractStateSpaceModel,
sampler::PG,
state::Union{PGState,Nothing}=nothing;
kwargs...,
)
# Create a new set of particles.
nparticles = sampler.nparticles
isref = !isnothing(state)

traces = map(1:nparticles) do i
if i == nparticles && isref
# Create reference trajectory.
forkr(deepcopy(state.trajectory))
else
trng = TracedRNG()
Trace(deepcopy(model), trng)
end
end

particles = ParticleContainer(traces, TracedRNG(), rng)

# Perform a particle sweep.
reference = isref ? particles.vals[nparticles] : nothing
logevidence = sweep!(rng, particles, sampler.resampler, reference)

# Pick a particle to be retained.
newtrajectory = rand(rng, particles)

return PGSample(newtrajectory, logevidence), PGState(newtrajectory)
end

struct PGAS{R} <: AbstractMCMC.AbstractSampler
"""Number of particles."""
nparticles::Int
Expand Down
56 changes: 29 additions & 27 deletions test/container.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,33 +112,35 @@
@test all(iszero, pc.logWs)
end

#@testset "trace" begin
# n = Ref(0)
# function f2(rng)
# t = [0]
# while true
# n[] += 1
# produce(t[1])
# n[] += 1
# t[1] = 1 + t[1]
# end
# end

# # Test task copy version of trace
# trng = AdvancedPS.TracedRNG()
# tmodel = AdvancedPS.GenericModel(f2, trng)
# tr = AdvancedPS.Trace(tmodel, trng)

# consume(tr.model.ctask)
# consume(tr.model.ctask)

# a = AdvancedPS.fork(tr)
# consume(a.model.ctask)
# consume(a.model.ctask)

# @test consume(tr.model.ctask) == 2
# @test consume(a.model.ctask) == 4
#end
@testset "trace" begin
struct Model <: AdvancedPS.AbstractGenericModel
val::Ref{Int}
end

function (model::Model)(rng::Random.AbstractRNG)
t = [0]
while true
model.val[] += 1
produce(t[1])
model.val[] += 1
t[1] += 1
end
end

# Test task copy version of trace
trng = AdvancedPS.TracedRNG()
tr = AdvancedPS.Trace(Model(Ref(0)), trng, trng)

consume(tr.model.ctask)
consume(tr.model.ctask)

a = AdvancedPS.fork(tr)
consume(a.model.ctask)
consume(a.model.ctask)

@test consume(tr.model.ctask) == 2
@test consume(a.model.ctask) == 4
end

@testset "seed container" begin
seed = 1
Expand Down
2 changes: 1 addition & 1 deletion test/pgas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@
seed = 10
rng = Random.MersenneTwister(seed)

for sampler in [AdvancedPS.PGAS(10), AdvancedPS.PG(10)]
for sampler in [AdvancedPS.PGAS(10)]
Random.seed!(rng, seed)
chain1 = sample(rng, model, sampler, 10)
vals1 = hcat([chain.trajectory.model.X for chain in chain1]...)
Expand Down
16 changes: 8 additions & 8 deletions test/smc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

# Smoke tests
@testset "models" begin
mutable struct NormalModel <: AbstractMCMC.AbstractModel
mutable struct NormalModel <: AdvancedPS.AbstractGenericModel
a::Float64
b::Float64

Expand All @@ -45,7 +45,7 @@
sample(NormalModel(), AdvancedPS.SMC(100))

# failing test
mutable struct FailSMCModel <: AbstractMCMC.AbstractModel
mutable struct FailSMCModel <: AdvancedPS.AbstractGenericModel
a::Float64
b::Float64

Expand All @@ -66,7 +66,7 @@
@testset "logevidence" begin
Random.seed!(100)

mutable struct TestModel <: AbstractMCMC.AbstractModel
mutable struct TestModel <: AdvancedPS.AbstractGenericModel
a::Float64
x::Bool
b::Float64
Expand Down Expand Up @@ -120,7 +120,7 @@
@testset "logevidence" begin
Random.seed!(100)

mutable struct TestModel <: AbstractMCMC.AbstractModel
mutable struct TestModel <: AdvancedPS.AbstractGenericModel
a::Float64
x::Bool
b::Float64
Expand Down Expand Up @@ -152,14 +152,14 @@
end

@testset "Replay reference" begin
mutable struct Model <: AbstractMCMC.AbstractModel
mutable struct DummyModel <: AdvancedPS.AbstractGenericModel
a::Float64
b::Float64

Model() = new()
DummyModel() = new()
end

function (m::Model)(rng)
function (m::DummyModel)(rng)
m.a = rand(rng, Normal())
AdvancedPS.observe(Normal(), m.a)

Expand All @@ -168,7 +168,7 @@
end

pg = AdvancedPS.PG(1)
first, second = sample(Model(), pg, 2)
first, second = sample(DummyModel(), pg, 2)

first_model = first.trajectory
second_model = second.trajectory
Expand Down

0 comments on commit 006c4fd

Please sign in to comment.