Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversion from Flux networks #7

Merged
merged 2 commits into from
Aug 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions src/Architecture/ActivationFunction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,20 @@ Hyperbolic tangent activation.
struct Tanh <: ActivationFunction end

(::Tanh)(x) = tanh.(x)

# constant instances of each activation function
const _id = Id()
const _relu = ReLU()
const _sigmoid = Sigmoid()
const _tanh = Tanh()

function load_Flux_activations()
return quote
activations_Flux = Dict(Flux.identity => _id,
_id => Flux.identity,
Flux.relu => _relu,
_relu => Flux.relu,
Flux.sigmoid => _sigmoid,
_sigmoid => Flux.sigmoid)
end
end
4 changes: 4 additions & 0 deletions src/Architecture/Architecture.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ Module containing data structures to represent controllers.
"""
module Architecture

using Requires

export AbstractNeuralNetwork, AbstractLayerOp,
FeedforwardNetwork, DenseLayerOp,
dim_in, dim_out, dim,
Expand All @@ -16,4 +18,6 @@ include("ActivationFunction.jl")
include("DenseLayerOp.jl")
include("FeedforwardNetwork.jl")

include("init.jl")

end # module
24 changes: 24 additions & 0 deletions src/Architecture/DenseLayerOp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ A dense layer operation is an affine map followed by an activation function.
- `weights` -- weight matrix
- `bias` -- bias vector
- `activation` -- activation function

### Notes

Conversion from a `Flux.Dense` is supported.
"""
struct DenseLayerOp{F,M<:AbstractMatrix,B} <: AbstractLayerOp
weights::M
Expand Down Expand Up @@ -42,3 +46,23 @@ end
dim_in(L::DenseLayerOp) = size(L.weights, 2)

dim_out(L::DenseLayerOp) = length(L.bias)

function load_Flux_convert_layer()
return quote
function Base.convert(::Type{DenseLayerOp}, layer::Flux.Dense)
act = get(activations_Flux, layer.σ, nothing)
if isnothing(act)
throw(ArgumentError("unsupported activation function $(layer.σ)"))
end
return DenseLayerOp(layer.weight, layer.bias, act)
end

function Base.convert(::Type{Flux.Dense}, layer::DenseLayerOp)
act = get(activations_Flux, layer.activation, nothing)
if isnothing(act)
throw(ArgumentError("unsupported activation function $(layer.activation)"))
end
return Flux.Dense(layer.weights, layer.bias, act)
end
end
end
16 changes: 16 additions & 0 deletions src/Architecture/FeedforwardNetwork.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ operations.

The field `layers` contains the layer operations, so the number of layers is
`length(layers) + 1`.

Conversion from a `Flux.Chain` is supported.
"""
struct FeedforwardNetwork{L} <: AbstractNeuralNetwork
layers::L
Expand Down Expand Up @@ -47,3 +49,17 @@ end
dim_in(N::FeedforwardNetwork) = dim_in(first(N.layers))

dim_out(N::FeedforwardNetwork) = dim_out(last(N.layers))

function load_Flux_convert_network()
return quote
function Base.convert(::Type{FeedforwardNetwork}, chain::Flux.Chain)
layers = [convert(DenseLayerOp, layer) for layer in chain.layers]
return FeedforwardNetwork(layers)
end

function Base.convert(::Type{Flux.Chain}, net::FeedforwardNetwork)
layers = [convert(Flux.Dense, layer) for layer in net.layers]
return Flux.Chain(layers)
end
end
end
8 changes: 8 additions & 0 deletions src/Architecture/init.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# optional dependencies
function __init__()
@require Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" begin
eval(load_Flux_activations())
eval(load_Flux_convert_layer())
eval(load_Flux_convert_network())
end
end
28 changes: 11 additions & 17 deletions src/FileFormats/available_activations.jl
Original file line number Diff line number Diff line change
@@ -1,22 +1,16 @@
# always use the same instance of each activation function
const _id = Id()
const _relu = ReLU()
const _sigmoid = Sigmoid()
const _tanh = Tanh()

const available_activations = Dict(
# Id
"Id" => _id,
"linear" => _id,
"Linear" => _id,
"Affine" => _id,
"Id" => Architecture._id,
"linear" => Architecture._id,
"Linear" => Architecture._id,
"Affine" => Architecture._id,
# ReLU
"relu" => _relu,
"ReLU" => _relu,
"relu" => Architecture._relu,
"ReLU" => Architecture._relu,
# Sigmoid
"sigmoid" => _sigmoid,
"Sigmoid" => _sigmoid,
"σ" => _sigmoid,
"sigmoid" => Architecture._sigmoid,
"Sigmoid" => Architecture._sigmoid,
"σ" => Architecture._sigmoid,
# Tanh
"tanh" => _tanh,
"Tanh" => _tanh)
"tanh" => Architecture._tanh,
"Tanh" => Architecture._tanh)
44 changes: 44 additions & 0 deletions test/Architecture/Flux.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import Flux

l1 = Flux.Dense(1, 2, Flux.relu)
l1.weight .= 1, 2
l1.bias .= 3, 4

l2 = Flux.Dense(2, 3, Flux.sigmoid)
l2.weight .= [1 2; 3 4; 5 6]

l3 = Flux.Dense(3, 1)
l3.weight .= [1 2 3;]

l_unsupported = Flux.Dense(1 => 1, Flux.trelu)

c = Flux.Chain(l1, l2, l3)

activations = [ReLU(), Sigmoid(), Id()]

# `==` is not defined for Flux types
function compare_Flux_layer(l1, l2)
return l1.weight == l2.weight && l1.bias == l2.bias && l1.σ == l2.σ
end

# layer conversion
for (i, l) in enumerate(c.layers)
op = convert(DenseLayerOp, l)
@test op.weights == l.weight
@test op.bias == l.bias
@test op.activation == activations[i]

l_back = convert(Flux.Dense, op)
@test compare_Flux_layer(l, l_back)
end
@test_throws ArgumentError convert(DenseLayerOp, l_unsupported)

# network conversion
net = convert(FeedforwardNetwork, c)
c_back = convert(Flux.Chain, net)
@test length(net.layers) == length(c)
for (i, l) in enumerate(c.layers)
@test net.layers[i] == convert(DenseLayerOp, l)

@test compare_Flux_layer(l, c_back.layers[i])
end
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
[deps]
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
MAT = "23992714-dd62-5051-b70f-ba57cb901cac"
ONNX = "d0dd6a25-fac6-55c0-abf7-829e0c774d20"
ReachabilityBase = "379f33d0-9447-4353-bd03-d664070e549f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
YAML = "ddb6d928-2868-570f-bddf-ab3f9cf99eb6"

[compat]
Flux = "0.13 - 0.14"
MAT = "0.10"
ONNX = "0.2"
ReachabilityBase = "0.1.1"
Expand Down
5 changes: 4 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ using Test, ControllerFormats

using ControllerFormats: dim

import MAT, ONNX, YAML
import Flux, MAT, ONNX, YAML

@testset "Architecture" begin
@testset "DenseLayerOp" begin
Expand All @@ -11,6 +11,9 @@ import MAT, ONNX, YAML
@testset "FeedforwardNetwork" begin
include("Architecture/FeedforwardNetwork.jl")
end
@testset "Flux bridge" begin
include("Architecture/Flux.jl")
end
end

@testset "FileFormats" begin
Expand Down