Skip to content

Commit

Permalink
feat: Added the MultinomialDirichlet evidential distribution.
Browse files Browse the repository at this point in the history
  • Loading branch information
DoktorMike committed Jul 11, 2022
1 parent a34d269 commit 7436f16
Show file tree
Hide file tree
Showing 6 changed files with 277 additions and 111 deletions.
41 changes: 41 additions & 0 deletions examples/classification.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
using EvidentialFlux
using Flux
using UnicodePlots


function gendata(n)
x1 = Float32.(randn(2, n))
x2 = Float32.(randn(2, n) .+ 2)
y1, y2 = Float32.(ones(n)), Float32.(zeros(n))
hcat(x1, x2), hcat(vcat(y1, y2), 1 .- vcat(y1, y2))'
end
n = 200
X, y = gendata(n)

# See the data
p = scatterplot(X[1, 1:n], X[2, 1:n], color = :green, width = 80, height = 30)
scatterplot!(p, X[1, (n+1):(n+n)], X[2, (n+1):(n+n)], color = :red)

m = Chain(Dense(2 => 30), DIR(30 => 2))
opt = Flux.Optimise.AdamW(0.01)
p = Flux.params(m)

epochs = 500
trnlosses = zeros(epochs)
for e in 1:epochs
local trnloss = 0
grads = Flux.gradient(p) do
α = m(X)
trnloss = dirloss(y, α, e)
trnloss
end
trnlosses[e] = trnloss
Flux.Optimise.update!(opt, p, grads)
end
scatterplot(1:epochs, trnlosses, width = 80, height = 30)

α̂ = m(X)
= α̂ ./ sum(α̂, dims = 1)
u = uncertainty(α̂)

contourplot(-5:.01:5, -5:.01:5, (x, y) -> uncertainty(m(vcat(y,x)))[1])
2 changes: 2 additions & 0 deletions src/EvidentialFlux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@ using SpecialFunctions

include("dense.jl")
export NIG
export DIR

include("losses.jl")
export nigloss
export dirloss

