Skip to content

Commit

Permalink
feat: Added several tests for correctness.
Browse files Browse the repository at this point in the history
  • Loading branch information
DoktorMike committed May 5, 2022
1 parent c245be7 commit b4f057d
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 18 deletions.
50 changes: 39 additions & 11 deletions src/dense.jl
Original file line number Diff line number Diff line change
@@ -1,30 +1,58 @@
"""
NIG(in => out, σ=NNlib.softplut; bias=true, init=Flux.glorot_uniform)
NIG(W::AbstractMatrix, [bias, σ])
Create a fully connected layer which implements the NormalInverseGamma Evidential distribution
whose forward pass is simply given by:
y = W * x .+ bias
The input `x` should be a vector of length `in`, or batch of vectors represented
as an `in × N` matrix, or any array with `size(x,1) == in`.
The out `y` will be a vector of length `out*4`, or a batch with
`size(y) == (out*4, size(x)[2:end]...)`
The output will have applied the function `σ(y)` to each row/element of `y` except the first `out` ones.
Keyword `bias=false` will switch off trainable bias for the layer.
The initialisation of the weight matrix is `W = init(out*4, in)`, calling the function
given to keyword `init`, with default [`glorot_uniform`](@doc Flux.glorot_uniform).
The weight matrix and/or the bias vector (of length `out`) may also be provided explicitly.
Remember that in this case the number of rows in the weight matrix `W` MUST be a multiple of 4.
The same holds true for the `bias` vector.
# Arguments:
- `(in, out)`: number of input and output neurons
- `σ`: The function to use to secure positive only outputs which defaults to the softplus function.
- `init`: The function to use to initialise the weight matrix.
- `bias`: Whether to include a trainable bias vector.
"""
struct NIG{F,M<:AbstractMatrix,B}
W::M
b::B
σ::F
function NIG(W::M, b = true, σ::F = identity) where {M<:AbstractMatrix,F}
function NIG(W::M, b = true, σ::F = NNlib.softplus) where {M<:AbstractMatrix,F}
b = Flux.create_bias(W, b, size(W, 1))
return new{F,M,typeof(b)}(W, b, σ)
end
end

function NIG((in, out)::Pair{<:Integer,<:Integer}, σ = identity;
function NIG((in, out)::Pair{<:Integer,<:Integer}, σ = NNlib.softplus;
init = Flux.glorot_uniform, bias = true)
NIG(init(out * 4, in), bias, σ)
end

Flux.@functor NIG

function (a::NIG)(x::AbstractVecOrMat)
o = a.σ.(a.W * x .+ a.b)
μ = @view o[1:4, :]
ν = @view o[5:8, :]
ν .= NNlib.softplus.(ν)
α = @view o[9:12, :]
α .= NNlib.softplus.(α) .+ 1
β = @view o[13:16, :]
β .= NNlib.softplus.(β)
return o
nout = Int(size(a.W, 1) / 4)
o = a.W * x .+ a.b
γ = o[1:nout, :]
ν = o[(nout+1):(nout*2), :]
ν = a.σ.(ν)
α = o[(nout*2+1):(nout*3), :]
α = a.σ.(α) .+ 1
β = o[(nout*3+1):(nout*4), :]
β = a.σ.(β)
return vcat(γ, ν, α, β)
end

(a::NIG)(x::AbstractArray) = reshape(a(reshape(x, size(x, 1), :)), :, size(x)[2:end]...)
65 changes: 58 additions & 7 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,63 @@
using EvidentialFlux
using Flux
using Test

@testset "EvidentialFlux.jl" begin
# Write your tests here.
m = NIG(3 => 4)
x = randn(Float32, 3, 10)
y = m(x)
@test size(y) == (16, 10)
@test y[5:16, :] == abs.(y[5:16, :])
#@test y[9:12, :] .> 1
# Creating a model and a forward pass

ninp, nout = 3, 5
m = NIG(ninp => nout)
x = randn(Float32, 3, 10)
= m(x)
@test size(ŷ) == (20, 10)
# The ν, α, and β all have to be positive
@test ŷ[6:20, :] == abs.(ŷ[6:20, :])
# The α all have to be ≥ 1
@test all((1), ŷ[11:15, :])

# Testing backward pass
oldW = similar(m.W)
oldW .= m.W
loss(y, ŷ) = sum(abs, y - ŷ)
pars = Flux.params(m)
y = randn(Float32, nout, 10) # Target (fake)
grads = Flux.gradient(pars) do
= m(x)
γ = ŷ[1:nout, :]
loss(y, γ)
end
# Test that we can update the weights based on gradients
opt = Descent(0.1)
Flux.Optimise.update!(opt, pars, grads)
@test m.W != oldW

# Testing convergence
ninp, nout = 3, 1
m = NIG(ninp => nout)
x = Float32.(collect(1:0.1:10))
x = cat(x', x' .- 10, x' .+ 5, dims = 1)
# y = 1 * sin.(x[1, :]) .- 3 * sin.(x[2, :]) .+ 2 * cos.(x[3, :]) .+ randn(Float32, 91)
y = 1 * x[1, :] .- 3 * x[2, :] .+ 2 * x[3, :] .+ randn(Float32, 91)
#scatterplot(x[1, :], y, width = 90, height = 30)
#loss(y, ŷ) = sum(abs, y - ŷ)
pars = Flux.params(m)
opt = ADAMW(0.005)
trnlosses = zeros(Float32, 1000)
for i in 1:1000
local trnloss
grads = Flux.gradient(pars) do
= m(x)
γ = ŷ[1, :]
trnloss = loss(y, γ)
end
trnlosses[i] = trnloss
# Test that we can update the weights based on gradients
Flux.Optimise.update!(opt, pars, grads)
#if i % 100 == 0
# println("Epoch $i, Loss: $trnloss")
#end
end
#scatterplot(1:1000, trnlosses)
@test trnlosses[10] > trnlosses[100] > trnlosses[300]

end

0 comments on commit b4f057d

Please sign in to comment.