Skip to content

Commit

Permalink
Merge pull request #338 from biaslab/half_normal
Browse files Browse the repository at this point in the history
Add Half-normal node
  • Loading branch information
bvdmitri authored Jul 12, 2023
2 parents 8f43259 + a234d1b commit 7b81d41
Show file tree
Hide file tree
Showing 12 changed files with 147 additions and 3 deletions.
1 change: 1 addition & 0 deletions src/ReactiveMP.jl
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ include("nodes/bifm.jl")
include("nodes/bifm_helper.jl")
include("nodes/probit.jl")
include("nodes/poisson.jl")
include("nodes/half_normal.jl")

include("nodes/flow/flow.jl")
include("nodes/delta/delta.jl")
Expand Down
22 changes: 22 additions & 0 deletions src/distributions/gamma.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,28 @@ function compute_logscale(new_dist::GammaDistributionsFamily, left_dist::GammaDi
return loggamma(ay) - loggamma(ax) - loggamma(az) + ax * log(bx) + az * log(bz) - ay * log(by)
end

prod_analytical_rule(::Type{<:Truncated{<:Normal}}, ::Type{<:GammaDistributionsFamily}) = ProdAnalyticalRuleAvailable()
prod_analytical_rule(::Type{<:GammaDistributionsFamily}, ::Type{<:Truncated{<:Normal}}) = ProdAnalyticalRuleAvailable()

prod(::ProdAnalytical, left::GammaDistributionsFamily, right::Truncated{<:Normal}) = prod(ProdAnalytical(), right, left)

function prod(::ProdAnalytical, left::Truncated{<:Normal}, right::GammaDistributionsFamily)
@assert (left.lower zero(left.lower) && isinf(left.upper)) "Truncated{Normal} * Gamma only implemented for Truncated{Normal}(0, Inf)"

samples = rand(MersenneTwister(123), left, 1000)
zeronum = zero(eltype(samples))

sx, xlogx, tw = mapreduce(.+, samples; init = (zeronum, zeronum, zeronum)) do sample
w = pdf(right, sample)
return (w * sample, w * log(sample), w)
end

statistics = Distributions.GammaStats(sx, xlogx, tw)
fit = Distributions.fit_mle(Gamma, statistics, alpha0 = shape(right))

return convert(typeof(right), fit)
end

## Friendly functions

function logpdf_sample_friendly(dist::GammaDistributionsFamily)
Expand Down
12 changes: 12 additions & 0 deletions src/distributions/gamma_inverse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,15 @@ function mean(::typeof(inv), dist::GammaInverse)
θ = scale(dist)
return α / θ
end

prod_analytical_rule(::Type{<:Truncated{<:Normal}}, ::Type{<:GammaInverse}) = ProdAnalyticalRuleAvailable()
prod_analytical_rule(::Type{<:GammaInverse}, ::Type{<:Truncated{<:Normal}}) = ProdAnalyticalRuleAvailable()

prod(::ProdAnalytical, left::GammaInverse, right::Truncated{<:Normal}) = prod(ProdAnalytical(), right, left)

function prod(::ProdAnalytical, left::Truncated{<:Normal}, right::GammaInverse)
@assert (left.lower zero(left.lower) && isinf(left.upper)) "Truncated{Normal} * GammaInverse only implemented for Truncated{Normal}(0, Inf)"
α, θ = shape(right), scale(right)
Γ = prod(ProdAnalytical(), left, Gamma(α, inv(θ)))
return InverseGamma(shape(Γ), inv(scale(Γ)))
end
5 changes: 5 additions & 0 deletions src/distributions/normal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,11 @@ Base.isapprox(left::JointNormal, right::JointNormal; kwargs...) = isapprox(left.
"""An alias for the [`JointNormal`](@ref)."""
const JointGaussian = JointNormal

# Half-Normal related
function convert_paramfloattype(::Type{T}, distribution::Truncated{<:Normal}) where {T}
return Truncated(convert_paramfloattype(T, distribution.untruncated), convert(T, distribution.lower), convert(T, distribution.upper))
end

# Variate forms promotion

promote_variate_type(::Type{Univariate}, ::Type{F}) where {F <: UnivariateNormalDistributionsFamily} = F
Expand Down
10 changes: 10 additions & 0 deletions src/nodes/half_normal.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
export HalfNormal

struct HalfNormal end

@node HalfNormal Stochastic [out, (v, aliases = [var, σ²])]

@average_energy HalfNormal (q_out::Any, q_v::Any) = begin
out_mean, out_var = mean_var(q_out)
return (log/ 2) + mean(log, q_v) + mean(inv, q_v) * (out_mean^2 + out_var)) / 2
end
4 changes: 4 additions & 0 deletions src/rules/half_normal/out.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
@rule HalfNormal(:out, Marginalisation) (q_v::PointMass,) = begin
mean_v = mean(q_v)
return Truncated(Normal(zero(eltype(q_v)), sqrt(mean_v)), zero(eltype(q_v)), typemax(float(mean_v)))
end
2 changes: 2 additions & 0 deletions src/rules/prototypes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -171,3 +171,5 @@ include("delta/unscented/marginals.jl")
include("delta/cvi/in.jl")
include("delta/cvi/out.jl")
include("delta/cvi/marginals.jl")

include("half_normal/out.jl")
11 changes: 11 additions & 0 deletions test/distributions/test_gamma.jl
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,17 @@ import ReactiveMP: xtlog
@test prod(ProdAnalytical(), GammaShapeRate(1, 2), GammaShapeScale(1, 2)) == GammaShapeRate(1, 5 / 2)
@test prod(ProdAnalytical(), GammaShapeRate(2, 2), GammaShapeScale(1, 2)) == GammaShapeRate(2, 5 / 2)
@test prod(ProdAnalytical(), GammaShapeRate(2, 2), GammaShapeScale(2, 2)) == GammaShapeRate(3, 5 / 2)

@test_throws AssertionError prod(ProdAnalytical(), GammaShapeRate(1, 1), Truncated(Normal(0.0, 1.0), -1.0, 1.0))
@test_throws AssertionError prod(ProdAnalytical(), Truncated(Normal(0.0, 1.0), -1.0, 1.0), GammaShapeRate(1, 1))
@test_throws AssertionError prod(ProdAnalytical(), GammaShapeScale(1, 1), Truncated(Normal(0.0, 1.0), -1.0, 1.0))
@test_throws AssertionError prod(ProdAnalytical(), Truncated(Normal(0.0, 1.0), -1.0, 1.0), GammaShapeScale(1, 1))

# TODO: these tests should check also check the actual result
@test prod(ProdAnalytical(), GammaShapeRate(1, 1), Truncated(Normal(0.0, 1.0), 0.0, Inf)) isa GammaShapeRate
@test prod(ProdAnalytical(), Truncated(Normal(0.0, 1.0), 0.0, Inf), GammaShapeRate(1, 1)) isa GammaShapeRate
@test prod(ProdAnalytical(), GammaShapeScale(1, 1), Truncated(Normal(0.0, 1.0), 0.0, Inf)) isa GammaShapeScale
@test prod(ProdAnalytical(), Truncated(Normal(0.0, 1.0), 0.0, Inf), GammaShapeScale(1, 1)) isa GammaShapeScale
end
end

Expand Down
7 changes: 7 additions & 0 deletions test/distributions/test_gamma_inverse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,13 @@ using Random
@test prod(ProdAnalytical(), GammaInverse(3.0, 2.0), GammaInverse(2.0, 1.0)) GammaInverse(6.0, 3.0)
@test prod(ProdAnalytical(), GammaInverse(7.0, 1.0), GammaInverse(0.1, 4.5)) GammaInverse(8.1, 5.5)
@test prod(ProdAnalytical(), GammaInverse(1.0, 3.0), GammaInverse(0.2, 0.4)) GammaInverse(2.2, 3.4)

@test_throws AssertionError prod(ProdAnalytical(), GammaInverse(1, 1), Truncated(Normal(0.0, 1.0), -1.0, 1.0))
@test_throws AssertionError prod(ProdAnalytical(), Truncated(Normal(0.0, 1.0), -1.0, 1.0), GammaInverse(1, 1))

# TODO: these tests should check also check the actual result
@test prod(ProdAnalytical(), GammaInverse(1, 1), Truncated(Normal(0.0, 1.0), 0.0, Inf)) isa GammaInverse
@test prod(ProdAnalytical(), Truncated(Normal(0.0, 1.0), 0.0, Inf), GammaInverse(1, 1)) isa GammaInverse
end

# log(θ) - digamma(α)
Expand Down
6 changes: 3 additions & 3 deletions test/nodes/test_gamma_inverse.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
module InverseWishartNodeTest
module InverseGammaNodeTest

using Test
using ReactiveMP
using Random

import ReactiveMP: make_node

@testset "InverseWishartNode" begin
@testset "InverseGammaNode" begin
@testset "Creation" begin
node = make_node(GammaInverse)
@test functionalform(node) === GammaInverse
Expand All @@ -16,7 +16,7 @@ import ReactiveMP: make_node
@test localmarginalnames(node) === (:out_α_θ,)
@test metadata(node) === nothing

node = make_node(InverseWishart, FactorNodeCreationOptions(nothing, 1, nothing))
node = make_node(GammaInverse, FactorNodeCreationOptions(nothing, 1, nothing))
@test metadata(node) === 1
end

Expand Down
51 changes: 51 additions & 0 deletions test/nodes/test_half_normal.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
module HalfNormalNodeTest

using Test
using ReactiveMP
using Random

import ReactiveMP: make_node

@testset "HalfNormalNode" begin
@testset "Creation" begin
node = make_node(HalfNormal)
@test functionalform(node) === HalfNormal
@test sdtype(node) === Stochastic()
@test name.(interfaces(node)) === (:out, :v)
@test factorisation(node) === ((1, 2),)
@test localmarginalnames(node) === (:out_v,)
@test metadata(node) === nothing

node = make_node(HalfNormal, FactorNodeCreationOptions(nothing, 1, nothing))
@test metadata(node) === 1
end

@testset "AverageEnergy" begin
begin
q_out = GammaShapeRate(2.0, 1.0)
q_v = PointMass(2.0)

marginals = (Marginal(q_out, false, false, nothing), Marginal(q_v, false, false, nothing))

@test score(AverageEnergy(), HalfNormal, Val{(:out, :v)}(), marginals, nothing) 2.072364942925
end
begin
q_out = GammaShapeScale(2.0, 1.0)
q_v = PointMass(2.0)

marginals = (Marginal(q_out, false, false, nothing), Marginal(q_v, false, false, nothing))

@test score(AverageEnergy(), HalfNormal, Val{(:out, :v)}(), marginals, nothing) 2.072364942925
end

begin
q_out = GammaInverse(3.0, 1.0)
q_v = PointMass(2.0)

marginals = (Marginal(q_out, false, false, nothing), Marginal(q_v, false, false, nothing))

@test score(AverageEnergy(), HalfNormal, Val{(:out, :v)}(), marginals, nothing) 0.6973649429247
end
end # testset: AverageEnergy
end # testset
end # module
19 changes: 19 additions & 0 deletions test/rules/half_normal/test_out.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
module RulesHalfNormalOutTest

using Test
using ReactiveMP
using Random
using Distributions

import ReactiveMP: @test_rules

@testset "rules:HalfNormal:out" begin
@testset "Variational Message Passing: (q_v::Any)" begin
@test_rules [check_type_promotion = true] HalfNormal(:out, Marginalisation) [
(input = (q_v = PointMass(1.0),), output = Truncated(Normal(0.0, 1.0), 0.0, Inf)),
(input = (q_v = PointMass(100.0),), output = Truncated(Normal(0.0, 10.0), 0.0, Inf)),
(input = (q_v = PointMass(100),), output = Truncated(Normal(0.0, 10.0), 0.0, Inf))
]
end
end # testset
end # module

0 comments on commit 7b81d41

Please sign in to comment.