Skip to content

Commit

Permalink
Merge pull request #19 from TuringLang/resampling_rng
Browse files Browse the repository at this point in the history
Add RNG to resampling methods
  • Loading branch information
yebai authored Dec 5, 2020
2 parents c87cba7 + fbea67e commit 487a943
Show file tree
Hide file tree
Showing 8 changed files with 53 additions and 43 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
name = "AdvancedPS"
uuid = "576499cb-2369-40b2-a588-c64705576edc"
authors = ["TuringLang"]
version = "0.1.0"
version = "0.2.0"

[deps]
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"

[compat]
Expand Down
1 change: 1 addition & 0 deletions src/AdvancedPS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module AdvancedPS

import Distributions
import Libtask
import Random
import StatsFuns

include("resampling.jl")
Expand Down
24 changes: 13 additions & 11 deletions src/container.jl
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,8 @@ function effectiveSampleSize(pc::ParticleContainer)
end

"""
resample_propagate!(pc::ParticleContainer[, randcat = resample, ref = nothing;
weights = getweights(pc)])
resample_propagate!(rng, pc::ParticleContainer[, randcat = resample_systematic,
ref = nothing; weights = getweights(pc)])
Resample and propagate the particles in `pc`.
Expand All @@ -176,8 +176,9 @@ of the particle `weights`. For Particle Gibbs sampling, one can provide a refere
`ref` that is ensured to survive the resampling step.
"""
function resample_propagate!(
rng::Random.AbstractRNG,
pc::ParticleContainer,
randcat = resample,
randcat = resample_systematic,
ref::Union{Particle, Nothing} = nothing;
weights = getweights(pc)
)
Expand All @@ -187,7 +188,7 @@ function resample_propagate!(
# sample ancestor indices
n = length(pc)
nresamples = ref === nothing ? n : n - 1
indx = randcat(weights, nresamples)
indx = randcat(rng, weights, nresamples)

# count number of children for each particle
num_children = zeros(Int, n)
Expand Down Expand Up @@ -230,6 +231,7 @@ function resample_propagate!(
end

function resample_propagate!(
rng::Random.AbstractRNG,
pc::ParticleContainer,
resampler::ResampleWithESSThreshold,
ref::Union{Particle,Nothing} = nothing;
Expand All @@ -239,7 +241,7 @@ function resample_propagate!(
ess = inv(sum(abs2, weights))

if ess resampler.threshold * length(pc)
resample_propagate!(pc, resampler.resampler, ref; weights = weights)
resample_propagate!(rng, pc, resampler.resampler, ref; weights = weights)
end

pc
Expand Down Expand Up @@ -292,7 +294,7 @@ function reweight!(pc::ParticleContainer)
end

"""
sweep!(pc::ParticleContainer, resampler)
sweep!(rng, pc::ParticleContainer, resampler)
Perform a particle sweep and return an unbiased estimate of the log evidence.
Expand All @@ -303,11 +305,11 @@ The resampling steps use the given `resampler`.
Del Moral, P., Doucet, A., & Jasra, A. (2006). Sequential monte carlo samplers.
Journal of the Royal Statistical Society: Series B (Statistical Methodology), 68(3), 411-436.
"""
function sweep!(pc::ParticleContainer, resampler)
function sweep!(rng::Random.AbstractRNG, pc::ParticleContainer, resampler)
# Initial step:

# Resample and propagate particles.
resample_propagate!(pc, resampler)
resample_propagate!(rng, pc, resampler)

# Compute the current normalizing constant ``Z₀`` of the unnormalized logarithmic
# weights.
Expand All @@ -317,7 +319,7 @@ function sweep!(pc::ParticleContainer, resampler)
logZ0 = logZ(pc)

# Reweight the particles by including the first observation ``y₁``.
isdone = reweight!(pc)
isdone = reweight!(rng, pc)

# Compute the normalizing constant ``Z₁`` after reweighting.
logZ1 = logZ(pc)
Expand All @@ -328,14 +330,14 @@ function sweep!(pc::ParticleContainer, resampler)
# For observations ``y₂, …, yₜ``:
while !isdone
# Resample and propagate particles.
resample_propagate!(pc, resampler)
resample_propagate!(rng, pc, resampler)

# Compute the current normalizing constant ``Z₀`` of the unnormalized logarithmic
# weights.
logZ0 = logZ(pc)

# Reweight the particles by including the next observation ``yₜ``.
isdone = reweight!(pc)
isdone = reweight!(rng, pc)

# Compute the normalizing constant ``Z₁`` after reweighting.
logZ1 = logZ(pc)
Expand Down
49 changes: 27 additions & 22 deletions src/resampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,33 +12,33 @@ struct ResampleWithESSThreshold{R, T<:Real}
threshold::T
end

function ResampleWithESSThreshold(resampler = resample)
function ResampleWithESSThreshold(resampler = resample_systematic)
ResampleWithESSThreshold(resampler, 0.5)
end

# More stable, faster version of rand(Categorical)
function randcat(p::AbstractVector{<:Real})
function randcat(rng::Random.AbstractRNG, p::AbstractVector{<:Real})
T = eltype(p)
r = rand(T)
r = rand(rng, T)
cp = p[1]
s = 1
for j in eachindex(p)
r -= p[j]
if r <= zero(T)
s = j
break
end
n = length(p)
while cp <= r && s < n
@inbounds cp += p[s += 1]
end
return s
end

function resample_multinomial(
rng::Random.AbstractRNG,
w::AbstractVector{<:Real},
num_particles::Integer = length(w),
)
return rand(Distributions.sampler(Distributions.Categorical(w)), num_particles)
return rand(rng, Distributions.sampler(Distributions.Categorical(w)), num_particles)
end

function resample_residual(
rng::Random.AbstractRNG,
w::AbstractVector{<:Real},
num_particles::Integer = length(weights),
)
Expand All @@ -57,19 +57,19 @@ function resample_residual(
end
residuals[j] = x - floor_x
end

# sampling from residuals
if i <= num_particles
residuals ./= sum(residuals)
rand!(Distributions.Categorical(residuals), view(indices, i:num_particles))
rand!(rng, Distributions.Categorical(residuals), view(indices, i:num_particles))
end

return indices
end


"""
resample_stratified(weights, n)
resample_stratified(rng, weights, n)
Return a vector of `n` samples `x₁`, ..., `xₙ` from the numbers 1, ..., `length(weights)`,
generated by stratified resampling.
Expand All @@ -80,7 +80,11 @@ are selected according to the multinomial distribution defined by the normalized
i.e., `xᵢ = j` if and only if
``uᵢ \\in [\\sum_{s=1}^{j-1} weights_{s}, \\sum_{s=1}^{j} weights_{s})``.
"""
function resample_stratified(weights::AbstractVector{<:Real}, n::Integer = length(weights))
function resample_stratified(
rng::Random.AbstractRNG,
weights::AbstractVector{<:Real},
n::Integer = length(weights),
)
# check input
m = length(weights)
m > 0 || error("weight vector is empty")
Expand All @@ -93,7 +97,7 @@ function resample_stratified(weights::AbstractVector{<:Real}, n::Integer = lengt
sample = 1
@inbounds for i in 1:n
# sample next `u` (scaled by `n`)
u = oftype(v, i - 1 + rand())
u = oftype(v, i - 1 + rand(rng))

# as long as we have not found the next sample
while v < u
Expand All @@ -114,7 +118,7 @@ function resample_stratified(weights::AbstractVector{<:Real}, n::Integer = lengt
end

"""
resample_systematic(weights, n)
resample_systematic(rng, weights, n)
Return a vector of `n` samples `x₁`, ..., `xₙ` from the numbers 1, ..., `length(weights)`,
generated by systematic resampling.
Expand All @@ -125,14 +129,18 @@ numbers `u₁`, ..., `uₙ` where ``uₖ = (u + k − 1) / n``. Based on these n
normalized `weights`, i.e., `xᵢ = j` if and only if
``uᵢ \\in [\\sum_{s=1}^{j-1} weights_{s}, \\sum_{s=1}^{j} weights_{s})``.
"""
function resample_systematic(weights::AbstractVector{<:Real}, n::Integer = length(weights))
function resample_systematic(
rng::Random.AbstractRNG,
weights::AbstractVector{<:Real},
n::Integer = length(weights),
)
# check input
m = length(weights)
m > 0 || error("weight vector is empty")

# pre-calculations
@inbounds v = n * weights[1]
u = oftype(v, rand())
u = oftype(v, rand(rng))

# find all samples
samples = Array{Int}(undef, n)
Expand All @@ -158,6 +166,3 @@ function resample_systematic(weights::AbstractVector{<:Real}, n::Integer = lengt

return samples
end

# Default resampling scheme
const resample = resample_systematic
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[deps]
Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
Expand Down
2 changes: 1 addition & 1 deletion test/container.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
@test AdvancedPS.logZ(pc) log(sum(exp, 2 .* logps))

# Resample and propagate particles.
AdvancedPS.resample_propagate!(pc)
AdvancedPS.resample_propagate!(Random.GLOBAL_RNG, pc)
@test pc.logWs == zeros(3)
@test AdvancedPS.getweights(pc) == fill(1/3, 3)
@test all(AdvancedPS.getweight(pc, i) == 1/3 for i in 1:3)
Expand Down
15 changes: 7 additions & 8 deletions test/resampling.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
@testset "resampling.jl" begin
D = [0.3, 0.4, 0.3]
num_samples = Int(1e6)
rng = Random.GLOBAL_RNG

resSystematic = AdvancedPS.resample_systematic(D, num_samples )
resStratified = AdvancedPS.resample_stratified(D, num_samples )
resMultinomial= AdvancedPS.resample_multinomial(D, num_samples )
resResidual = AdvancedPS.resample_residual(D, num_samples )
AdvancedPS.resample(D)
resSystematic2= AdvancedPS.resample(D, num_samples )
resSystematic = AdvancedPS.resample_systematic(rng, D, num_samples)
resStratified = AdvancedPS.resample_stratified(rng, D, num_samples)
resMultinomial= AdvancedPS.resample_multinomial(rng, D, num_samples)
resResidual = AdvancedPS.resample_residual(rng, D, num_samples)
AdvancedPS.resample_systematic(rng, D)

@test sum(resSystematic .== 2) (num_samples * 0.4) atol=1e-3*num_samples
@test sum(resSystematic2 .== 2) (num_samples * 0.4) atol=1e-3*num_samples
@test sum(resStratified .== 2) (num_samples * 0.4) atol=1e-3*num_samples
@test sum(resMultinomial .== 2) (num_samples * 0.4) atol=1e-2*num_samples
@test sum(resResidual .== 2) (num_samples * 0.4) atol=1e-2*num_samples
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using AdvancedPS
using Libtask
using Random
using Test

@testset "AdvancedPS.jl" begin
Expand Down

2 comments on commit 487a943

@devmotion
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/36675

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.2.0 -m "<description of version>" 487a943192ad40a0ab2a2de3e4a0469203ac2b20
git push origin v0.2.0

Please sign in to comment.