include("utils.jl")
export uncertainty
Expand Down
99 changes: 68 additions & 31 deletions src/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,47 +26,84 @@ The same holds true for the `bias` vector.
- `bias`: Whether to include a trainable bias vector.
"""
struct NIG{F,M<:AbstractMatrix,B}
W::M
b::B
σ::F
function NIG(W::M, b = true, σ::F = NNlib.softplus) where {M<:AbstractMatrix,F}
b = Flux.create_bias(W, b, size(W, 1))
return new{F,M,typeof(b)}(W, b, σ)
end
W::M
b::B
σ::F
function NIG(W::M, b = true, σ::F = NNlib.softplus) where {M<:AbstractMatrix,F}
b = Flux.create_bias(W, b, size(W, 1))
return new{F,M,typeof(b)}(W, b, σ)
end
end

function NIG((in, out)::Pair{<:Integer,<:Integer}, σ = NNlib.softplus;
init = Flux.glorot_uniform, bias = true)
NIG(init(out * 4, in), bias, σ)
init = Flux.glorot_uniform, bias = true)
NIG(init(out * 4, in), bias, σ)
end

Flux.@functor NIG

function (a::NIG)(x::AbstractVecOrMat)
nout = Int(size(a.W, 1) / 4)
o = a.W * x .+ a.b
γ = o[1:nout, :]
ν = o[(nout+1):(nout*2), :]
ν = a.σ.(ν)
α = o[(nout*2+1):(nout*3), :]
α = a.σ.(α) .+ 1
β = o[(nout*3+1):(nout*4), :]
β = a.σ.(β)
return vcat(γ, ν, α, β)
nout = Int(size(a.W, 1) / 4)
o = a.W * x .+ a.b
γ = o[1:nout, :]
ν = o[(nout+1):(nout*2), :]
ν = a.σ.(ν)
α = o[(nout*2+1):(nout*3), :]
α = a.σ.(α) .+ 1
β = o[(nout*3+1):(nout*4), :]
β = a.σ.(β)
return vcat(γ, ν, α, β)
end

(a::NIG)(x::AbstractArray) = reshape(a(reshape(x, size(x, 1), :)), :, size(x)[2:end]...)

#function predict(m::NIG, x::AbstractVecOrMat)
# nout = Int(size(m.W, 1) / 4)
# o = m.W * x .+ m.b
# γ = o[1:nout, :]
# ν = o[(nout+1):(nout*2), :]
# ν = m.σ.(ν)
# α = o[(nout*2+1):(nout*3), :]
# α = m.σ.(α) .+ 1
# β = o[(nout*3+1):(nout*4), :]
# β = m.σ.(β)
# return γ, uncertainty(ν, α, β), uncertainty(α, β)
#end
"""
DIR(in => out; bias=true, init=Flux.glorot_uniform)
DIR(W::AbstractMatrix, [bias])
A Linear layer with a softplys activation function in the end to implement the
Dirichlet evidential distribution. In this layer the number of output nodes
should correspond to the number of classes you wish to model. This layer should
be used to model a Multinomial likelihood with a Dirichlet prior. Thus the
posterior is also a Dirichlet distribution. Moreover the type II maximum
likelihood, i.e., the marginal likelihood is a Dirichlet-Multinomial
distribution. Create a fully connected layer which implements the Dirichlet
Evidential distribution whose forward pass is simply given by:
y = softplus.(W * x .+ bias)
The input `x` should be a vector of length `in`, or batch of vectors represented
as an `in × N` matrix, or any array with `size(x,1) == in`.
The out `y` will be a vector of length `out`, or a batch with
`size(y) == (out, size(x)[2:end]...)`
The output will have applied the function `softplus(y)` to each row/element of `y`.
Keyword `bias=false` will switch off trainable bias for the layer.
The initialisation of the weight matrix is `W = init(out, in)`, calling the function
given to keyword `init`, with default [`glorot_uniform`](@doc Flux.glorot_uniform).
The weight matrix and/or the bias vector (of length `out`) may also be provided explicitly.
# Arguments:
- `(in, out)`: number of input and output neurons
- `init`: The function to use to initialise the weight matrix.
- `bias`: Whether to include a trainable bias vector.
"""
struct DIR{M<:AbstractMatrix,B}
W::M
b::B
function DIR(W::M, b = true) where {M<:AbstractMatrix}
b = Flux.create_bias(W, b, size(W, 1))
return new{M,typeof(b)}(W, b)
end
end

function DIR((in, out)::Pair{<:Integer,<:Integer}; init = Flux.glorot_uniform, bias = true)
DIR(init(out, in), bias)
end

Flux.@functor DIR

function (a::DIR)(x::AbstractVecOrMat)
NNlib.softplus.(a.W * x .+ a.b) .+ 1
end

(a::DIR)(x::AbstractArray) = reshape(a(reshape(x, size(x, 1), :)), :, size(x)[2:end]...)
85 changes: 68 additions & 17 deletions src/losses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,72 @@ function: μ and σ.
- `ϵ`: the threshold for the regularizer (default: 0.0001)
"""
function nigloss(y, γ, ν, α, β, λ = 1, ϵ = 1e-4)
# NLL: Calculate the negative log likelihood of the Normal-Inverse-Gamma distribution
twoβλ = 2 * β .* (1 .+ ν)
logγ = SpecialFunctions.loggamma
nll = 0.5 * log.(π ./ ν) -
α .* log.(twoβλ) +
.+ 0.5) .* log.(ν .* (y - γ) .^ 2 + twoβλ) +
logγ.(α) -
logγ.(α .+ 0.5)
nll

# REG: Calculate regularizer based on absolute error of prediction
error = abs.(y - γ)
reg = error .* (2 * ν + α)

