Skip to content

Commit

Permalink
feat: Added a predict function that dispatches on the last layer of a…
Browse files Browse the repository at this point in the history
… Flux Chain.
  • Loading branch information
DoktorMike committed Jun 17, 2022
1 parent 116d328 commit f81861a
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 5 deletions.
6 changes: 6 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,15 @@ authors = ["Michael Green <micke.green@gmail.com> and contributors"]
version = "0.1.0"

[deps]
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
MLDataUtils = "cc2ba9b6-d476-5e6d-8eaf-a92d5412d41d"
MLFlowClient = "64a0f543-368b-4a9a-827a-e71edb2a0b83"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Wandb = "ad70616a-06c9-5745-b1f1-6a5f42545108"

[compat]
julia = "1"
Expand Down
7 changes: 4 additions & 3 deletions examples/regression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ Predicts the output of the model m on the input x.
function predict(m, x)
= m(x)
γ, ν, α, β = ŷ[1, :], ŷ[2, :], ŷ[3, :], ŷ[4, :]
γ, uncertainty(ν, α, β)
(pred=γ, eu=uncertainty(ν, α, β), au=uncertainty(α, β))
end

mae(y, ŷ) = Statistics.mean(abs.(y - ŷ))
Expand Down Expand Up @@ -52,7 +52,7 @@ end
# The convergance plot shows the loss function converges to a local minimum
scatterplot(1:epochs, trnlosses, width = 80)
# And the MAE corresponds to the noise we added in the target
ŷ, u = predict(m, x')
ŷ, u, au = predict(m, x')
println("MAE: $(mae(y, ŷ))")

# Correlation plot confirms the fit
Expand All @@ -69,8 +69,9 @@ scatterplot!(p, x, y, marker = :x, color = :blue)
## Out of sample predictions

x = Float32.(collect(0:0.1:3π));
ŷ, u = predict(m, x');
ŷ, u, au = predict(m, x');

p = scatterplot(x, sin.(x), width = 80, height = 30, marker = "o");
scatterplot!(p, x, ŷ, color = :red, marker = "x");
scatterplot!(p, x, u)
scatterplot!(p, x, au)
3 changes: 1 addition & 2 deletions src/EvidentialFlux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ using SpecialFunctions

# Write your package code here.

hello() = println("Hello from EvidentialFlux!")

include("dense.jl")
export NIG

Expand All @@ -16,5 +14,6 @@ export nigloss

include("utils.jl")
export uncertainty
export predict

end
13 changes: 13 additions & 0 deletions src/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,16 @@ end

(a::NIG)(x::AbstractArray) = reshape(a(reshape(x, size(x, 1), :)), :, size(x)[2:end]...)

#function predict(m::NIG, x::AbstractVecOrMat)
# nout = Int(size(m.W, 1) / 4)
# o = m.W * x .+ m.b
# γ = o[1:nout, :]
# ν = o[(nout+1):(nout*2), :]
# ν = m.σ.(ν)
# α = o[(nout*2+1):(nout*3), :]
# α = m.σ.(α) .+ 1
# β = o[(nout*3+1):(nout*4), :]
# β = m.σ.(β)
# return γ, uncertainty(ν, α, β), uncertainty(α, β)
#end

19 changes: 19 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,22 @@ Calculates the aleatoric uncertainty of the predictions from the Normal Inverse
"""
uncertainty(α, β) = @. β /- 1)

"""
predict(m, x)
Returns the predictions along with the epistemic and aleatoric uncertainty.
# Arguments:
- `m`: the model which has to have the last layer be Normal Inverse Gamma(NIG) layer
- `x`: the input data which has to be given as an array or vector
"""
predict(m, x) = predict(typeof(m.layers[end]), m, x)

function predict(::Type{<:NIG}, m, x)
#(pred = γ, eu = uncertainty(ν, α, β), au = uncertainty(α, β))
nout = Int(size(m[end].W)[1] / 4)
= m(x)
γ, ν, α, β = ŷ[1:nout, :], ŷ[(nout+1):(2*nout), :], ŷ[(2*nout+1):(3*nout), :], ŷ[(3*nout+1):(4*nout), :]
#return γ, uncertainty(ν, α, β), uncertainty(α, β)
γ, ν, α, β
end

0 comments on commit f81861a

Please sign in to comment.