Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Half-normal node #338

Merged
merged 7 commits into from
Jul 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/ReactiveMP.jl
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ include("nodes/bifm.jl")
include("nodes/bifm_helper.jl")
include("nodes/probit.jl")
include("nodes/poisson.jl")
include("nodes/half_normal.jl")

include("nodes/flow/flow.jl")
include("nodes/delta/delta.jl")
Expand Down
22 changes: 22 additions & 0 deletions src/distributions/gamma.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,28 @@
return loggamma(ay) - loggamma(ax) - loggamma(az) + ax * log(bx) + az * log(bz) - ay * log(by)
end

prod_analytical_rule(::Type{<:Truncated{<:Normal}}, ::Type{<:GammaDistributionsFamily}) = ProdAnalyticalRuleAvailable()
prod_analytical_rule(::Type{<:GammaDistributionsFamily}, ::Type{<:Truncated{<:Normal}}) = ProdAnalyticalRuleAvailable()

Check warning on line 82 in src/distributions/gamma.jl

View check run for this annotation

Codecov / codecov/patch

src/distributions/gamma.jl#L81-L82

Added lines #L81 - L82 were not covered by tests

prod(::ProdAnalytical, left::GammaDistributionsFamily, right::Truncated{<:Normal}) = prod(ProdAnalytical(), right, left)

function prod(::ProdAnalytical, left::Truncated{<:Normal}, right::GammaDistributionsFamily)
@assert (left.lower ≈ zero(left.lower) && isinf(left.upper)) "Truncated{Normal} * Gamma only implemented for Truncated{Normal}(0, Inf)"

samples = rand(MersenneTwister(123), left, 1000)
zeronum = zero(eltype(samples))

sx, xlogx, tw = mapreduce(.+, samples; init = (zeronum, zeronum, zeronum)) do sample
w = pdf(right, sample)
return (w * sample, w * log(sample), w)
end

statistics = Distributions.GammaStats(sx, xlogx, tw)
fit = Distributions.fit_mle(Gamma, statistics, alpha0 = shape(right))

return convert(typeof(right), fit)
end

## Friendly functions

function logpdf_sample_friendly(dist::GammaDistributionsFamily)
Expand Down
12 changes: 12 additions & 0 deletions src/distributions/gamma_inverse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,15 @@
θ = scale(dist)
return α / θ
end

prod_analytical_rule(::Type{<:Truncated{<:Normal}}, ::Type{<:GammaInverse}) = ProdAnalyticalRuleAvailable()
prod_analytical_rule(::Type{<:GammaInverse}, ::Type{<:Truncated{<:Normal}}) = ProdAnalyticalRuleAvailable()

Check warning on line 31 in src/distributions/gamma_inverse.jl

View check run for this annotation

Codecov / codecov/patch

src/distributions/gamma_inverse.jl#L30-L31

Added lines #L30 - L31 were not covered by tests

prod(::ProdAnalytical, left::GammaInverse, right::Truncated{<:Normal}) = prod(ProdAnalytical(), right, left)

function prod(::ProdAnalytical, left::Truncated{<:Normal}, right::GammaInverse)
@assert (left.lower ≈ zero(left.lower) && isinf(left.upper)) "Truncated{Normal} * GammaInverse only implemented for Truncated{Normal}(0, Inf)"
α, θ = shape(right), scale(right)
Γ = prod(ProdAnalytical(), left, Gamma(α, inv(θ)))
return InverseGamma(shape(Γ), inv(scale(Γ)))
end
5 changes: 5 additions & 0 deletions src/distributions/normal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,11 @@
"""An alias for the [`JointNormal`](@ref)."""
const JointGaussian = JointNormal

# Half-Normal related
function convert_paramfloattype(::Type{T}, distribution::Truncated{<:Normal}) where {T}
return Truncated(convert_paramfloattype(T, distribution.untruncated), convert(T, distribution.lower), convert(T, distribution.upper))

Check warning on line 189 in src/distributions/normal.jl

View check run for this annotation

Codecov / codecov/patch

src/distributions/normal.jl#L188-L189

Added lines #L188 - L189 were not covered by tests
end

# Variate forms promotion

promote_variate_type(::Type{Univariate}, ::Type{F}) where {F <: UnivariateNormalDistributionsFamily} = F
Expand Down
10 changes: 10 additions & 0 deletions src/nodes/half_normal.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
export HalfNormal

struct HalfNormal end

@node HalfNormal Stochastic [out, (v, aliases = [var, σ²])]

@average_energy HalfNormal (q_out::Any, q_v::Any) = begin
out_mean, out_var = mean_var(q_out)
return (log(π / 2) + mean(log, q_v) + mean(inv, q_v) * (out_mean^2 + out_var)) / 2
end
4 changes: 4 additions & 0 deletions src/rules/half_normal/out.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
@rule HalfNormal(:out, Marginalisation) (q_v::PointMass,) = begin
mean_v = mean(q_v)
return Truncated(Normal(zero(eltype(q_v)), sqrt(mean_v)), zero(eltype(q_v)), typemax(float(mean_v)))
end
2 changes: 2 additions & 0 deletions src/rules/prototypes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -171,3 +171,5 @@ include("delta/unscented/marginals.jl")
include("delta/cvi/in.jl")
include("delta/cvi/out.jl")
include("delta/cvi/marginals.jl")

include("half_normal/out.jl")
11 changes: 11 additions & 0 deletions test/distributions/test_gamma.jl
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,17 @@ import ReactiveMP: xtlog
@test prod(ProdAnalytical(), GammaShapeRate(1, 2), GammaShapeScale(1, 2)) == GammaShapeRate(1, 5 / 2)
@test prod(ProdAnalytical(), GammaShapeRate(2, 2), GammaShapeScale(1, 2)) == GammaShapeRate(2, 5 / 2)
@test prod(ProdAnalytical(), GammaShapeRate(2, 2), GammaShapeScale(2, 2)) == GammaShapeRate(3, 5 / 2)

@test_throws AssertionError prod(ProdAnalytical(), GammaShapeRate(1, 1), Truncated(Normal(0.0, 1.0), -1.0, 1.0))
@test_throws AssertionError prod(ProdAnalytical(), Truncated(Normal(0.0, 1.0), -1.0, 1.0), GammaShapeRate(1, 1))
@test_throws AssertionError prod(ProdAnalytical(), GammaShapeScale(1, 1), Truncated(Normal(0.0, 1.0), -1.0, 1.0))
@test_throws AssertionError prod(ProdAnalytical(), Truncated(Normal(0.0, 1.0), -1.0, 1.0), GammaShapeScale(1, 1))

# TODO: these tests should check also check the actual result
@test prod(ProdAnalytical(), GammaShapeRate(1, 1), Truncated(Normal(0.0, 1.0), 0.0, Inf)) isa GammaShapeRate
@test prod(ProdAnalytical(), Truncated(Normal(0.0, 1.0), 0.0, Inf), GammaShapeRate(1, 1)) isa GammaShapeRate
@test prod(ProdAnalytical(), GammaShapeScale(1, 1), Truncated(Normal(0.0, 1.0), 0.0, Inf)) isa GammaShapeScale
@test prod(ProdAnalytical(), Truncated(Normal(0.0, 1.0), 0.0, Inf), GammaShapeScale(1, 1)) isa GammaShapeScale
end
end

Expand Down
7 changes: 7 additions & 0 deletions test/distributions/test_gamma_inverse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,13 @@ using Random
@test prod(ProdAnalytical(), GammaInverse(3.0, 2.0), GammaInverse(2.0, 1.0)) ≈ GammaInverse(6.0, 3.0)
@test prod(ProdAnalytical(), GammaInverse(7.0, 1.0), GammaInverse(0.1, 4.5)) ≈ GammaInverse(8.1, 5.5)
@test prod(ProdAnalytical(), GammaInverse(1.0, 3.0), GammaInverse(0.2, 0.4)) ≈ GammaInverse(2.2, 3.4)

@test_throws AssertionError prod(ProdAnalytical(), GammaInverse(1, 1), Truncated(Normal(0.0, 1.0), -1.0, 1.0))
@test_throws AssertionError prod(ProdAnalytical(), Truncated(Normal(0.0, 1.0), -1.0, 1.0), GammaInverse(1, 1))

# TODO: these tests should check also check the actual result
@test prod(ProdAnalytical(), GammaInverse(1, 1), Truncated(Normal(0.0, 1.0), 0.0, Inf)) isa GammaInverse
@test prod(ProdAnalytical(), Truncated(Normal(0.0, 1.0), 0.0, Inf), GammaInverse(1, 1)) isa GammaInverse
end

# log(θ) - digamma(α)
Expand Down
6 changes: 3 additions & 3 deletions test/nodes/test_gamma_inverse.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
module InverseWishartNodeTest
module InverseGammaNodeTest

using Test
using ReactiveMP
using Random

import ReactiveMP: make_node

@testset "InverseWishartNode" begin
@testset "InverseGammaNode" begin
@testset "Creation" begin
node = make_node(GammaInverse)
@test functionalform(node) === GammaInverse
Expand All @@ -16,7 +16,7 @@ import ReactiveMP: make_node
@test localmarginalnames(node) === (:out_α_θ,)
@test metadata(node) === nothing

node = make_node(InverseWishart, FactorNodeCreationOptions(nothing, 1, nothing))
node = make_node(GammaInverse, FactorNodeCreationOptions(nothing, 1, nothing))
@test metadata(node) === 1
end

Expand Down
51 changes: 51 additions & 0 deletions test/nodes/test_half_normal.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
module HalfNormalNodeTest

using Test
using ReactiveMP
using Random

import ReactiveMP: make_node

@testset "HalfNormalNode" begin
@testset "Creation" begin
node = make_node(HalfNormal)
@test functionalform(node) === HalfNormal
@test sdtype(node) === Stochastic()
@test name.(interfaces(node)) === (:out, :v)
@test factorisation(node) === ((1, 2),)
@test localmarginalnames(node) === (:out_v,)
@test metadata(node) === nothing

node = make_node(HalfNormal, FactorNodeCreationOptions(nothing, 1, nothing))
@test metadata(node) === 1
end

@testset "AverageEnergy" begin
begin
q_out = GammaShapeRate(2.0, 1.0)
q_v = PointMass(2.0)

marginals = (Marginal(q_out, false, false, nothing), Marginal(q_v, false, false, nothing))

@test score(AverageEnergy(), HalfNormal, Val{(:out, :v)}(), marginals, nothing) ≈ 2.072364942925
end
begin
q_out = GammaShapeScale(2.0, 1.0)
q_v = PointMass(2.0)

marginals = (Marginal(q_out, false, false, nothing), Marginal(q_v, false, false, nothing))

@test score(AverageEnergy(), HalfNormal, Val{(:out, :v)}(), marginals, nothing) ≈ 2.072364942925
end

begin
q_out = GammaInverse(3.0, 1.0)
q_v = PointMass(2.0)

marginals = (Marginal(q_out, false, false, nothing), Marginal(q_v, false, false, nothing))

@test score(AverageEnergy(), HalfNormal, Val{(:out, :v)}(), marginals, nothing) ≈ 0.6973649429247
end
end # testset: AverageEnergy
end # testset
end # module
19 changes: 19 additions & 0 deletions test/rules/half_normal/test_out.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
module RulesHalfNormalOutTest

using Test
using ReactiveMP
using Random
using Distributions

import ReactiveMP: @test_rules

@testset "rules:HalfNormal:out" begin
@testset "Variational Message Passing: (q_v::Any)" begin
@test_rules [check_type_promotion = true] HalfNormal(:out, Marginalisation) [
(input = (q_v = PointMass(1.0),), output = Truncated(Normal(0.0, 1.0), 0.0, Inf)),
(input = (q_v = PointMass(100.0),), output = Truncated(Normal(0.0, 10.0), 0.0, Inf)),
(input = (q_v = PointMass(100),), output = Truncated(Normal(0.0, 10.0), 0.0, Inf))
]
end
end # testset
end # module
Loading