# Combine negative log likelihood and regularizer
loss = nll + λ .* (reg .- ϵ)
loss
# NLL: Calculate the negative log likelihood of the Normal-Inverse-Gamma distribution
twoβλ = 2 * β .* (1 .+ ν)
logγ = SpecialFunctions.loggamma
nll = 0.5 * log.(π ./ ν) -
α .* log.(twoβλ) +
.+ 0.5) .* log.(ν .* (y - γ) .^ 2 + twoβλ) +
logγ.(α) -
logγ.(α .+ 0.5)
nll

# REG: Calculate regularizer based on absolute error of prediction
error = abs.(y - γ)
reg = error .* (2 * ν + α)

# Combine negative log likelihood and regularizer
loss = nll + λ .* (reg .- ϵ)
loss
end

# The α here is actually the α̃ which has scaled down evidence that is good.
# the α heres is a matrix of size (K, B) or (O, B)
function kl(α)
ψ = SpecialFunctions.digamma
lnΓ = SpecialFunctions.loggamma
K = first(size(α))
# Actual computation
∑α = sum(α, dims = 1)
∑lnΓα = sum(lnΓ.(α), dims = 1)
A = lnΓ.(∑α) .- lnΓ(K) .- ∑lnΓα
B = sum((α .- 1) .* (ψ.(α) .- ψ.(∑α)), dims = 1)
kl = A + B
kl
end


"""
dirloss(y, α, t)
Regularized version of a type II maximum likelihood for the Multinomial(p)
distribution where the parameter p, which follows a Dirichlet distribution has
been integrated out.
# Arguments:
- `y`: the targets whose shape should be (O, B)
- `α`: the parameters of a Dirichlet distribution representing the belief in each class which shape should be (O, B)
- `t`: counter for the current epoch being evaluated
"""
function dirloss(y, α, t)
S = sum(α, dims = 1)
= α ./ S
# Main loss
loss = (y - p̂) .^ 2 .+.* (1 .- p̂) ./ (S .+ 1)
loss = sum(loss, dims = 1)
# Regularizer
λₜ = min(1.0, t / 10)
# Keep only misleading evidence, i.e., penalize stuff that fit badly
α̂ = @. y + (1 - y) * α
reg = kl(α̂)
# Total loss = likelihood + regularizer
#sum(loss .+ λₜ .* reg, dims = 2)
sum(loss .+ λₜ .* reg)
end

#y = Flux.onehotbatch(rand(Categorical([0.2, 0.2, 0.2, 0.2, 0.2]), 10), 1:5)
#α = reshape(1:50, (5, 10))
#S = sum(α, dims = 1)
#p̂ = α ./ S
#α̂ = @. y + (1 - y) * α
#kl(α̂)
23 changes: 23 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,24 @@ Given a ``\\text{N-}\\Gamma^{-1}(γ, υ, α, β)`` distribution we can calculate
"""
uncertainty(α, β) = @. β /- 1)

"""
uncertainty(α)
Calculates the epistemic uncertainty associated with a MultinomialDirichlet model (DIR) layer.
- `α`: the α parameter of the Dirichlet distribution which relates to it's concentrations and whose shape should be (O, B)
"""
uncertainty(α) = first(size(α)) ./ sum(α, dims = 1)

"""
evidence(α)
Calculates the total evidence of assigning each observation in α to the respective class for a DIR layer.
- `α`: the α parameter of the Dirichlet distribution which relates to it's concentrations and whose shape should be (O, B)
"""
evidence(α) = α .- 1

"""
evidence(ν, α)
Expand Down Expand Up @@ -59,3 +77,8 @@ function predict(::Type{<:NIG}, m, x)
#return γ, uncertainty(ν, α, β), uncertainty(α, β)
γ, ν, α, β
end

function predict(::Type{<:DIR}, m, x)
= m(x)
end
Loading

0 comments on commit 7436f16

Please sign in to comment.