Skip to content

Commit

Permalink
Merge fbea67e into c87cba7
Browse files Browse the repository at this point in the history
  • Loading branch information
devmotion authored Dec 5, 2020
2 parents c87cba7 + fbea67e commit 4555653
Show file tree
Hide file tree
Showing 8 changed files with 53 additions and 43 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
name = "AdvancedPS"
uuid = "576499cb-2369-40b2-a588-c64705576edc"
authors = ["TuringLang"]
version = "0.1.0"
version = "0.2.0"

[deps]
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"

[compat]
Expand Down
1 change: 1 addition & 0 deletions src/AdvancedPS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module AdvancedPS

import Distributions
import Libtask
import Random
import StatsFuns

include("resampling.jl")
Expand Down
24 changes: 13 additions & 11 deletions src/container.jl
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,8 @@ function effectiveSampleSize(pc::ParticleContainer)
end

"""
resample_propagate!(pc::ParticleContainer[, randcat = resample, ref = nothing;
weights = getweights(pc)])
resample_propagate!(rng, pc::ParticleContainer[, randcat = resample_systematic,
ref = nothing; weights = getweights(pc)])
Resample and propagate the particles in `pc`.
Expand All @@ -176,8 +176,9 @@ of the particle `weights`. For Particle Gibbs sampling, one can provide a refere
`ref` that is ensured to survive the resampling step.
"""
function resample_propagate!(
rng::Random.AbstractRNG,
pc::ParticleContainer,
randcat = resample,
randcat = resample_systematic,
ref::Union{Particle, Nothing} = nothing;
weights = getweights(pc)
)
Expand All @@ -187,7 +188,7 @@ function resample_propagate!(
# sample ancestor indices
n = length(pc)
nresamples = ref === nothing ? n : n - 1
indx = randcat(weights, nresamples)
indx = randcat(rng, weights, nresamples)

# count number of children for each particle
num_children = zeros(Int, n)
Expand Down Expand Up @@ -230,6 +231,7 @@ function resample_propagate!(
end

function resample_propagate!(
rng::Random.AbstractRNG,
pc::ParticleContainer,
resampler::ResampleWithESSThreshold,
ref::Union{Particle,Nothing} = nothing;
Expand All @@ -239,7 +241,7 @@ function resample_propagate!(
ess = inv(sum(abs2, weights))

if ess resampler.threshold * length(pc)
resample_propagate!(pc, resampler.resampler, ref; weights = weights)
resample_propagate!(rng, pc, resampler.resampler, ref; weights = weights)
end

pc
Expand Down Expand Up @@ -292,7 +294,7 @@ function reweight!(pc::ParticleContainer)
end

"""
sweep!(pc::ParticleContainer, resampler)
sweep!(rng, pc::ParticleContainer, resampler)
Perform a particle sweep and return an unbiased estimate of the log evidence.
Expand All @@ -303,11 +305,11 @@ The resampling steps use the given `resampler`.
Del Moral, P., Doucet, A., & Jasra, A. (2006). Sequential monte carlo samplers.
Journal of the Royal Statistical Society: Series B (Statistical Methodology), 68(3), 411-436.
"""
function sweep!(pc::ParticleContainer, resampler)
function sweep!(rng::Random.AbstractRNG, pc::ParticleContainer, resampler)
# Initial step:

# Resample and propagate particles.
resample_propagate!(pc, resampler)
resample_propagate!(rng, pc, resampler)

# Compute the current normalizing constant ``Z₀`` of the unnormalized logarithmic
# weights.
Expand All @@ -317,7 +319,7 @@ function sweep!(pc::ParticleContainer, resampler)
logZ0 = logZ(pc)

# Reweight the particles by including the first observation ``y₁``.
isdone = reweight!(pc)
isdone = reweight!(rng, pc)

# Compute the normalizing constant ``Z₁`` after reweighting.
logZ1 = logZ(pc)
Expand All @@ -328,14 +330,14 @@ function sweep!(pc::ParticleContainer, resampler)
# For observations ``y₂, …, yₜ``:
while !isdone
# Resample and propagate particles.
resample_propagate!(pc, resampler)
resample_propagate!(rng, pc, resampler)

# Compute the current normalizing constant ``Z₀`` of the unnormalized logarithmic
# weights.
logZ0 = logZ(pc)

# Reweight the particles by including the next observation ``yₜ``.
isdone = reweight!(pc)
isdone = reweight!(rng, pc)

# Compute the normalizing constant ``Z₁`` after reweighting.
logZ1 = logZ(pc)
Expand Down
49 changes: 27 additions & 22 deletions src/resampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,33 +12,33 @@ struct ResampleWithESSThreshold{R, T<:Real}
threshold::T
end

function ResampleWithESSThreshold(resampler = resample)
function ResampleWithESSThreshold(resampler = resample_systematic)
ResampleWithESSThreshold(resampler, 0.5)
end

# More stable, faster version of rand(Categorical)
function randcat(p::AbstractVector{<:Real})
function randcat(rng::Random.AbstractRNG, p::AbstractVector{<:Real})
T = eltype(p)
r = rand(T)
r = rand(rng, T)
cp = p[1]
s = 1
for j in eachindex(p)
r -= p[j]
if r <= zero(T)
s = j
break
end
n = length(p)
while cp <= r && s < n
@inbounds cp += p[s += 1]
end
return s
end

function resample_multinomial(
rng::Random.AbstractRNG,
w::AbstractVector{<:Real},
num_particles::Integer = length(w),
)
return rand(Distributions.sampler(Distributions.Categorical(w)), num_particles)
return rand(rng, Distributions.sampler(Distributions.Categorical(w)), num_particles)
end

function resample_residual(
rng::Random.AbstractRNG,
w::AbstractVector{<:Real},
num_particles::Integer = length(weights),
)
Expand All @@ -57,19 +57,19 @@ function resample_residual(
end
residuals[j] = x - floor_x
end

# sampling from residuals
if i <= num_particles
residuals ./= sum(residuals)
rand!(Distributions.Categorical(residuals), view(indices, i:num_particles))
rand!(rng, Distributions.Categorical(residuals), view(indices, i:num_particles))
end

return indices
end


"""
resample_stratified(weights, n)
resample_stratified(rng, weights, n)
Return a vector of `n` samples `x₁`, ..., `xₙ` from the numbers 1, ..., `length(weights)`,
generated by stratified resampling.
Expand All @@ -80,7 +80,11 @@ are selected according to the multinomial distribution defined by the normalized
i.e., `xᵢ = j` if and only if
``uᵢ \\in [\\sum_{s=1}^{j-1} weights_{s}, \\sum_{s=1}^{j} weights_{s})``.
"""
function resample_stratified(weights::AbstractVector{<:Real}, n::Integer = length(weights))
function resample_stratified(
rng::Random.AbstractRNG,
weights::AbstractVector{<:Real},
n::Integer = length(weights),
)
# check input
m = length(weights)
m > 0 || error("weight vector is empty")
Expand All @@ -93,7 +97,7 @@ function resample_stratified(weights::AbstractVector{<:Real}, n::Integer = lengt
sample = 1
@inbounds for i in 1:n
# sample next `u` (scaled by `n`)
u = oftype(v, i - 1 + rand())
u = oftype(v, i - 1 + rand(rng))

# as long as we have not found the next sample
while v < u
Expand All @@ -114,7 +118,7 @@ function resample_stratified(weights::AbstractVector{<:Real}, n::Integer = lengt
end

"""
resample_systematic(weights, n)
resample_systematic(rng, weights, n)
Return a vector of `n` samples `x₁`, ..., `xₙ` from the numbers 1, ..., `length(weights)`,
generated by systematic resampling.
Expand All @@ -125,14 +129,18 @@ numbers `u₁`, ..., `uₙ` where ``uₖ = (u + k − 1) / n``. Based on these n
normalized `weights`, i.e., `xᵢ = j` if and only if
``uᵢ \\in [\\sum_{s=1}^{j-1} weights_{s}, \\sum_{s=1}^{j} weights_{s})``.
"""
function resample_systematic(weights::AbstractVector{<:Real}, n::Integer = length(weights))
function resample_systematic(
rng::Random.AbstractRNG,
weights::AbstractVector{<:Real},
n::Integer = length(weights),
)
# check input
m = length(weights)
m > 0 || error("weight vector is empty")

# pre-calculations
@inbounds v = n * weights[1]
u = oftype(v, rand())
u = oftype(v, rand(rng))

# find all samples
samples = Array{Int}(undef, n)
Expand All @@ -158,6 +166,3 @@ function resample_systematic(weights::AbstractVector{<:Real}, n::Integer = lengt

return samples
end

# Default resampling scheme
const resample = resample_systematic
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[deps]
Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
Expand Down
2 changes: 1 addition & 1 deletion test/container.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
@test AdvancedPS.logZ(pc) log(sum(exp, 2 .* logps))

# Resample and propagate particles.
AdvancedPS.resample_propagate!(pc)
AdvancedPS.resample_propagate!(Random.GLOBAL_RNG, pc)
@test pc.logWs == zeros(3)
@test AdvancedPS.getweights(pc) == fill(1/3, 3)
@test all(AdvancedPS.getweight(pc, i) == 1/3 for i in 1:3)
Expand Down
15 changes: 7 additions & 8 deletions test/resampling.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
@testset "resampling.jl" begin
D = [0.3, 0.4, 0.3]
num_samples = Int(1e6)
rng = Random.GLOBAL_RNG

resSystematic = AdvancedPS.resample_systematic(D, num_samples )
resStratified = AdvancedPS.resample_stratified(D, num_samples )
resMultinomial= AdvancedPS.resample_multinomial(D, num_samples )
resResidual = AdvancedPS.resample_residual(D, num_samples )
AdvancedPS.resample(D)
resSystematic2= AdvancedPS.resample(D, num_samples )
resSystematic = AdvancedPS.resample_systematic(rng, D, num_samples)
resStratified = AdvancedPS.resample_stratified(rng, D, num_samples)
resMultinomial= AdvancedPS.resample_multinomial(rng, D, num_samples)
resResidual = AdvancedPS.resample_residual(rng, D, num_samples)
AdvancedPS.resample_systematic(rng, D)

@test sum(resSystematic .== 2) (num_samples * 0.4) atol=1e-3*num_samples
@test sum(resSystematic2 .== 2) (num_samples * 0.4) atol=1e-3*num_samples
@test sum(resStratified .== 2) (num_samples * 0.4) atol=1e-3*num_samples
@test sum(resMultinomial .== 2) (num_samples * 0.4) atol=1e-2*num_samples
@test sum(resResidual .== 2) (num_samples * 0.4) atol=1e-2*num_samples
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using AdvancedPS
using Libtask
using Random
using Test

@testset "AdvancedPS.jl" begin
Expand Down

0 comments on commit 4555653

Please sign in to comment.