Skip to content

Commit

Permalink
fix sampling from GammaShapeRate
Browse files Browse the repository at this point in the history
  • Loading branch information
bvdmitri committed Jul 11, 2023
1 parent 5c9ca66 commit b77eeb1
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 0 deletions.
6 changes: 6 additions & 0 deletions src/distributions/gamma_shape_rate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,12 @@ function Random.rand(rng::AbstractRNG, dist::GammaShapeRate)
return convert(eltype(dist), rand(rng, convert(GammaShapeScale, dist)))
end

# This method is identical to the next one, but we need it
# See issue https://github.com/biaslab/ReactiveMP.jl/issues/337
function Random.rand(rng::AbstractRNG, dist::GammaShapeRate, n::Int64)
return convert(AbstractArray{eltype(dist)}, rand(rng, convert(GammaShapeScale, dist), n))
end

function Random.rand(rng::AbstractRNG, dist::GammaShapeRate, n::Integer)
return convert(AbstractArray{eltype(dist)}, rand(rng, convert(GammaShapeScale, dist), n))
end
Expand Down
12 changes: 12 additions & 0 deletions test/distributions/test_gamma.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using Test
using ReactiveMP
using Random
using Distributions
using StableRNGs

import SpecialFunctions: loggamma
import ReactiveMP: xtlog
Expand Down Expand Up @@ -41,6 +42,17 @@ import ReactiveMP: xtlog
@test eltype(GammaShapeRate(1.0f0, 2.0f0)) === Float32
end

@testset "rand" begin
rng = StableRNG(42)
for shape in (1, 2), scale in (1, 2)
for dist in (GammaShapeScale(shape, scale), GammaShapeRate(shape, inv(scale)))
@test rand(rng, dist) > 0.0
@test all(>(0.0), rand(rng, dist, 10))
@test all(>(0.0), rand!(rng, dist, Vector{Float64}(undef, 10)))
end
end
end

@testset "vague" begin
vague(GammaShapeScale) == Gamma(1.0, ReactiveMP.huge)
vague(GammaShapeRate) == Gamma(1.0, ReactiveMP.tiny)
Expand Down

0 comments on commit b77eeb1

Please sign in to comment.