Skip to content

Commit

Permalink
feat: Implemented Mean-variance network and cleaned up the dependencies.
Browse files Browse the repository at this point in the history
  • Loading branch information
DoktorMike committed Jun 17, 2024
1 parent 3dfe5d7 commit 448e865
Show file tree
Hide file tree
Showing 6 changed files with 162 additions and 11 deletions.
7 changes: 1 addition & 6 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,11 @@ authors = ["Michael Green <micke.green@gmail.com> and contributors"]
version = "1.3.3"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"

[compat]
Flux = "0.13"
julia = "1.7"
NNlib = "0.8"
SpecialFunctions = "2.1"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

Expand Down
2 changes: 2 additions & 0 deletions src/EvidentialFlux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@ using SpecialFunctions
include("dense.jl")
export NIG
export DIR
export MVE

include("losses.jl")
export nigloss
export nigloss2
export dirloss
export mveloss

include("utils.jl")
export uncertainty
Expand Down
60 changes: 57 additions & 3 deletions src/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@ struct NIG{F, M <: AbstractMatrix, B}
end

function NIG((in, out)::Pair{<:Integer, <:Integer}, σ = NNlib.softplus;
init = Flux.glorot_uniform, bias = true)
init = Flux.glorot_uniform, bias = true)
NIG(init(out * 4, in), bias, σ)
end

Flux.@functor NIG
Flux.@layer NIG

