Skip to content

Commit

Permalink
Merge pull request #7 from JuliaReach/schillic/flux
Browse files Browse the repository at this point in the history
Conversion from Flux networks
  • Loading branch information
schillic authored Aug 16, 2023
2 parents b5c4077 + 1fbd2cf commit 92d302c
Show file tree
Hide file tree
Showing 9 changed files with 130 additions and 18 deletions.
17 changes: 17 additions & 0 deletions src/Architecture/ActivationFunction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,20 @@ Hyperbolic tangent activation.
struct Tanh <: ActivationFunction end

(::Tanh)(x) = tanh.(x)

# constant instances of each activation function
const _id = Id()
const _relu = ReLU()
const _sigmoid = Sigmoid()
const _tanh = Tanh()

function load_Flux_activations()
return quote
activations_Flux = Dict(Flux.identity => _id,
_id => Flux.identity,
Flux.relu => _relu,
_relu => Flux.relu,
Flux.sigmoid => _sigmoid,
_sigmoid => Flux.sigmoid)
end
end
4 changes: 4 additions & 0 deletions src/Architecture/Architecture.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ Module containing data structures to represent controllers.
"""
module Architecture

using Requires

export AbstractNeuralNetwork, AbstractLayerOp,
FeedforwardNetwork, DenseLayerOp,
dim_in, dim_out, dim,
Expand All @@ -16,4 +18,6 @@ include("ActivationFunction.jl")
include("DenseLayerOp.jl")
include("FeedforwardNetwork.jl")

include("init.jl")

end # module
24 changes: 24 additions & 0 deletions src/Architecture/DenseLayerOp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ A dense layer operation is an affine map followed by an activation function.
- `weights` -- weight matrix
- `bias` -- bias vector
- `activation` -- activation function
### Notes
Conversion from a `Flux.Dense` is supported.
"""
struct DenseLayerOp{F,M<:AbstractMatrix,B} <: AbstractLayerOp
weights::M
Expand Down Expand Up @@ -42,3 +46,23 @@ end
dim_in(L::DenseLayerOp) = size(L.weights, 2)

dim_out(L::DenseLayerOp) = length(L.bias)

function load_Flux_convert_layer()
return quote
function Base.convert(::Type{DenseLayerOp}, layer::Flux.Dense)
act = get(activations_Flux, layer.σ, nothing)
if isnothing(act)
throw(ArgumentError("unsupported activation function $(layer.σ)"))
end
return DenseLayerOp(layer.weight, layer.bias, act)
end

