Skip to content

Commit

Permalink
Merge pull request #313 from simonbyrne/rgamma
Browse files Browse the repository at this point in the history
reconfigure samplers, new gamma sampler
  • Loading branch information
lindahua committed Nov 20, 2014
2 parents 4c55ac1 + ed7b00d commit f10f3dd
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 57 deletions.
22 changes: 15 additions & 7 deletions perf/samplers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,23 +87,30 @@ benchmark_exponential() = (

## gamma

import Distributions: GammaRmathSampler, GammaMTSampler
import Distributions: GammaRmathSampler, GammaGDSampler, GammaGSSampler,
GammaMTSampler, GammaIPSampler

getname(::Type{GammaRmathSampler}) = "rmath"
getname(::Type{GammaGDSampler}) = "GD"
getname(::Type{GammaGSSampler}) = "GS"
getname(::Type{GammaMTSampler}) = "MT"
getname(::Type{GammaIPSampler}) = "IP"

benchmark_gamma() = (
make_procs(GammaRmathSampler, GammaMTSampler),
"(α, scale)", [(α, 1.0) for α in [0.5, 1.0, 2.0, 5.0, 20.0]])
benchmark_gamma_hi() = (
make_procs(GammaRmathSampler, GammaMTSampler, GammaGDSampler),
"Dist", [(Gamma(α, 1.0),) for α in [1.5, 2.0, 3.0, 5.0, 20.0]])

benchmark_gamma_lo() = (
make_procs(GammaRmathSampler, GammaGSSampler, GammaIPSampler),
"Dist", [(Gamma(α, 1.0),) for α in [0.1, 0.5, 0.9]])

### main

const dnames = ["categorical",
"binomial",
"poisson",
"exponential",
"gamma"]
"gamma_hi","gamma_lo"]

function printhelp()
println("Require exactly one argument. Usage:")
Expand Down Expand Up @@ -141,10 +148,11 @@ function do_benchmark(dname; verbose::Int=2)
dname == "binomial" ? benchmark_binomial() :
dname == "poisson" ? benchmark_poisson() :
dname == "exponential" ? benchmark_exponential() :
dname == "gamma" ? benchmark_gamma() :
dname == "gamma_hi" ? benchmark_gamma_hi() :
dname == "gamma_lo" ? benchmark_gamma_lo() :
error("benchmarking function for $dname has not been implemented.")

r = run(procs, cfgs; duration=0.2, verbose=verbose)
r = run(procs, cfgs; duration=0.5, verbose=verbose)
println()
show(r; unit=:mps, cfghead=cfghead)
end
Expand Down
2 changes: 1 addition & 1 deletion src/Distributions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -243,11 +243,11 @@ include("functionals.jl")
include("genericfit.jl")

# specific samplers and distributions
include("samplers.jl")
include("univariates.jl")
include("empirical.jl")
include("multivariates.jl")
include("matrixvariates.jl")
include("samplers.jl")

# others
include("truncate.jl")
Expand Down
137 changes: 94 additions & 43 deletions src/samplers/gamma.jl
Original file line number Diff line number Diff line change
@@ -1,25 +1,24 @@

immutable GammaRmathSampler <: Sampleable{Univariate,Continuous}
α::Float64
scale::Float64
d::Gamma
end

rand(s::GammaRmathSampler) =
ccall((:rgamma, "libRmath-julia"), Float64, (Float64, Float64), s.α, s.scale)

rand(s::GammaRmathSampler) =
ccall((:rgamma, "libRmath-julia"), Float64, (Float64, Float64), s.d.shape, s.d.scale)


# "Generating gamma variates by a modified rejection technique"
# J.H. Ahrens, U. Dieter
# Communications of the ACM, Vol 25(1), 1982, pp 47-54
# doi:10.1145/358315.358390

# suitable for scale >= 1.0
# suitable for shape >= 1.0

immutable GammaGDSampler <: Sampleable{Univariate,Continuous}
a::Float64
s2::Float64
s::Float64
i2s::Float64
d::Float64
q0::Float64
b::Float64
Expand All @@ -28,10 +27,13 @@ immutable GammaGDSampler <: Sampleable{Univariate,Continuous}
scale::Float64
end

function GammaGDSampler(a::Float64,scale::Float64)
function GammaGDSampler(g::Gamma)
a = g.shape

# Step 1
s2 = a-0.5
s = sqrt(s2)
i2s = 0.5/s
d = 5.656854249492381 - 12.0s # 4*sqrt(2) - 12s

# Step 4
Expand Down Expand Up @@ -61,7 +63,7 @@ function GammaGDSampler(a::Float64,scale::Float64)
c = 0.1515/s
end

GammaGDSampler(a,s2,s,d,q0,b,σ,c,scale)
GammaGDSampler(a,s2,s,i2s,d,q0,b,σ,c,g.scale)
end

function rand(s::GammaGDSampler)
Expand All @@ -77,7 +79,7 @@ function rand(s::GammaGDSampler)
# Step 5
if x > 0.0
# Step 6
v = t/(2.0*s.s)
v = t*s.i2s
if abs(v) > 0.25
q = s.q0 - s.s*t + 0.25*t*t + 2.0*s.s2*log1p(v)
else
Expand Down Expand Up @@ -107,7 +109,7 @@ function rand(s::GammaGDSampler)
t < -0.718_744_837_717_19 && @goto step8

# Step 10
v = t/(2.0*s.s)
v = t*s.i2s
if abs(v) > 0.25
q = s.q0 - s.s*t + 0.25*t*t + 2.0*s.s2*log1p(v)
else
Expand All @@ -131,61 +133,110 @@ function rand(s::GammaGDSampler)
return x*x*s.scale
end

