diff --git a/examples/classification.jl b/examples/classification.jl new file mode 100644 index 0000000..8017ce2 --- /dev/null +++ b/examples/classification.jl @@ -0,0 +1,41 @@ +using EvidentialFlux +using Flux +using UnicodePlots + + +function gendata(n) + x1 = Float32.(randn(2, n)) + x2 = Float32.(randn(2, n) .+ 2) + y1, y2 = Float32.(ones(n)), Float32.(zeros(n)) + hcat(x1, x2), hcat(vcat(y1, y2), 1 .- vcat(y1, y2))' +end +n = 200 +X, y = gendata(n) + +# See the data +p = scatterplot(X[1, 1:n], X[2, 1:n], color = :green, width = 80, height = 30) +scatterplot!(p, X[1, (n+1):(n+n)], X[2, (n+1):(n+n)], color = :red) + +m = Chain(Dense(2 => 30), DIR(30 => 2)) +opt = Flux.Optimise.AdamW(0.01) +p = Flux.params(m) + +epochs = 500 +trnlosses = zeros(epochs) +for e in 1:epochs + local trnloss = 0 + grads = Flux.gradient(p) do + α = m(X) + trnloss = dirloss(y, α, e) + trnloss + end + trnlosses[e] = trnloss + Flux.Optimise.update!(opt, p, grads) +end +scatterplot(1:epochs, trnlosses, width = 80, height = 30) + +α̂ = m(X) +ŷ = α̂ ./ sum(α̂, dims = 1) +u = uncertainty(α̂) + +contourplot(-5:.01:5, -5:.01:5, (x, y) -> uncertainty(m(vcat(y,x)))[1]) diff --git a/src/EvidentialFlux.jl b/src/EvidentialFlux.jl index fac794d..392ad3a 100644 --- a/src/EvidentialFlux.jl +++ b/src/EvidentialFlux.jl @@ -8,9 +8,11 @@ using SpecialFunctions include("dense.jl") export NIG +export DIR include("losses.jl") export nigloss +export dirloss include("utils.jl") export uncertainty diff --git a/src/dense.jl b/src/dense.jl index feed1ba..d0f2e3f 100644 --- a/src/dense.jl +++ b/src/dense.jl @@ -26,47 +26,84 @@ The same holds true for the `bias` vector. - `bias`: Whether to include a trainable bias vector. """ struct NIG{F,M<:AbstractMatrix,B} - W::M - b::B - σ::F - function NIG(W::M, b = true, σ::F = NNlib.softplus) where {M<:AbstractMatrix,F} - b = Flux.create_bias(W, b, size(W, 1)) - return new{F,M,typeof(b)}(W, b, σ) - end + W::M + b::B + σ::F + function NIG(W::M, b = true, σ::F = NNlib.softplus) where {M<:AbstractMatrix,F} + b = Flux.create_bias(W, b, size(W, 1)) + return new{F,M,typeof(b)}(W, b, σ) + end end function NIG((in, out)::Pair{<:Integer,<:Integer}, σ = NNlib.softplus; - init = Flux.glorot_uniform, bias = true) - NIG(init(out * 4, in), bias, σ) + init = Flux.glorot_uniform, bias = true) + NIG(init(out * 4, in), bias, σ) end Flux.@functor NIG function (a::NIG)(x::AbstractVecOrMat) - nout = Int(size(a.W, 1) / 4) - o = a.W * x .+ a.b - γ = o[1:nout, :] - ν = o[(nout+1):(nout*2), :] - ν = a.σ.(ν) - α = o[(nout*2+1):(nout*3), :] - α = a.σ.(α) .+ 1 - β = o[(nout*3+1):(nout*4), :] - β = a.σ.(β) - return vcat(γ, ν, α, β) + nout = Int(size(a.W, 1) / 4) + o = a.W * x .+ a.b + γ = o[1:nout, :] + ν = o[(nout+1):(nout*2), :] + ν = a.σ.(ν) + α = o[(nout*2+1):(nout*3), :] + α = a.σ.(α) .+ 1 + β = o[(nout*3+1):(nout*4), :] + β = a.σ.(β) + return vcat(γ, ν, α, β) end (a::NIG)(x::AbstractArray) = reshape(a(reshape(x, size(x, 1), :)), :, size(x)[2:end]...) -#function predict(m::NIG, x::AbstractVecOrMat) -# nout = Int(size(m.W, 1) / 4) -# o = m.W * x .+ m.b -# γ = o[1:nout, :] -# ν = o[(nout+1):(nout*2), :] -# ν = m.σ.(ν) -# α = o[(nout*2+1):(nout*3), :] -# α = m.σ.(α) .+ 1 -# β = o[(nout*3+1):(nout*4), :] -# β = m.σ.(β) -# return γ, uncertainty(ν, α, β), uncertainty(α, β) -#end +""" + DIR(in => out; bias=true, init=Flux.glorot_uniform) + DIR(W::AbstractMatrix, [bias]) + +A Linear layer with a softplys activation function in the end to implement the +Dirichlet evidential distribution. In this layer the number of output nodes +should correspond to the number of classes you wish to model. This layer should +be used to model a Multinomial likelihood with a Dirichlet prior. Thus the +posterior is also a Dirichlet distribution. Moreover the type II maximum +likelihood, i.e., the marginal likelihood is a Dirichlet-Multinomial +distribution. Create a fully connected layer which implements the Dirichlet +Evidential distribution whose forward pass is simply given by: + + y = softplus.(W * x .+ bias) + +The input `x` should be a vector of length `in`, or batch of vectors represented +as an `in × N` matrix, or any array with `size(x,1) == in`. +The out `y` will be a vector of length `out`, or a batch with +`size(y) == (out, size(x)[2:end]...)` +The output will have applied the function `softplus(y)` to each row/element of `y`. +Keyword `bias=false` will switch off trainable bias for the layer. +The initialisation of the weight matrix is `W = init(out, in)`, calling the function +given to keyword `init`, with default [`glorot_uniform`](@doc Flux.glorot_uniform). +The weight matrix and/or the bias vector (of length `out`) may also be provided explicitly. + +# Arguments: +- `(in, out)`: number of input and output neurons +- `init`: The function to use to initialise the weight matrix. +- `bias`: Whether to include a trainable bias vector. +""" +struct DIR{M<:AbstractMatrix,B} + W::M + b::B + function DIR(W::M, b = true) where {M<:AbstractMatrix} + b = Flux.create_bias(W, b, size(W, 1)) + return new{M,typeof(b)}(W, b) + end +end + +function DIR((in, out)::Pair{<:Integer,<:Integer}; init = Flux.glorot_uniform, bias = true) + DIR(init(out, in), bias) +end + +Flux.@functor DIR + +function (a::DIR)(x::AbstractVecOrMat) + NNlib.softplus.(a.W * x .+ a.b) .+ 1 +end +(a::DIR)(x::AbstractArray) = reshape(a(reshape(x, size(x, 1), :)), :, size(x)[2:end]...) diff --git a/src/losses.jl b/src/losses.jl index 304b095..547171a 100644 --- a/src/losses.jl +++ b/src/losses.jl @@ -15,21 +15,72 @@ function: μ and σ. - `ϵ`: the threshold for the regularizer (default: 0.0001) """ function nigloss(y, γ, ν, α, β, λ = 1, ϵ = 1e-4) - # NLL: Calculate the negative log likelihood of the Normal-Inverse-Gamma distribution - twoβλ = 2 * β .* (1 .+ ν) - logγ = SpecialFunctions.loggamma - nll = 0.5 * log.(π ./ ν) - - α .* log.(twoβλ) + - (α .+ 0.5) .* log.(ν .* (y - γ) .^ 2 + twoβλ) + - logγ.(α) - - logγ.(α .+ 0.5) - nll - - # REG: Calculate regularizer based on absolute error of prediction - error = abs.(y - γ) - reg = error .* (2 * ν + α) - - # Combine negative log likelihood and regularizer - loss = nll + λ .* (reg .- ϵ) - loss + # NLL: Calculate the negative log likelihood of the Normal-Inverse-Gamma distribution + twoβλ = 2 * β .* (1 .+ ν) + logγ = SpecialFunctions.loggamma + nll = 0.5 * log.(π ./ ν) - + α .* log.(twoβλ) + + (α .+ 0.5) .* log.(ν .* (y - γ) .^ 2 + twoβλ) + + logγ.(α) - + logγ.(α .+ 0.5) + nll + + # REG: Calculate regularizer based on absolute error of prediction + error = abs.(y - γ) + reg = error .* (2 * ν + α) + + # Combine negative log likelihood and regularizer + loss = nll + λ .* (reg .- ϵ) + loss +end + +# The α here is actually the α̃ which has scaled down evidence that is good. +# the α heres is a matrix of size (K, B) or (O, B) +function kl(α) + ψ = SpecialFunctions.digamma + lnΓ = SpecialFunctions.loggamma + K = first(size(α)) + # Actual computation + ∑α = sum(α, dims = 1) + ∑lnΓα = sum(lnΓ.(α), dims = 1) + A = lnΓ.(∑α) .- lnΓ(K) .- ∑lnΓα + B = sum((α .- 1) .* (ψ.(α) .- ψ.(∑α)), dims = 1) + kl = A + B + kl +end + + +""" + dirloss(y, α, t) + +Regularized version of a type II maximum likelihood for the Multinomial(p) +distribution where the parameter p, which follows a Dirichlet distribution has +been integrated out. + +# Arguments: +- `y`: the targets whose shape should be (O, B) +- `α`: the parameters of a Dirichlet distribution representing the belief in each class which shape should be (O, B) +- `t`: counter for the current epoch being evaluated +""" +function dirloss(y, α, t) + S = sum(α, dims = 1) + p̂ = α ./ S + # Main loss + loss = (y - p̂) .^ 2 .+ p̂ .* (1 .- p̂) ./ (S .+ 1) + loss = sum(loss, dims = 1) + # Regularizer + λₜ = min(1.0, t / 10) + # Keep only misleading evidence, i.e., penalize stuff that fit badly + α̂ = @. y + (1 - y) * α + reg = kl(α̂) + # Total loss = likelihood + regularizer + #sum(loss .+ λₜ .* reg, dims = 2) + sum(loss .+ λₜ .* reg) end + +#y = Flux.onehotbatch(rand(Categorical([0.2, 0.2, 0.2, 0.2, 0.2]), 10), 1:5) +#α = reshape(1:50, (5, 10)) +#S = sum(α, dims = 1) +#p̂ = α ./ S +#α̂ = @. y + (1 - y) * α +#kl(α̂) diff --git a/src/utils.jl b/src/utils.jl index eee877b..e8be86e 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -28,6 +28,24 @@ Given a ``\\text{N-}\\Gamma^{-1}(γ, υ, α, β)`` distribution we can calculate """ uncertainty(α, β) = @. β / (α - 1) +""" + uncertainty(α) + +Calculates the epistemic uncertainty associated with a MultinomialDirichlet model (DIR) layer. + +- `α`: the α parameter of the Dirichlet distribution which relates to it's concentrations and whose shape should be (O, B) +""" +uncertainty(α) = first(size(α)) ./ sum(α, dims = 1) + +""" + evidence(α) + +Calculates the total evidence of assigning each observation in α to the respective class for a DIR layer. + +- `α`: the α parameter of the Dirichlet distribution which relates to it's concentrations and whose shape should be (O, B) +""" +evidence(α) = α .- 1 + """ evidence(ν, α) @@ -59,3 +77,8 @@ function predict(::Type{<:NIG}, m, x) #return γ, uncertainty(ν, α, β), uncertainty(α, β) γ, ν, α, β end + +function predict(::Type{<:DIR}, m, x) + ŷ = m(x) + ŷ +end diff --git a/test/runtests.jl b/test/runtests.jl index ae4f3ed..4ef81ee 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,76 +2,88 @@ using EvidentialFlux using Flux using Test -@testset "EvidentialFlux.jl" begin - # Creating a model and a forward pass +@testset "EvidentialFlux.jl - Classification" begin + # Creating a model and a forward pass - ninp, nout = 3, 5 - m = NIG(ninp => nout) - x = randn(Float32, 3, 10) + ninp, nout = 3, 5 + m = DIR(ninp => nout) + x = randn(Float32, 3, 10) + ŷ = m(x) + @test size(ŷ) == (5, 10) + # The α all have to be ≥ 1 + @test all(≥(1), ŷ) +end + +@testset "EvidentialFlux.jl - Regression" begin + # Creating a model and a forward pass + + ninp, nout = 3, 5 + m = NIG(ninp => nout) + x = randn(Float32, 3, 10) + ŷ = m(x) + @test size(ŷ) == (20, 10) + # The ν, α, and β all have to be positive + @test ŷ[6:20, :] == abs.(ŷ[6:20, :]) + # The α all have to be ≥ 1 + @test all(≥(1), ŷ[11:15, :]) + + # Testing backward pass + oldW = similar(m.W) + oldW .= m.W + loss(y, ŷ) = sum(abs, y - ŷ) + pars = Flux.params(m) + y = randn(Float32, nout, 10) # Target (fake) + grads = Flux.gradient(pars) do ŷ = m(x) - @test size(ŷ) == (20, 10) - # The ν, α, and β all have to be positive - @test ŷ[6:20, :] == abs.(ŷ[6:20, :]) - # The α all have to be ≥ 1 - @test all(≥(1), ŷ[11:15, :]) + γ = ŷ[1:nout, :] + loss(y, γ) + end + # Test that we can update the weights based on gradients + opt = Descent(0.1) + Flux.Optimise.update!(opt, pars, grads) + @test m.W != oldW - # Testing backward pass - oldW = similar(m.W) - oldW .= m.W - loss(y, ŷ) = sum(abs, y - ŷ) - pars = Flux.params(m) - y = randn(Float32, nout, 10) # Target (fake) + # Testing convergence + ninp, nout = 3, 1 + m = NIG(ninp => nout) + x = Float32.(collect(1:0.1:10)) + x = cat(x', x' .- 10, x' .+ 5, dims = 1) + # y = 1 * sin.(x[1, :]) .- 3 * sin.(x[2, :]) .+ 2 * cos.(x[3, :]) .+ randn(Float32, 91) + y = 1 * x[1, :] .- 3 * x[2, :] .+ 2 * x[3, :] .+ randn(Float32, 91) + #scatterplot(x[1, :], y, width = 90, height = 30) + pars = Flux.params(m) + opt = AdamW(0.005) + trnlosses = zeros(Float32, 1000) + for i in 1:1000 + local trnloss grads = Flux.gradient(pars) do - ŷ = m(x) - γ = ŷ[1:nout, :] - loss(y, γ) + ŷ = m(x) + γ, ν, α, β = ŷ[1, :], ŷ[2, :], ŷ[3, :], ŷ[4, :] + trnloss = sum(nigloss(y, γ, ν, α, β, 0.1, 1e-4)) end + trnlosses[i] = trnloss # Test that we can update the weights based on gradients - opt = Descent(0.1) Flux.Optimise.update!(opt, pars, grads) - @test m.W != oldW - - # Testing convergence - ninp, nout = 3, 1 - m = NIG(ninp => nout) - x = Float32.(collect(1:0.1:10)) - x = cat(x', x' .- 10, x' .+ 5, dims = 1) - # y = 1 * sin.(x[1, :]) .- 3 * sin.(x[2, :]) .+ 2 * cos.(x[3, :]) .+ randn(Float32, 91) - y = 1 * x[1, :] .- 3 * x[2, :] .+ 2 * x[3, :] .+ randn(Float32, 91) - #scatterplot(x[1, :], y, width = 90, height = 30) - pars = Flux.params(m) - opt = ADAMW(0.005) - trnlosses = zeros(Float32, 1000) - for i in 1:1000 - local trnloss - grads = Flux.gradient(pars) do - ŷ = m(x) - γ, ν, α, β = ŷ[1, :], ŷ[2, :], ŷ[3, :], ŷ[4, :] - trnloss = sum(nigloss(y, γ, ν, α, β, 0.1, 1e-4)) - end - trnlosses[i] = trnloss - # Test that we can update the weights based on gradients - Flux.Optimise.update!(opt, pars, grads) - #if i % 100 == 0 - # println("Epoch $i, Loss: $trnloss") - #end - end - #scatterplot(1:1000, trnlosses, width = 80) - @test trnlosses[10] > trnlosses[100] > trnlosses[300] + #if i % 100 == 0 + # println("Epoch $i, Loss: $trnloss") + #end + end + #scatterplot(1:1000, trnlosses, width = 80) + @test trnlosses[10] > trnlosses[100] > trnlosses[300] - # Test the nigloss and uncertainty function - ninp, nout = 3, 5 - m = NIG(ninp => nout) - x = randn(Float32, 3, 10) - y = randn(Float32, nout, 10) # Target (fake) - ŷ = m(x) - γ = ŷ[1:nout, :] - ν = ŷ[(nout+1):(nout*2), :] - α = ŷ[(nout*2+1):(nout*3), :] - β = ŷ[(nout*3+1):(nout*4), :] - myloss = nigloss(y, γ, ν, α, β, 0.1, 1e-4) - @test size(myloss) == (nout, 10) - myuncert = uncertainty(ν, α, β) - @test size(myuncert) == size(myloss) + # Test the nigloss and uncertainty function + ninp, nout = 3, 5 + m = NIG(ninp => nout) + x = randn(Float32, 3, 10) + y = randn(Float32, nout, 10) # Target (fake) + ŷ = m(x) + γ = ŷ[1:nout, :] + ν = ŷ[(nout+1):(nout*2), :] + α = ŷ[(nout*2+1):(nout*3), :] + β = ŷ[(nout*3+1):(nout*4), :] + myloss = nigloss(y, γ, ν, α, β, 0.1, 1e-4) + @test size(myloss) == (nout, 10) + myuncert = uncertainty(ν, α, β) + @test size(myuncert) == size(myloss) end