function Base.convert(::Type{Flux.Dense}, layer::DenseLayerOp)
act = get(activations_Flux, layer.activation, nothing)
if isnothing(act)
throw(ArgumentError("unsupported activation function $(layer.activation)"))
end
return Flux.Dense(layer.weights, layer.bias, act)
end
end
end
16 changes: 16 additions & 0 deletions src/Architecture/FeedforwardNetwork.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ operations.
The field `layers` contains the layer operations, so the number of layers is
`length(layers) + 1`.
Conversion from a `Flux.Chain` is supported.
"""
struct FeedforwardNetwork{L} <: AbstractNeuralNetwork
layers::L
Expand Down Expand Up @@ -47,3 +49,17 @@ end
dim_in(N::FeedforwardNetwork) = dim_in(first(N.layers))

dim_out(N::FeedforwardNetwork) = dim_out(last(N.layers))

function load_Flux_convert_network()
return quote
function Base.convert(::Type{FeedforwardNetwork}, chain::Flux.Chain)
layers = [convert(DenseLayerOp, layer) for layer in chain.layers]
return FeedforwardNetwork(layers)
end

function Base.convert(::Type{Flux.Chain}, net::FeedforwardNetwork)
layers = [convert(Flux.Dense, layer) for layer in net.layers]
return Flux.Chain(layers)
end
end
end
8 changes: 8 additions & 0 deletions src/Architecture/init.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# optional dependencies
function __init__()
@require Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" begin
eval(load_Flux_activations())
eval(load_Flux_convert_layer())
eval(load_Flux_convert_network())
end
end
28 changes: 11 additions & 17 deletions src/FileFormats/available_activations.jl
Original file line number Diff line number Diff line change
@@ -1,22 +1,16 @@
# always use the same instance of each activation function
const _id = Id()
const _relu = ReLU()
const _sigmoid = Sigmoid()
const _tanh = Tanh()

const available_activations = Dict(
# Id
"Id" => _id,
"linear" => _id,
"Linear" => _id,
"Affine" => _id,
"Id" => Architecture._id,
"linear" => Architecture._id,
"Linear" => Architecture._id,
"Affine" => Architecture._id,
# ReLU
"relu" => _relu,
"ReLU" => _relu,
"relu" => Architecture._relu,
"ReLU" => Architecture._relu,
# Sigmoid
"sigmoid" => _sigmoid,
"Sigmoid" => _sigmoid,
"σ" => _sigmoid,
"sigmoid" => Architecture._sigmoid,
"Sigmoid" => Architecture._sigmoid,
"σ" => Architecture._sigmoid,
# Tanh
"tanh" => _tanh,
"Tanh" => _tanh)
"tanh" => Architecture._tanh,
"Tanh" => Architecture._tanh)
44 changes: 44 additions & 0 deletions test/Architecture/Flux.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import Flux

l1 = Flux.Dense(1, 2, Flux.relu)
l1.weight .= 1, 2
l1.bias .= 3, 4

l2 = Flux.Dense(2, 3, Flux.sigmoid)
l2.weight .= [1 2; 3 4; 5 6]

l3 = Flux.Dense(3, 1)
l3.weight .= [1 2 3;]

l_unsupported = Flux.Dense(1 => 1, Flux.trelu)

c = Flux.Chain(l1, l2, l3)

activations = [ReLU(), Sigmoid(), Id()]

# `==` is not defined for Flux types
function compare_Flux_layer(l1, l2)
return l1.weight == l2.weight && l1.bias == l2.bias && l1.σ == l2.σ
end

# layer conversion
for (i, l) in enumerate(c.layers)
op = convert(DenseLayerOp, l)
@test op.weights == l.weight
@test op.bias == l.bias
@test op.activation == activations[i]

l_back = convert(Flux.Dense, op)
@test compare_Flux_layer(l, l_back)
end
@test_throws ArgumentError convert(DenseLayerOp, l_unsupported)

# network conversion
net = convert(FeedforwardNetwork, c)
c_back = convert(Flux.Chain, net)
@test length(net.layers) == length(c)
for (i, l) in enumerate(c.layers)
@test net.layers[i] == convert(DenseLayerOp, l)

@test compare_Flux_layer(l, c_back.layers[i])
end
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
[deps]
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
MAT = "23992714-dd62-5051-b70f-ba57cb901cac"
ONNX = "d0dd6a25-fac6-55c0-abf7-829e0c774d20"
ReachabilityBase = "379f33d0-9447-4353-bd03-d664070e549f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
YAML = "ddb6d928-2868-570f-bddf-ab3f9cf99eb6"

[compat]
Flux = "0.13 - 0.14"
MAT = "0.10"
ONNX = "0.2"
ReachabilityBase = "0.1.1 - 0.2"
Expand Down
5 changes: 4 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ using Test, ControllerFormats

using ControllerFormats: dim

import MAT, ONNX, YAML
import Flux, MAT, ONNX, YAML

@testset "Architecture" begin
@testset "DenseLayerOp" begin
Expand All @@ -11,6 +11,9 @@ import MAT, ONNX, YAML
@testset "FeedforwardNetwork" begin
include("Architecture/FeedforwardNetwork.jl")
end
@testset "Flux bridge" begin
include("Architecture/Flux.jl")
end
end

@testset "FileFormats" begin
Expand Down

0 comments on commit 92d302c

Please sign in to comment.