function (a::NIG)(x::AbstractVecOrMat)
nout = Int(size(a.W, 1) / 4)
Expand Down Expand Up @@ -100,10 +100,64 @@ function DIR((in, out)::Pair{<:Integer, <:Integer}; init = Flux.glorot_uniform,
DIR(init(out, in), bias)
end

Flux.@functor DIR
Flux.@layer DIR

function (a::DIR)(x::AbstractVecOrMat)
NNlib.softplus.(a.W * x .+ a.b) .+ 1
end

(a::DIR)(x::AbstractArray) = reshape(a(reshape(x, size(x, 1), :)), :, size(x)[2:end]...)

"""
MVE(in => out, σ=NNlib.softplus; bias=true, init=Flux.glorot_uniform)
MVE(W::AbstractMatrix, [bias, σ])
Create a fully connected layer which implements the Mean-Variance Network which is just a Normal
distribution whose forward pass is simply given by:
y = W * x .+ bias
The input `x` should be a vector of length `in`, or batch of vectors represented
as an `in × N` matrix, or any array with `size(x,1) == in`.
The out `y` will be a vector of length `out*4`, or a batch with
`size(y) == (out*4, size(x)[2:end]...)`
The output will have applied the function `σ(y)` to each row/element of `y` except the first `out` ones.
Keyword `bias=false` will switch off trainable bias for the layer.
The initialisation of the weight matrix is `W = init(out*4, in)`, calling the function
given to keyword `init`, with default [`glorot_uniform`](@doc Flux.glorot_uniform).
The weight matrix and/or the bias vector (of length `out`) may also be provided explicitly.
Remember that in this case the number of rows in the weight matrix `W` MUST be a multiple of 2.
The same holds true for the `bias` vector.
# Arguments:
- `(in, out)`: number of input and output neurons
- `σ`: The function to use to secure positive only outputs which defaults to the softplus function.
- `init`: The function to use to initialise the weight matrix.
- `bias`: Whether to include a trainable bias vector.
"""
struct MVE{F, M <: AbstractMatrix, B}
W::M
b::B
σ::F
function MVE(W::M, b = true, σ::F = NNlib.softplus) where {M <: AbstractMatrix, F}
b = Flux.create_bias(W, b, size(W, 1))
return new{F, M, typeof(b)}(W, b, σ)
end
end

function MVE(
(in, out)::Pair{<:Integer, <:Integer}, σ = NNlib.softplus; init = Flux.glorot_uniform, bias = true)
MVE(init(out * 2, in), bias, σ)
end

Flux.@layer MVE

function (a::MVE)(x::AbstractVecOrMat)
nout = Int(size(a.W, 1) / 2)
o = a.W * x .+ a.b
μ = o[1:nout, :]
s = a.σ.(o[(nout + 1):(nout * 2), :])
return vcat(μ, s)
end

(a::MVE)(x::AbstractArray) = reshape(a(reshape(x, size(x, 1), :)), :, size(x)[2:end]...)
26 changes: 26 additions & 0 deletions src/losses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,3 +123,29 @@ function dirloss(y, α, t)
#sum(loss .+ λₜ .* reg, dims = 2)
sum(loss .+ λₜ .* reg)
end

"""
mveloss(y, μ, σ)
Calculates the Mean-Variance loss for a Normal distribution. This is merely the negative log likelihood.
This loss should be used with the MVE network type.
# Arguments:
- `y`: targets
- `μ`: the predicted mean
- `σ`: the predicted variance
"""
mveloss(y, μ, σ) = (1 / 2) * (((y - μ) .^ 2) ./ σ + log.(σ))

"""
mveloss(y, μ, σ, β)
DOCSTRING
# Arguments:
- `y`: targets
- `μ`: the predicted mean
- `σ`: the predicted variance
- `β`: used to increase or decrease the effect of the predicted variance on the loss
"""
mveloss(y, μ, σ, β) = mveloss(y, μ, σ) .* ignore_derivatives(σ) .^ β
9 changes: 8 additions & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,18 @@ function predict(::Type{<:NIG}, m, x)
nout = Int(size(m[end].W)[1] / 4)
= m(x)
γ, ν, α, β = ŷ[1:nout, :], ŷ[(nout + 1):(2 * nout), :],
ŷ[(2 * nout + 1):(3 * nout), :], ŷ[(3 * nout + 1):(4 * nout), :]
ŷ[(2 * nout + 1):(3 * nout), :], ŷ[(3 * nout + 1):(4 * nout), :]
#return γ, uncertainty(ν, α, β), uncertainty(α, β)
γ, ν, α, β
end

function predict(::Type{<:MVE}, m, x)
nout = Int(size(m[end].W)[1] / 2)
= m(x)
μ, σ = ŷ[1:nout, :], ŷ[(nout + 1):(2 * nout), :]
μ, σ
end

function predict(::Type{<:DIR}, m, x)
= m(x)
Expand Down
69 changes: 68 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ using Test
@test all((1), ŷ)
end

@testset "EvidentialFlux.jl - Regression" begin
@testset "EvidentialFlux.jl - NIG Regression" begin
# Creating a model and a forward pass

ninp, nout = 3, 5
Expand Down Expand Up @@ -86,3 +86,70 @@ end
myuncert = uncertainty(ν, α, β)
@test size(myuncert) == size(myloss)
end

@testset "EvidentialFlux.jl - MVE Regression" begin
# Creating a model and a forward pass

ninp, nout = 3, 5
m = MVE(ninp => nout)
x = randn(Float32, 3, 10)
= m(x)
@test size(ŷ) == (2 * nout, 10)
# The σ all have to be positive
@test ŷ[6:10, :] == abs.(ŷ[6:10, :])

# Testing backward pass
oldW = similar(m.W)
oldW .= m.W
loss(y, ŷ) = sum(abs, y - ŷ)
pars = Flux.params(m)
y = randn(Float32, nout, 10) # Target (fake)
grads = Flux.gradient(pars) do
= m(x)
γ = ŷ[1:nout, :]
loss(y, γ)
end
# Test that we can update the weights based on gradients
opt = Descent(0.1)
Flux.Optimise.update!(opt, pars, grads)
@test m.W != oldW

# Testing convergence
ninp, nout = 3, 1
m = MVE(ninp => nout)
x = Float32.(collect(1:0.1:10))
x = cat(x', x' .- 10, x' .+ 5, dims = 1)
# y = 1 * sin.(x[1, :]) .- 3 * sin.(x[2, :]) .+ 2 * cos.(x[3, :]) .+ randn(Float32, 91)
y = 1 * x[1, :] .- 3 * x[2, :] .+ 2 * x[3, :] .+ randn(Float32, 91)
#scatterplot(x[1, :], y, width = 90, height = 30)
pars = Flux.params(m)
opt = AdamW(0.005)
trnlosses = zeros(Float32, 1000)
for i in 1:1000
local trnloss
grads = Flux.gradient(pars) do
= m(x)
μ, σ = ŷ[1, :], ŷ[2, :]
trnloss = sum(mveloss(y, μ, σ))
end
trnlosses[i] = trnloss
# Test that we can update the weights based on gradients
Flux.Optimise.update!(opt, pars, grads)
#if i % 100 == 0
# println("Epoch $i, Loss: $trnloss")
#end
end
#scatterplot(1:1000, trnlosses, width = 80)
@test trnlosses[10] > trnlosses[100] > trnlosses[300]

# Test the nigloss and uncertainty function
ninp, nout = 3, 5
m = MVE(ninp => nout)
x = randn(Float32, 3, 10)
y = randn(Float32, nout, 10) # Target (fake)
= m(x)
μ = ŷ[1:nout, :]
σ = ŷ[(nout + 1):(nout * 2), :]
myloss = mveloss(y, μ, σ)
@test size(myloss) == (nout, 10)
end

0 comments on commit 448e865

Please sign in to comment.