# A simple method for generating gamma variables - Marsaglia and Tsang (2000)
# http://www.cparity.com/projects/AcmClassification/samples/358414.pdf
# Page 369
# basic simulation loop for pre-computed d and c
#
# "Computer methods for sampling from gamma, beta, poisson and bionomial distributions"
# J.H. Ahrens and U. Dieter
# Computing, 1974, Volume 12(3), pp 223-246
# doi:10.1007/BF02293108

# valid for 0 < shape <= 1
immutable GammaGSSampler <: Sampleable{Univariate,Continuous}
a::Float64
ia::Float64
b::Float64
scale::Float64
end

function GammaGSSampler(d::Gamma)
a = d.shape
ia = 1/d.shape
b = 1.0+0.36787944117144233*d.shape
GammaGSSampler(a,ia,b,d.scale)
end

function rand(s::GammaGSSampler)
while true
# step 1
p = s.b*rand()
e = Base.Random.randmtzig_exprnd()
if p <= 1.0
# step 2
x = exp(log(p)*s.ia)
e < x || return s.scale*x
else
# step 3
x = -log(s.ia*(s.b-p))
e < log(x)*(1.0-s.a) || return s.scale*x
end
end
end


# "A simple method for generating gamma variables"
# G. Marsaglia and W.W. Tsang
# ACM Transactions on Mathematical Software (TOMS), 2000, Volume 26(3), pp. 363-372
# doi:10.1145/358407.358414
# http://www.cparity.com/projects/AcmClassification/samples/358414.pdf

immutable GammaMTSampler <: Sampleable{Univariate,Continuous}
::Float64
d::Float64
c::Float64
κ::Float64
end

function GammaMTSampler::Float64, scale::Float64)
if α >= 1.0
= 1.0
d = α - 1/3
else
= 1.0 / α
d = α + 2/3
end
function GammaMTSampler(g::Gamma)
d = g.shape - 1/3
c = 1.0 / sqrt(9.0 * d)
κ = d * scale
GammaMTSampler(iα, d, c, κ)
κ = d * g.scale
GammaMTSampler(d, c, κ)
end

GammaMTSampler::Float64) = GammaMTSampler(α, 1.0)

function rand(s::GammaMTSampler)
d = s.d
c = s.c
= s.

v = 0.0
while true
x = randn()
v = 1.0 + c * x
v = 1.0 + s.c * x
while v <= 0.0
x = randn()
v = 1.0 + c * x
v = 1.0 + s.c * x
end
v *= (v * v)
u = rand()
x2 = x * x
if u < 1.0 - 0.331 * abs2(x2)
break
return v*s.κ
end
if log(u) < 0.5 * x2 + d * (1.0 - v + log(v))
break
if log(u) < 0.5 * x2 + s.d * (1.0 - v + log(v)) # logmxp1
return v*s.κ
end
end
v *= s.κ
if> 1.0
v *= (rand() ^ iα)
end
return v
end

# Inverse Power sampler
# uses the x*u^(1/a) trick from Marsaglia and Tsang (2000) for when shape < 1
immutable GammaIPSampler{S<:Sampleable{Univariate,Continuous}} <: Sampleable{Univariate,Continuous}
s::S #sampler for Gamma(1+shape,scale)
nia::Float64 #-1/scale
end

function GammaIPSampler{S<:Sampleable}(d::Gamma,::Type{S})
GammaIPSampler(Gamma(1.0+d.shape,d.scale), -1.0/d.shape)
end
GammaIPSampler(d::Gamma) = GammaIPSampler(d,GammaMTSampler)

function rand(s::GammaIPSampler)
x = rand(s.s)
e = Base.Random.randmtzig_exprnd()
x*exp(s.nia*e)
end

# function sampler(d::Gamma)
# if d.shape < 1.0
# # TODO: d.shape = 0.5 : use scaled chisq
# GammaIPSampler(d)
# elseif d.shape == 1.0
# Exponential(d.scale)
# else
# GammaGDSampler(d)
# end
# end

# rand(d::Gamma) = rand(sampler(d))
23 changes: 17 additions & 6 deletions test/samplers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@ import Distributions:
PoissonADSampler,
PoissonCountSampler,
ExponentialSampler,
GammaMTSampler
GammaGDSampler,
GammaGSSampler,
GammaMTSampler,
GammaIPSampler

n_tsamples = 10^6

Expand Down Expand Up @@ -74,11 +77,19 @@ end


## Gamma samplers

for S in [GammaMTSampler]
for pa in [(1.0, 1.0), (2.0, 1.0), (3.0, 1.0), (0.5, 1.0),
(1.0, 2.0), (3.0, 2.0), (0.5, 2.0)]
test_samples(S(pa...), Gamma(pa...), n_tsamples)
# shape >= 1
for S in [GammaGDSampler, GammaMTSampler]
println(" testing $S")
for d in [Gamma(1.0, 1.0), Gamma(2.0, 1.0), Gamma(3.0, 1.0),
Gamma(1.0, 2.0), Gamma(3.0, 2.0), Gamma(100.0, 2.0)]
test_samples(S(d), d, n_tsamples)
end
end

# shape < 1
for S in [GammaGSSampler, GammaIPSampler]
println(" testing $S")
for d in [Gamma(0.1,1.0),Gamma(0.9,1.0)]
test_samples(S(d), d, n_tsamples)
end
end

0 comments on commit f10f3dd

Please sign in to comment.