diff --git a/Project.toml b/Project.toml index 81c08f78..c7ddbd00 100644 --- a/Project.toml +++ b/Project.toml @@ -1,11 +1,12 @@ name = "AdvancedPS" uuid = "576499cb-2369-40b2-a588-c64705576edc" authors = ["TuringLang"] -version = "0.1.0" +version = "0.2.0" [deps] Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" [compat] diff --git a/src/AdvancedPS.jl b/src/AdvancedPS.jl index cb098ec6..4805a767 100644 --- a/src/AdvancedPS.jl +++ b/src/AdvancedPS.jl @@ -2,6 +2,7 @@ module AdvancedPS import Distributions import Libtask +import Random import StatsFuns include("resampling.jl") diff --git a/src/container.jl b/src/container.jl index 2a4d13d5..ac20d477 100644 --- a/src/container.jl +++ b/src/container.jl @@ -166,8 +166,8 @@ function effectiveSampleSize(pc::ParticleContainer) end """ - resample_propagate!(pc::ParticleContainer[, randcat = resample, ref = nothing; - weights = getweights(pc)]) + resample_propagate!(rng, pc::ParticleContainer[, randcat = resample_systematic, + ref = nothing; weights = getweights(pc)]) Resample and propagate the particles in `pc`. @@ -176,8 +176,9 @@ of the particle `weights`. For Particle Gibbs sampling, one can provide a refere `ref` that is ensured to survive the resampling step. """ function resample_propagate!( + rng::Random.AbstractRNG, pc::ParticleContainer, - randcat = resample, + randcat = resample_systematic, ref::Union{Particle, Nothing} = nothing; weights = getweights(pc) ) @@ -187,7 +188,7 @@ function resample_propagate!( # sample ancestor indices n = length(pc) nresamples = ref === nothing ? n : n - 1 - indx = randcat(weights, nresamples) + indx = randcat(rng, weights, nresamples) # count number of children for each particle num_children = zeros(Int, n) @@ -230,6 +231,7 @@ function resample_propagate!( end function resample_propagate!( + rng::Random.AbstractRNG, pc::ParticleContainer, resampler::ResampleWithESSThreshold, ref::Union{Particle,Nothing} = nothing; @@ -239,7 +241,7 @@ function resample_propagate!( ess = inv(sum(abs2, weights)) if ess ≤ resampler.threshold * length(pc) - resample_propagate!(pc, resampler.resampler, ref; weights = weights) + resample_propagate!(rng, pc, resampler.resampler, ref; weights = weights) end pc @@ -292,7 +294,7 @@ function reweight!(pc::ParticleContainer) end """ - sweep!(pc::ParticleContainer, resampler) + sweep!(rng, pc::ParticleContainer, resampler) Perform a particle sweep and return an unbiased estimate of the log evidence. @@ -303,11 +305,11 @@ The resampling steps use the given `resampler`. Del Moral, P., Doucet, A., & Jasra, A. (2006). Sequential monte carlo samplers. Journal of the Royal Statistical Society: Series B (Statistical Methodology), 68(3), 411-436. """ -function sweep!(pc::ParticleContainer, resampler) +function sweep!(rng::Random.AbstractRNG, pc::ParticleContainer, resampler) # Initial step: # Resample and propagate particles. - resample_propagate!(pc, resampler) + resample_propagate!(rng, pc, resampler) # Compute the current normalizing constant ``Z₀`` of the unnormalized logarithmic # weights. @@ -317,7 +319,7 @@ function sweep!(pc::ParticleContainer, resampler) logZ0 = logZ(pc) # Reweight the particles by including the first observation ``y₁``. - isdone = reweight!(pc) + isdone = reweight!(rng, pc) # Compute the normalizing constant ``Z₁`` after reweighting. logZ1 = logZ(pc) @@ -328,14 +330,14 @@ function sweep!(pc::ParticleContainer, resampler) # For observations ``y₂, …, yₜ``: while !isdone # Resample and propagate particles. - resample_propagate!(pc, resampler) + resample_propagate!(rng, pc, resampler) # Compute the current normalizing constant ``Z₀`` of the unnormalized logarithmic # weights. logZ0 = logZ(pc) # Reweight the particles by including the next observation ``yₜ``. - isdone = reweight!(pc) + isdone = reweight!(rng, pc) # Compute the normalizing constant ``Z₁`` after reweighting. logZ1 = logZ(pc) diff --git a/src/resampling.jl b/src/resampling.jl index 964ef5ee..4bb21046 100644 --- a/src/resampling.jl +++ b/src/resampling.jl @@ -12,33 +12,33 @@ struct ResampleWithESSThreshold{R, T<:Real} threshold::T end -function ResampleWithESSThreshold(resampler = resample) +function ResampleWithESSThreshold(resampler = resample_systematic) ResampleWithESSThreshold(resampler, 0.5) end # More stable, faster version of rand(Categorical) -function randcat(p::AbstractVector{<:Real}) +function randcat(rng::Random.AbstractRNG, p::AbstractVector{<:Real}) T = eltype(p) - r = rand(T) + r = rand(rng, T) + cp = p[1] s = 1 - for j in eachindex(p) - r -= p[j] - if r <= zero(T) - s = j - break - end + n = length(p) + while cp <= r && s < n + @inbounds cp += p[s += 1] end return s end function resample_multinomial( + rng::Random.AbstractRNG, w::AbstractVector{<:Real}, num_particles::Integer = length(w), ) - return rand(Distributions.sampler(Distributions.Categorical(w)), num_particles) + return rand(rng, Distributions.sampler(Distributions.Categorical(w)), num_particles) end function resample_residual( + rng::Random.AbstractRNG, w::AbstractVector{<:Real}, num_particles::Integer = length(weights), ) @@ -57,19 +57,19 @@ function resample_residual( end residuals[j] = x - floor_x end - + # sampling from residuals if i <= num_particles residuals ./= sum(residuals) - rand!(Distributions.Categorical(residuals), view(indices, i:num_particles)) + rand!(rng, Distributions.Categorical(residuals), view(indices, i:num_particles)) end - + return indices end """ - resample_stratified(weights, n) + resample_stratified(rng, weights, n) Return a vector of `n` samples `x₁`, ..., `xₙ` from the numbers 1, ..., `length(weights)`, generated by stratified resampling. @@ -80,7 +80,11 @@ are selected according to the multinomial distribution defined by the normalized i.e., `xᵢ = j` if and only if ``uᵢ \\in [\\sum_{s=1}^{j-1} weights_{s}, \\sum_{s=1}^{j} weights_{s})``. """ -function resample_stratified(weights::AbstractVector{<:Real}, n::Integer = length(weights)) +function resample_stratified( + rng::Random.AbstractRNG, + weights::AbstractVector{<:Real}, + n::Integer = length(weights), +) # check input m = length(weights) m > 0 || error("weight vector is empty") @@ -93,7 +97,7 @@ function resample_stratified(weights::AbstractVector{<:Real}, n::Integer = lengt sample = 1 @inbounds for i in 1:n # sample next `u` (scaled by `n`) - u = oftype(v, i - 1 + rand()) + u = oftype(v, i - 1 + rand(rng)) # as long as we have not found the next sample while v < u @@ -114,7 +118,7 @@ function resample_stratified(weights::AbstractVector{<:Real}, n::Integer = lengt end """ - resample_systematic(weights, n) + resample_systematic(rng, weights, n) Return a vector of `n` samples `x₁`, ..., `xₙ` from the numbers 1, ..., `length(weights)`, generated by systematic resampling. @@ -125,14 +129,18 @@ numbers `u₁`, ..., `uₙ` where ``uₖ = (u + k − 1) / n``. Based on these n normalized `weights`, i.e., `xᵢ = j` if and only if ``uᵢ \\in [\\sum_{s=1}^{j-1} weights_{s}, \\sum_{s=1}^{j} weights_{s})``. """ -function resample_systematic(weights::AbstractVector{<:Real}, n::Integer = length(weights)) +function resample_systematic( + rng::Random.AbstractRNG, + weights::AbstractVector{<:Real}, + n::Integer = length(weights), +) # check input m = length(weights) m > 0 || error("weight vector is empty") # pre-calculations @inbounds v = n * weights[1] - u = oftype(v, rand()) + u = oftype(v, rand(rng)) # find all samples samples = Array{Int}(undef, n) @@ -158,6 +166,3 @@ function resample_systematic(weights::AbstractVector{<:Real}, n::Integer = lengt return samples end - -# Default resampling scheme -const resample = resample_systematic \ No newline at end of file diff --git a/test/Project.toml b/test/Project.toml index 87216cb8..3aacf35c 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,5 +1,6 @@ [deps] Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] diff --git a/test/container.jl b/test/container.jl index 6a4b67a2..2b2f83c8 100644 --- a/test/container.jl +++ b/test/container.jl @@ -47,7 +47,7 @@ @test AdvancedPS.logZ(pc) ≈ log(sum(exp, 2 .* logps)) # Resample and propagate particles. - AdvancedPS.resample_propagate!(pc) + AdvancedPS.resample_propagate!(Random.GLOBAL_RNG, pc) @test pc.logWs == zeros(3) @test AdvancedPS.getweights(pc) == fill(1/3, 3) @test all(AdvancedPS.getweight(pc, i) == 1/3 for i in 1:3) diff --git a/test/resampling.jl b/test/resampling.jl index ee1ff953..f72f07bf 100644 --- a/test/resampling.jl +++ b/test/resampling.jl @@ -1,17 +1,16 @@ @testset "resampling.jl" begin D = [0.3, 0.4, 0.3] num_samples = Int(1e6) + rng = Random.GLOBAL_RNG - resSystematic = AdvancedPS.resample_systematic(D, num_samples ) - resStratified = AdvancedPS.resample_stratified(D, num_samples ) - resMultinomial= AdvancedPS.resample_multinomial(D, num_samples ) - resResidual = AdvancedPS.resample_residual(D, num_samples ) - AdvancedPS.resample(D) - resSystematic2= AdvancedPS.resample(D, num_samples ) + resSystematic = AdvancedPS.resample_systematic(rng, D, num_samples) + resStratified = AdvancedPS.resample_stratified(rng, D, num_samples) + resMultinomial= AdvancedPS.resample_multinomial(rng, D, num_samples) + resResidual = AdvancedPS.resample_residual(rng, D, num_samples) + AdvancedPS.resample_systematic(rng, D) @test sum(resSystematic .== 2) ≈ (num_samples * 0.4) atol=1e-3*num_samples - @test sum(resSystematic2 .== 2) ≈ (num_samples * 0.4) atol=1e-3*num_samples @test sum(resStratified .== 2) ≈ (num_samples * 0.4) atol=1e-3*num_samples @test sum(resMultinomial .== 2) ≈ (num_samples * 0.4) atol=1e-2*num_samples @test sum(resResidual .== 2) ≈ (num_samples * 0.4) atol=1e-2*num_samples -end \ No newline at end of file +end diff --git a/test/runtests.jl b/test/runtests.jl index f01dd2c1..2ceb1147 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,6 @@ using AdvancedPS using Libtask +using Random using Test @testset "AdvancedPS.jl" begin