From 448e865900420055573e442702135807eaddefdd Mon Sep 17 00:00:00 2001 From: Michael Green Date: Mon, 17 Jun 2024 13:26:21 +0200 Subject: [PATCH] feat: Implemented Mean-variance network and cleaned up the dependencies. --- Project.toml | 7 +---- src/EvidentialFlux.jl | 2 ++ src/dense.jl | 60 +++++++++++++++++++++++++++++++++++-- src/losses.jl | 26 ++++++++++++++++ src/utils.jl | 9 +++++- test/runtests.jl | 69 ++++++++++++++++++++++++++++++++++++++++++- 6 files changed, 162 insertions(+), 11 deletions(-) diff --git a/Project.toml b/Project.toml index 566e821..65cc869 100644 --- a/Project.toml +++ b/Project.toml @@ -4,16 +4,11 @@ authors = ["Michael Green and contributors"] version = "1.3.3" [deps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" -[compat] -Flux = "0.13" -julia = "1.7" -NNlib = "0.8" -SpecialFunctions = "2.1" - [extras] Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/src/EvidentialFlux.jl b/src/EvidentialFlux.jl index 4f157bb..cf5cb59 100644 --- a/src/EvidentialFlux.jl +++ b/src/EvidentialFlux.jl @@ -9,11 +9,13 @@ using SpecialFunctions include("dense.jl") export NIG export DIR +export MVE include("losses.jl") export nigloss export nigloss2 export dirloss +export mveloss include("utils.jl") export uncertainty diff --git a/src/dense.jl b/src/dense.jl index 5672d69..e9ea72d 100644 --- a/src/dense.jl +++ b/src/dense.jl @@ -36,11 +36,11 @@ struct NIG{F, M <: AbstractMatrix, B} end function NIG((in, out)::Pair{<:Integer, <:Integer}, σ = NNlib.softplus; - init = Flux.glorot_uniform, bias = true) + init = Flux.glorot_uniform, bias = true) NIG(init(out * 4, in), bias, σ) end -Flux.@functor NIG +Flux.@layer NIG function (a::NIG)(x::AbstractVecOrMat) nout = Int(size(a.W, 1) / 4) @@ -100,10 +100,64 @@ function DIR((in, out)::Pair{<:Integer, <:Integer}; init = Flux.glorot_uniform, DIR(init(out, in), bias) end -Flux.@functor DIR +Flux.@layer DIR function (a::DIR)(x::AbstractVecOrMat) NNlib.softplus.(a.W * x .+ a.b) .+ 1 end (a::DIR)(x::AbstractArray) = reshape(a(reshape(x, size(x, 1), :)), :, size(x)[2:end]...) + +""" + MVE(in => out, σ=NNlib.softplus; bias=true, init=Flux.glorot_uniform) + MVE(W::AbstractMatrix, [bias, σ]) + +Create a fully connected layer which implements the Mean-Variance Network which is just a Normal +distribution whose forward pass is simply given by: + + y = W * x .+ bias + +The input `x` should be a vector of length `in`, or batch of vectors represented +as an `in × N` matrix, or any array with `size(x,1) == in`. +The out `y` will be a vector of length `out*4`, or a batch with +`size(y) == (out*4, size(x)[2:end]...)` +The output will have applied the function `σ(y)` to each row/element of `y` except the first `out` ones. +Keyword `bias=false` will switch off trainable bias for the layer. +The initialisation of the weight matrix is `W = init(out*4, in)`, calling the function +given to keyword `init`, with default [`glorot_uniform`](@doc Flux.glorot_uniform). +The weight matrix and/or the bias vector (of length `out`) may also be provided explicitly. +Remember that in this case the number of rows in the weight matrix `W` MUST be a multiple of 2. +The same holds true for the `bias` vector. + +# Arguments: +- `(in, out)`: number of input and output neurons +- `σ`: The function to use to secure positive only outputs which defaults to the softplus function. +- `init`: The function to use to initialise the weight matrix. +- `bias`: Whether to include a trainable bias vector. +""" +struct MVE{F, M <: AbstractMatrix, B} + W::M + b::B + σ::F + function MVE(W::M, b = true, σ::F = NNlib.softplus) where {M <: AbstractMatrix, F} + b = Flux.create_bias(W, b, size(W, 1)) + return new{F, M, typeof(b)}(W, b, σ) + end +end + +function MVE( + (in, out)::Pair{<:Integer, <:Integer}, σ = NNlib.softplus; init = Flux.glorot_uniform, bias = true) + MVE(init(out * 2, in), bias, σ) +end + +Flux.@layer MVE + +function (a::MVE)(x::AbstractVecOrMat) + nout = Int(size(a.W, 1) / 2) + o = a.W * x .+ a.b + μ = o[1:nout, :] + s = a.σ.(o[(nout + 1):(nout * 2), :]) + return vcat(μ, s) +end + +(a::MVE)(x::AbstractArray) = reshape(a(reshape(x, size(x, 1), :)), :, size(x)[2:end]...) diff --git a/src/losses.jl b/src/losses.jl index 9a45246..e6dff87 100644 --- a/src/losses.jl +++ b/src/losses.jl @@ -123,3 +123,29 @@ function dirloss(y, α, t) #sum(loss .+ λₜ .* reg, dims = 2) sum(loss .+ λₜ .* reg) end + +""" + mveloss(y, μ, σ) + +Calculates the Mean-Variance loss for a Normal distribution. This is merely the negative log likelihood. +This loss should be used with the MVE network type. + +# Arguments: +- `y`: targets +- `μ`: the predicted mean +- `σ`: the predicted variance +""" +mveloss(y, μ, σ) = (1 / 2) * (((y - μ) .^ 2) ./ σ + log.(σ)) + +""" + mveloss(y, μ, σ, β) + +DOCSTRING + +# Arguments: +- `y`: targets +- `μ`: the predicted mean +- `σ`: the predicted variance +- `β`: used to increase or decrease the effect of the predicted variance on the loss +""" +mveloss(y, μ, σ, β) = mveloss(y, μ, σ) .* ignore_derivatives(σ) .^ β diff --git a/src/utils.jl b/src/utils.jl index 4312125..b972117 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -102,11 +102,18 @@ function predict(::Type{<:NIG}, m, x) nout = Int(size(m[end].W)[1] / 4) ŷ = m(x) γ, ν, α, β = ŷ[1:nout, :], ŷ[(nout + 1):(2 * nout), :], - ŷ[(2 * nout + 1):(3 * nout), :], ŷ[(3 * nout + 1):(4 * nout), :] + ŷ[(2 * nout + 1):(3 * nout), :], ŷ[(3 * nout + 1):(4 * nout), :] #return γ, uncertainty(ν, α, β), uncertainty(α, β) γ, ν, α, β end +function predict(::Type{<:MVE}, m, x) + nout = Int(size(m[end].W)[1] / 2) + ŷ = m(x) + μ, σ = ŷ[1:nout, :], ŷ[(nout + 1):(2 * nout), :] + μ, σ +end + function predict(::Type{<:DIR}, m, x) ŷ = m(x) ŷ diff --git a/test/runtests.jl b/test/runtests.jl index 8fbcf98..7b85075 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -14,7 +14,7 @@ using Test @test all(≥(1), ŷ) end -@testset "EvidentialFlux.jl - Regression" begin +@testset "EvidentialFlux.jl - NIG Regression" begin # Creating a model and a forward pass ninp, nout = 3, 5 @@ -86,3 +86,70 @@ end myuncert = uncertainty(ν, α, β) @test size(myuncert) == size(myloss) end + +@testset "EvidentialFlux.jl - MVE Regression" begin + # Creating a model and a forward pass + + ninp, nout = 3, 5 + m = MVE(ninp => nout) + x = randn(Float32, 3, 10) + ŷ = m(x) + @test size(ŷ) == (2 * nout, 10) + # The σ all have to be positive + @test ŷ[6:10, :] == abs.(ŷ[6:10, :]) + + # Testing backward pass + oldW = similar(m.W) + oldW .= m.W + loss(y, ŷ) = sum(abs, y - ŷ) + pars = Flux.params(m) + y = randn(Float32, nout, 10) # Target (fake) + grads = Flux.gradient(pars) do + ŷ = m(x) + γ = ŷ[1:nout, :] + loss(y, γ) + end + # Test that we can update the weights based on gradients + opt = Descent(0.1) + Flux.Optimise.update!(opt, pars, grads) + @test m.W != oldW + + # Testing convergence + ninp, nout = 3, 1 + m = MVE(ninp => nout) + x = Float32.(collect(1:0.1:10)) + x = cat(x', x' .- 10, x' .+ 5, dims = 1) + # y = 1 * sin.(x[1, :]) .- 3 * sin.(x[2, :]) .+ 2 * cos.(x[3, :]) .+ randn(Float32, 91) + y = 1 * x[1, :] .- 3 * x[2, :] .+ 2 * x[3, :] .+ randn(Float32, 91) + #scatterplot(x[1, :], y, width = 90, height = 30) + pars = Flux.params(m) + opt = AdamW(0.005) + trnlosses = zeros(Float32, 1000) + for i in 1:1000 + local trnloss + grads = Flux.gradient(pars) do + ŷ = m(x) + μ, σ = ŷ[1, :], ŷ[2, :] + trnloss = sum(mveloss(y, μ, σ)) + end + trnlosses[i] = trnloss + # Test that we can update the weights based on gradients + Flux.Optimise.update!(opt, pars, grads) + #if i % 100 == 0 + # println("Epoch $i, Loss: $trnloss") + #end + end + #scatterplot(1:1000, trnlosses, width = 80) + @test trnlosses[10] > trnlosses[100] > trnlosses[300] + + # Test the nigloss and uncertainty function + ninp, nout = 3, 5 + m = MVE(ninp => nout) + x = randn(Float32, 3, 10) + y = randn(Float32, nout, 10) # Target (fake) + ŷ = m(x) + μ = ŷ[1:nout, :] + σ = ŷ[(nout + 1):(nout * 2), :] + myloss = mveloss(y, μ, σ) + @test size(myloss) == (nout, 10) +end