-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #15 from TuringLang/resampling
Extract resampling methods
- Loading branch information
Showing
8 changed files
with
184 additions
and
179 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
module AdvancedPS | ||
|
||
# Write your package code here. | ||
import Distributions | ||
|
||
include("resampling.jl") | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,159 @@ | ||
#### | ||
#### Resampling schemes for particle filters | ||
#### | ||
|
||
# Some references | ||
# - http://arxiv.org/pdf/1301.4019.pdf | ||
# - http://people.isy.liu.se/rt/schon/Publications/HolSG2006.pdf | ||
# Code adapted from: http://uk.mathworks.com/matlabcentral/fileexchange/24968-resampling-methods-for-particle-filtering | ||
|
||
struct ResampleWithESSThreshold{R, T<:Real} | ||
resampler::R | ||
threshold::T | ||
end | ||
|
||
function ResampleWithESSThreshold(resampler = resample) | ||
ResampleWithESSThreshold(resampler, 0.5) | ||
end | ||
|
||
# Default resampling scheme | ||
function resample(w::AbstractVector{<:Real}, num_particles::Integer=length(w)) | ||
return resample_systematic(w, num_particles) | ||
end | ||
|
||
# More stable, faster version of rand(Categorical) | ||
function randcat(p::AbstractVector{<:Real}) | ||
T = eltype(p) | ||
r = rand(T) | ||
s = 1 | ||
for j in eachindex(p) | ||
r -= p[j] | ||
if r <= zero(T) | ||
s = j | ||
break | ||
end | ||
end | ||
return s | ||
end | ||
|
||
function resample_multinomial(w::AbstractVector{<:Real}, num_particles::Integer) | ||
return rand(Distributions.sampler(Distributions.Categorical(w)), num_particles) | ||
end | ||
|
||
function resample_residual(w::AbstractVector{<:Real}, num_particles::Integer) | ||
# Pre-allocate array for resampled particles | ||
indices = Vector{Int}(undef, num_particles) | ||
|
||
# deterministic assignment | ||
residuals = similar(w) | ||
i = 1 | ||
@inbounds for j in 1:length(w) | ||
x = num_particles * w[j] | ||
floor_x = floor(Int, x) | ||
for k in 1:floor_x | ||
indices[i] = j | ||
i += 1 | ||
end | ||
residuals[j] = x - floor_x | ||
end | ||
|
||
# sampling from residuals | ||
if i <= num_particles | ||
residuals ./= sum(residuals) | ||
rand!(Distributions.Categorical(residuals), view(indices, i:num_particles)) | ||
end | ||
|
||
return indices | ||
end | ||
|
||
|
||
""" | ||
resample_stratified(weights, n) | ||
Return a vector of `n` samples `x₁`, ..., `xₙ` from the numbers 1, ..., `length(weights)`, | ||
generated by stratified resampling. | ||
In stratified resampling `n` ordered random numbers `u₁`, ..., `uₙ` are generated, where | ||
``uₖ \\sim U[(k - 1) / n, k / n)``. Based on these numbers the samples `x₁`, ..., `xₙ` | ||
are selected according to the multinomial distribution defined by the normalized `weights`, | ||
i.e., `xᵢ = j` if and only if | ||
``uᵢ \\in [\\sum_{s=1}^{j-1} weights_{s}, \\sum_{s=1}^{j} weights_{s})``. | ||
""" | ||
function resample_stratified(weights::AbstractVector{<:Real}, n::Integer) | ||
# check input | ||
m = length(weights) | ||
m > 0 || error("weight vector is empty") | ||
|
||
# pre-calculations | ||
@inbounds v = n * weights[1] | ||
|
||
# generate all samples | ||
samples = Array{Int}(undef, n) | ||
sample = 1 | ||
@inbounds for i in 1:n | ||
# sample next `u` (scaled by `n`) | ||
u = oftype(v, i - 1 + rand()) | ||
|
||
# as long as we have not found the next sample | ||
while v < u | ||
# increase and check the sample | ||
sample += 1 | ||
sample > m && | ||
error("sample could not be selected (are the weights normalized?)") | ||
|
||
# update the cumulative sum of weights (scaled by `n`) | ||
v += n * weights[sample] | ||
end | ||
|
||
# save the next sample | ||
samples[i] = sample | ||
end | ||
|
||
return samples | ||
end | ||
|
||
""" | ||
resample_systematic(weights, n) | ||
Return a vector of `n` samples `x₁`, ..., `xₙ` from the numbers 1, ..., `length(weights)`, | ||
generated by systematic resampling. | ||
In systematic resampling a random number ``u \\sim U[0, 1)`` is used to generate `n` ordered | ||
numbers `u₁`, ..., `uₙ` where ``uₖ = (u + k − 1) / n``. Based on these numbers the samples | ||
`x₁`, ..., `xₙ` are selected according to the multinomial distribution defined by the | ||
normalized `weights`, i.e., `xᵢ = j` if and only if | ||
``uᵢ \\in [\\sum_{s=1}^{j-1} weights_{s}, \\sum_{s=1}^{j} weights_{s})``. | ||
""" | ||
function resample_systematic(weights::AbstractVector{<:Real}, n::Integer) | ||
# check input | ||
m = length(weights) | ||
m > 0 || error("weight vector is empty") | ||
|
||
# pre-calculations | ||
@inbounds v = n * weights[1] | ||
u = oftype(v, rand()) | ||
|
||
# find all samples | ||
samples = Array{Int}(undef, n) | ||
sample = 1 | ||
@inbounds for i in 1:n | ||
# as long as we have not found the next sample | ||
while v < u | ||
# increase and check the sample | ||
sample += 1 | ||
sample > m && | ||
error("sample could not be selected (are the weights normalized?)") | ||
|
||
# update the cumulative sum of weights (scaled by `n`) | ||
v += n * weights[sample] | ||
end | ||
|
||
# save the next sample | ||
samples[i] = sample | ||
|
||
# update `u` | ||
u += one(u) | ||
end | ||
|
||
return samples | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
@testset "resampling.jl" begin | ||
D = [0.3, 0.4, 0.3] | ||
num_samples = Int(1e6) | ||
|
||
resSystematic = AdvancedPS.resample_systematic(D, num_samples ) | ||
resStratified = AdvancedPS.resample_stratified(D, num_samples ) | ||
resMultinomial= AdvancedPS.resample_multinomial(D, num_samples ) | ||
resResidual = AdvancedPS.resample_residual(D, num_samples ) | ||
AdvancedPS.resample(D) | ||
resSystematic2= AdvancedPS.resample(D, num_samples ) | ||
|
||
@test sum(resSystematic .== 2) ≈ (num_samples * 0.4) atol=1e-3*num_samples | ||
@test sum(resSystematic2 .== 2) ≈ (num_samples * 0.4) atol=1e-3*num_samples | ||
@test sum(resStratified .== 2) ≈ (num_samples * 0.4) atol=1e-3*num_samples | ||
@test sum(resMultinomial .== 2) ≈ (num_samples * 0.4) atol=1e-2*num_samples | ||
@test sum(resResidual .== 2) ≈ (num_samples * 0.4) atol=1e-2*num_samples | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters