feat: Added the MultinomialDirichlet evidential distribution.
DoktorMike committed Jul 11, 2022
commit 7436f16
Showing 6 changed files with 277 additions and 111 deletions.
using EvidentialFlux
using Flux
using UnicodePlots

function gendata(n)
x1 = Float32.(randn(2, n))
x2 = Float32.(randn(2, n) .+ 2)
y1, y2 = Float32.(ones(n)), Float32.(zeros(n))
hcat(x1, x2), hcat(vcat(y1, y2), 1 .- vcat(y1, y2))'
n = 200
X, y = gendata(n)

# See the data
p = scatterplot(X[1, 1:n], X[2, 1:n], color = :green, width = 80, height = 30)
scatterplot!(p, X[1, (n+1):(n+n)], X[2, (n+1):(n+n)], color = :red)

m = Chain(Dense(2 => 30), DIR(30 => 2))
opt = Flux.Optimise.AdamW(0.01)
p = Flux.params(m)

epochs = 500
trnlosses = zeros(epochs)
for e in 1:epochs
local trnloss = 0
grads = Flux.gradient(p) do
α = m(X)
trnloss = dirloss(y, α, e)
trnlosses[e] = trnloss
Flux.Optimise.update!(opt, p, grads)
scatterplot(1:epochs, trnlosses, width = 80, height = 30)

α̂ = m(X)
= α̂ ./ sum(α̂, dims = 1)
u = uncertainty(α̂)

contourplot(-5:.01:5, -5:.01:5, (x, y) -> uncertainty(m(vcat(y,x)))[1])
Expand Up @@ -8,9 +8,11 @@ using SpecialFunctions

export NIG
export DIR

export nigloss
export dirloss

export uncertainty
Expand Up @@ -26,47 +26,84 @@ The same holds true for the `bias` vector.
- `bias`: Whether to include a trainable bias vector.
struct NIG{F,M<:AbstractMatrix,B}
function NIG(W::M, b = true, σ::F = NNlib.softplus) where {M<:AbstractMatrix,F}
b = Flux.create_bias(W, b, size(W, 1))
return new{F,M,typeof(b)}(W, b, σ)
function NIG(W::M, b = true, σ::F = NNlib.softplus) where {M<:AbstractMatrix,F}
b = Flux.create_bias(W, b, size(W, 1))
return new{F,M,typeof(b)}(W, b, σ)

function NIG((in, out)::Pair{<:Integer,<:Integer}, σ = NNlib.softplus;
init = Flux.glorot_uniform, bias = true)
NIG(init(out * 4, in), bias, σ)
init = Flux.glorot_uniform, bias = true)
NIG(init(out * 4, in), bias, σ)

Flux.@functor NIG

function (a::NIG)(x::AbstractVecOrMat)
nout = Int(size(a.W, 1) / 4)
o = a.W * x .+ a.b
γ = o[1:nout, :]
ν = o[(nout+1):(nout*2), :]
ν = a.σ.(ν)
α = o[(nout*2+1):(nout*3), :]
α = a.σ.(α) .+ 1
β = o[(nout*3+1):(nout*4), :]
β = a.σ.(β)
return vcat(γ, ν, α, β)
nout = Int(size(a.W, 1) / 4)
o = a.W * x .+ a.b
γ = o[1:nout, :]
ν = o[(nout+1):(nout*2), :]
ν = a.σ.(ν)
α = o[(nout*2+1):(nout*3), :]
α = a.σ.(α) .+ 1
β = o[(nout*3+1):(nout*4), :]
β = a.σ.(β)
return vcat(γ, ν, α, β)

(a::NIG)(x::AbstractArray) = reshape(a(reshape(x, size(x, 1), :)), :, size(x)[2:end]...)

#function predict(m::NIG, x::AbstractVecOrMat)
# nout = Int(size(m.W, 1) / 4)
# o = m.W * x .+ m.b
# γ = o[1:nout, :]
# ν = o[(nout+1):(nout*2), :]
# ν = m.σ.(ν)
# α = o[(nout*2+1):(nout*3), :]
# α = m.σ.(α) .+ 1
# β = o[(nout*3+1):(nout*4), :]
# β = m.σ.(β)
# return γ, uncertainty(ν, α, β), uncertainty(α, β)
DIR(in => out; bias=true, init=Flux.glorot_uniform)
DIR(W::AbstractMatrix, [bias])
A Linear layer with a softplys activation function in the end to implement the
Dirichlet evidential distribution. In this layer the number of output nodes
should correspond to the number of classes you wish to model. This layer should
be used to model a Multinomial likelihood with a Dirichlet prior. Thus the
posterior is also a Dirichlet distribution. Moreover the type II maximum
likelihood, i.e., the marginal likelihood is a Dirichlet-Multinomial
distribution. Create a fully connected layer which implements the Dirichlet
Evidential distribution whose forward pass is simply given by:
y = softplus.(W * x .+ bias)
The input `x` should be a vector of length `in`, or batch of vectors represented
as an `in × N` matrix, or any array with `size(x,1) == in`.
The out `y` will be a vector of length `out`, or a batch with
`size(y) == (out, size(x)[2:end]...)`
The output will have applied the function `softplus(y)` to each row/element of `y`.
Keyword `bias=false` will switch off trainable bias for the layer.
The initialisation of the weight matrix is `W = init(out, in)`, calling the function
given to keyword `init`, with default [`glorot_uniform`](@doc Flux.glorot_uniform).
The weight matrix and/or the bias vector (of length `out`) may also be provided explicitly.
# Arguments:
- `(in, out)`: number of input and output neurons
- `init`: The function to use to initialise the weight matrix.
- `bias`: Whether to include a trainable bias vector.
struct DIR{M<:AbstractMatrix,B}
function DIR(W::M, b = true) where {M<:AbstractMatrix}
b = Flux.create_bias(W, b, size(W, 1))
return new{M,typeof(b)}(W, b)

function DIR((in, out)::Pair{<:Integer,<:Integer}; init = Flux.glorot_uniform, bias = true)
DIR(init(out, in), bias)

Flux.@functor DIR

function (a::DIR)(x::AbstractVecOrMat)
NNlib.softplus.(a.W * x .+ a.b) .+ 1

(a::DIR)(x::AbstractArray) = reshape(a(reshape(x, size(x, 1), :)), :, size(x)[2:end]...)
Expand Up @@ -15,21 +15,72 @@ function: μ and σ.
- `ϵ`: the threshold for the regularizer (default: 0.0001)
function nigloss(y, γ, ν, α, β, λ = 1, ϵ = 1e-4)
# NLL: Calculate the negative log likelihood of the Normal-Inverse-Gamma distribution
twoβλ = 2 * β .* (1 .+ ν)
logγ = SpecialFunctions.loggamma
nll = 0.5 * log.(π ./ ν) -
α .* log.(twoβλ) +
.+ 0.5) .* log.(ν .* (y - γ) .^ 2 + twoβλ) +
logγ.(α) -
logγ.(α .+ 0.5)

# REG: Calculate regularizer based on absolute error of prediction
error = abs.(y - γ)
reg = error .* (2 * ν + α)

# Combine negative log likelihood and regularizer
loss = nll + λ .* (reg .- ϵ)
# NLL: Calculate the negative log likelihood of the Normal-Inverse-Gamma distribution
twoβλ = 2 * β .* (1 .+ ν)
logγ = SpecialFunctions.loggamma
nll = 0.5 * log.(π ./ ν) -
α .* log.(twoβλ) +
.+ 0.5) .* log.(ν .* (y - γ) .^ 2 + twoβλ) +
logγ.(α) -
logγ.(α .+ 0.5)

# REG: Calculate regularizer based on absolute error of prediction
error = abs.(y - γ)
reg = error .* (2 * ν + α)

# Combine negative log likelihood and regularizer
loss = nll + λ .* (reg .- ϵ)

# The α here is actually the α̃ which has scaled down evidence that is good.
# the α heres is a matrix of size (K, B) or (O, B)
function kl(α)
ψ = SpecialFunctions.digamma
lnΓ = SpecialFunctions.loggamma
K = first(size(α))
# Actual computation
∑α = sum(α, dims = 1)
∑lnΓα = sum(lnΓ.(α), dims = 1)
A = lnΓ.(∑α) .- lnΓ(K) .- ∑lnΓα
B = sum((α .- 1) .* (ψ.(α) .- ψ.(∑α)), dims = 1)
kl = A + B

dirloss(y, α, t)
Regularized version of a type II maximum likelihood for the Multinomial(p)
distribution where the parameter p, which follows a Dirichlet distribution has
been integrated out.
# Arguments:
- `y`: the targets whose shape should be (O, B)
- `α`: the parameters of a Dirichlet distribution representing the belief in each class which shape should be (O, B)
- `t`: counter for the current epoch being evaluated
function dirloss(y, α, t)
S = sum(α, dims = 1)
= α ./ S
# Main loss
loss = (y - p̂) .^ 2 .+.* (1 .- p̂) ./ (S .+ 1)
loss = sum(loss, dims = 1)
# Regularizer
λₜ = min(1.0, t / 10)
# Keep only misleading evidence, i.e., penalize stuff that fit badly
α̂ = @. y + (1 - y) * α
reg = kl(α̂)
# Total loss = likelihood + regularizer
#sum(loss .+ λₜ .* reg, dims = 2)
sum(loss .+ λₜ .* reg)

#y = Flux.onehotbatch(rand(Categorical([0.2, 0.2, 0.2, 0.2, 0.2]), 10), 1:5)
#α = reshape(1:50, (5, 10))
#S = sum(α, dims = 1)
#p̂ = α ./ S
#α̂ = @. y + (1 - y) * α
Expand Up @@ -28,6 +28,24 @@ Given a ``\\text{N-}\\Gamma^{-1}(γ, υ, α, β)`` distribution we can calculate
uncertainty(α, β) = @. β /- 1)

Calculates the epistemic uncertainty associated with a MultinomialDirichlet model (DIR) layer.
- `α`: the α parameter of the Dirichlet distribution which relates to it's concentrations and whose shape should be (O, B)
uncertainty(α) = first(size(α)) ./ sum(α, dims = 1)

Calculates the total evidence of assigning each observation in α to the respective class for a DIR layer.
- `α`: the α parameter of the Dirichlet distribution which relates to it's concentrations and whose shape should be (O, B)
evidence(α) = α .- 1

evidence(ν, α)
Expand Down Expand Up @@ -59,3 +77,8 @@ function predict(::Type{<:NIG}, m, x)
#return γ, uncertainty(ν, α, β), uncertainty(α, β)
γ, ν, α, β

function predict(::Type{<:DIR}, m, x)
= m(x)

