From 69143df13dd5900d3f3364d2134e6c4f2f2e514e Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Tue, 12 Jan 2021 11:00:04 -0600 Subject: [PATCH 01/15] Add initial implementation of Parallel --- src/Flux.jl | 10 ++++++---- src/layers/basic.jl | 45 ++++++++++++++++++++++++++++++++++++++++++++ test/layers/basic.jl | 12 ++++++++++++ test/outputsize.jl | 3 +++ 4 files changed, 66 insertions(+), 4 deletions(-) diff --git a/src/Flux.jl b/src/Flux.jl index b7851138d3..ff067803be 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -11,10 +11,12 @@ using Zygote: Params, @adjoint, gradient, pullback, @nograd export gradient -export Chain, Dense, Maxout, RNN, LSTM, GRU, SamePad, Conv, CrossCor, ConvTranspose, - AdaptiveMaxPool, AdaptiveMeanPool, GlobalMaxPool, GlobalMeanPool, MaxPool, - MeanPool, flatten, DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, - InstanceNorm, GroupNorm, SkipConnection, params, fmap, cpu, gpu, f32, f64, +export Chain, Dense, Maxout, SkipConnection, Parallel, flatten, + RNN, LSTM, GRU, + SamePad, Conv, CrossCor, ConvTranspose, DepthwiseConv, + AdaptiveMaxPool, AdaptiveMeanPool, GlobalMaxPool, GlobalMeanPool, MaxPool, MeanPool, + Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm, + params, fmap, cpu, gpu, f32, f64, testmode!, trainmode! include("optimise/Optimise.jl") diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 5a37bcb504..1d7b1e26ce 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -253,3 +253,48 @@ end function Base.show(io::IO, b::SkipConnection) print(io, "SkipConnection(", b.layers, ", ", b.connection, ")") end + +""" + Parallel(connection, layers...) + +Create a new 'Parallel' layer that passes a single input array each path in +`layers`, combining the output of each path with `connection`. +`connection` should be a reducible operator (i.e. it can be passed to `Base.reduce`). + +# Example +```jldoctest +julia> model = Chain( + Dense(1, 1), + Parallel( + Dense(1, 1), + Dense(1, 3), + Chain( + Dense(1, 5), + Dense(5, 2), + ) + ), + Dense(6, 1) +) +julia> model(rand(1)) +Float32[0.27] +``` +""" +struct Parallel{F, T} + connection::F + layers::T +end + +Parallel(connection, layers...) = Parallel(connection, layers) + +Flux.@functor Parallel + +(m::Parallel)(x::AbstractArray) = mapreduce(f -> f(x), m.connection, m.layers) + +Base.getindex(m::Parallel, i::Integer) = m.layers[i] +Base.getindex(m::Parallel, i::AbstractArray) = Parallel(m.connection, m.layers[i]...) + +function Base.show(io::IO, m::Parallel) + print(io, "Parallel(", m.connection, ", ") + join(io, m.layers, ", ") + print(io, ")") +end \ No newline at end of file diff --git a/test/layers/basic.jl b/test/layers/basic.jl index e1660812f0..c36ba68be6 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -106,4 +106,16 @@ import Flux: activations @test size(SkipConnection(Dense(10,10), (a,b) -> cat(a, b, dims = 2))(input)) == (10,4) end end + + @testset "Parallel" begin + @testset "zero sum" begin + input = randn(10, 10, 10, 10) + @test Parallel(+, x -> zeros(size(x)), identity)(input) == input + end + + @testset "concat size" begin + input = randn(10, 2) + @test size(Parallel((a, b) -> cat(a, b; dims=2), Dense(10, 10), identity)(input)) == (10, 4) + end + end end \ No newline at end of file diff --git a/test/outputsize.jl b/test/outputsize.jl index dc8ad3023b..ba183a985c 100644 --- a/test/outputsize.jl +++ b/test/outputsize.jl @@ -28,6 +28,9 @@ m = SkipConnection(Conv((3, 3), 3 => 16; pad = 1), (mx, x) -> cat(mx, x; dims = 3)) @test outputsize(m, (10, 10, 3, 1)) == (10, 10, 19, 1) + + m = Parallel((mx, x) -> cat(mx, x; dims = 3), Conv((3, 3), 3 => 16; pad = 1), identity) + @test outputsize(m, (10, 10, 3, 1)) == (10, 10, 19, 1) end @testset "activations" begin From 95ce8e725cdccfd84ed427f6f0ea9aa6f05c9fcf Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Tue, 12 Jan 2021 11:08:49 -0600 Subject: [PATCH 02/15] Update docstring --- src/layers/basic.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 1d7b1e26ce..56402966c6 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -266,6 +266,7 @@ Create a new 'Parallel' layer that passes a single input array each path in julia> model = Chain( Dense(1, 1), Parallel( + vcat, Dense(1, 1), Dense(1, 3), Chain( From 9ba1885e859058a28166a0636dd925bbf991bf1c Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Tue, 12 Jan 2021 11:12:25 -0600 Subject: [PATCH 03/15] Update docstring 2 --- src/layers/basic.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 56402966c6..6ac4f61e94 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -261,7 +261,8 @@ Create a new 'Parallel' layer that passes a single input array each path in `layers`, combining the output of each path with `connection`. `connection` should be a reducible operator (i.e. it can be passed to `Base.reduce`). -# Example +# Examples + ```jldoctest julia> model = Chain( Dense(1, 1), @@ -275,7 +276,7 @@ julia> model = Chain( ) ), Dense(6, 1) -) +); julia> model(rand(1)) Float32[0.27] ``` From 327aebe8882a3afe44ef600f682783ced45935fb Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Tue, 12 Jan 2021 11:21:27 -0600 Subject: [PATCH 04/15] Make Parallel source consistent with rest of code Co-authored-by: Atiyo Ghosh --- src/layers/basic.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 6ac4f61e94..5722756771 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -288,7 +288,7 @@ end Parallel(connection, layers...) = Parallel(connection, layers) -Flux.@functor Parallel +@functor Parallel (m::Parallel)(x::AbstractArray) = mapreduce(f -> f(x), m.connection, m.layers) @@ -299,4 +299,4 @@ function Base.show(io::IO, m::Parallel) print(io, "Parallel(", m.connection, ", ") join(io, m.layers, ", ") print(io, ")") -end \ No newline at end of file +end From e3b95f56464c70e82f0243146244b8f59ab40851 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Tue, 12 Jan 2021 11:33:14 -0600 Subject: [PATCH 05/15] Add PR feedback on Parallel --- src/layers/basic.jl | 20 ++++++-------------- test/layers/basic.jl | 2 +- 2 files changed, 7 insertions(+), 15 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 5722756771..f02ec996d1 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -257,26 +257,18 @@ end """ Parallel(connection, layers...) -Create a new 'Parallel' layer that passes a single input array each path in -`layers`, combining the output of each path with `connection`. -`connection` should be a reducible operator (i.e. it can be passed to `Base.reduce`). +Create a 'Parallel' layer that passes an input array to each path in +`layers`, reducing the output with `connection`. + +Equivalent to calling `reduce(connection, [l(x) for l in layers]...)`. # Examples ```jldoctest julia> model = Chain( Dense(1, 1), - Parallel( - vcat, - Dense(1, 1), - Dense(1, 3), - Chain( - Dense(1, 5), - Dense(5, 2), - ) - ), - Dense(6, 1) -); + Parallel(vcat, Dense(1, 1), Dense(1, 3), Chain(Dense(1, 5), Dense(5, 2))), + Dense(6, 1)); julia> model(rand(1)) Float32[0.27] ``` diff --git a/test/layers/basic.jl b/test/layers/basic.jl index c36ba68be6..f05b5b5100 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -118,4 +118,4 @@ import Flux: activations @test size(Parallel((a, b) -> cat(a, b; dims=2), Dense(10, 10), identity)(input)) == (10, 4) end end -end \ No newline at end of file +end From f139398a9d1584b65fa69ad7a879e9213c4296ee Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Tue, 12 Jan 2021 11:40:44 -0600 Subject: [PATCH 06/15] Added multiple input arg version of Parallel --- src/layers/basic.jl | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index f02ec996d1..08426655dc 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -269,8 +269,11 @@ julia> model = Chain( Dense(1, 1), Parallel(vcat, Dense(1, 1), Dense(1, 3), Chain(Dense(1, 5), Dense(5, 2))), Dense(6, 1)); -julia> model(rand(1)) -Float32[0.27] +julia> size(model(rand(1))) +(1,) +julia> model = Parallel(+, Dense(10, 2), Dense(5, 2)); +julia> size(model(rand(10), rand(5))) +(2,) ``` """ struct Parallel{F, T} @@ -283,6 +286,7 @@ Parallel(connection, layers...) = Parallel(connection, layers) @functor Parallel (m::Parallel)(x::AbstractArray) = mapreduce(f -> f(x), m.connection, m.layers) +(m::Parallel)(xs::Vararg{<:AbstractArray}) = mapreduce((f, x) -> f(x), m.connection, m.layers, xs) Base.getindex(m::Parallel, i::Integer) = m.layers[i] Base.getindex(m::Parallel, i::AbstractArray) = Parallel(m.connection, m.layers[i]...) From 3af5755bbef7c7ca71aaf5562547749167f8131d Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Tue, 12 Jan 2021 11:51:24 -0600 Subject: [PATCH 07/15] Fix docstring for Parallel --- src/layers/basic.jl | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 08426655dc..f1e642de47 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -265,13 +265,14 @@ Equivalent to calling `reduce(connection, [l(x) for l in layers]...)`. # Examples ```jldoctest -julia> model = Chain( - Dense(1, 1), - Parallel(vcat, Dense(1, 1), Dense(1, 3), Chain(Dense(1, 5), Dense(5, 2))), - Dense(6, 1)); +julia> model = Chain(Dense(1, 1), Parallel(vcat, Dense(1, 1), Dense(1, 3), Chain(Dense(1, 5), Dense(5, 2))), Dense(6, 1)); + julia> size(model(rand(1))) (1,) -julia> model = Parallel(+, Dense(10, 2), Dense(5, 2)); + +julia> model = Parallel(+, Dense(10, 2), Dense(5, 2)) +Parallel(+, Dense(10, 2), Dense(5, 2)) + julia> size(model(rand(10), rand(5))) (2,) ``` From 97c0682941ae3da001f23486bbacdc53e57676f0 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Tue, 12 Jan 2021 17:31:55 -0600 Subject: [PATCH 08/15] Make complex docstring example numbers distinct Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com> --- src/layers/basic.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index f1e642de47..5e13ffbc62 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -265,10 +265,10 @@ Equivalent to calling `reduce(connection, [l(x) for l in layers]...)`. # Examples ```jldoctest -julia> model = Chain(Dense(1, 1), Parallel(vcat, Dense(1, 1), Dense(1, 3), Chain(Dense(1, 5), Dense(5, 2))), Dense(6, 1)); +julia> model = Chain(Dense(3, 13, tanh), Parallel(vcat, Dense(13, 4), Chain(Dense(13, 7, tanh), Dense(7, 4))), Dense(8, 17)); -julia> size(model(rand(1))) -(1,) +julia> size(model(rand(3))) +(17,) julia> model = Parallel(+, Dense(10, 2), Dense(5, 2)) Parallel(+, Dense(10, 2), Dense(5, 2)) From 052e36d5ec501aa44d07118a8e18cc1052bb00a4 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Tue, 12 Jan 2021 17:36:02 -0600 Subject: [PATCH 09/15] Update docstring for multiple input case Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com> --- src/layers/basic.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 5e13ffbc62..fdadf9ba9d 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -260,7 +260,8 @@ end Create a 'Parallel' layer that passes an input array to each path in `layers`, reducing the output with `connection`. -Equivalent to calling `reduce(connection, [l(x) for l in layers]...)`. +Called with one input `x`, this is equivalent to `reduce(connection, [l(x) for l in layers])`. +If called with multiple inputs, they are `zip`ped with the layers, thus `Parallel(+, f, g)(x, y) = f(x) + g(y)`. # Examples From 1c660e60837c7ec7a8d90c322fafeb75a59133c0 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Wed, 13 Jan 2021 08:55:12 -0600 Subject: [PATCH 10/15] Update src/layers/basic.jl Co-authored-by: Dhairya Gandhi --- src/layers/basic.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index fdadf9ba9d..8974fef8e1 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -291,7 +291,7 @@ Parallel(connection, layers...) = Parallel(connection, layers) (m::Parallel)(xs::Vararg{<:AbstractArray}) = mapreduce((f, x) -> f(x), m.connection, m.layers, xs) Base.getindex(m::Parallel, i::Integer) = m.layers[i] -Base.getindex(m::Parallel, i::AbstractArray) = Parallel(m.connection, m.layers[i]...) +Base.getindex(m::Parallel, i::AbstractVector) = Parallel(m.connection, m.layers[i]...) function Base.show(io::IO, m::Parallel) print(io, "Parallel(", m.connection, ", ") From b2c46d54dcb7a4246554ac1335b3373411de0f0d Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Wed, 13 Jan 2021 09:26:41 -0600 Subject: [PATCH 11/15] Update src/layers/basic.jl Co-authored-by: Dhairya Gandhi --- src/layers/basic.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 8974fef8e1..cac2387b7f 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -266,7 +266,9 @@ If called with multiple inputs, they are `zip`ped with the layers, thus `Paralle # Examples ```jldoctest -julia> model = Chain(Dense(3, 13, tanh), Parallel(vcat, Dense(13, 4), Chain(Dense(13, 7, tanh), Dense(7, 4))), Dense(8, 17)); +julia> model = Chain(Dense(3, 5), + Parallel(vcat, Dense(5, 4), Chain(Dense(5, 7), Dense(7, 4))), + Dense(8, 17)); julia> size(model(rand(3))) (17,) From 3409ffe24c29e811ca2aed82289e56b2add3cefc Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Wed, 13 Jan 2021 11:11:18 -0600 Subject: [PATCH 12/15] Added docs for Join and Split --- docs/src/models/advanced.md | 107 ++++++++++++++++++++++++++++++++++++ 1 file changed, 107 insertions(+) diff --git a/docs/src/models/advanced.md b/docs/src/models/advanced.md index 557128aea4..9bfc985ddd 100644 --- a/docs/src/models/advanced.md +++ b/docs/src/models/advanced.md @@ -70,3 +70,110 @@ by simply deleting it from `ps`: ps = params(m) delete!(ps, m[2].b) ``` + +## Custom multiple input or output layer + +Sometimes a model needs to receive several separate inputs at once or produce several separate outputs at once. In other words, there multiple paths within this high-level layer, each processing a different input or producing a different output. A simple example of this in machine learning literature is the [inception module](https://www.cv-foundation.org/openaccess/content_cvpr_2016/papers/Szegedy_Rethinking_the_Inception_CVPR_2016_paper.pdf). + +Naively, we could have a struct that stores the weights of along each path and implement the joining/splitting in the forward pass function. But that would mean a new struct any time the operations along each path changes. Instead, this guide will show you how to construct a high-level layer (like [`Chain`](@ref)) that is made of multiple sub-layers for each path. + +### Multiple inputs: a custom `Join` layer + +Our custom `Join` layer will accept multiple inputs at once, pass each input through a separate path, then combine the results together. + +We start by defining a new struct, `Join`, that stores the different paths and a combine operation as its fields. +```julia +using Flux +using CUDA + +# custom join layer +struct Join{T, F} + combine::F + paths::T +end + +# allow Join(op, m1, m2, ...) as a constructor +Join(combine, paths...) = Join(combine, paths) +``` +Notice that we parameterized the type of the `paths` field. This is necessary for fast Julia code; in general, `T` might be a `Tuple` or `Vector`, but we don't need to pay attention to what it specifically is. The same goes for the `combine` field. + +The next step is to use [`Flux.@functor`](@ref) to make our struct behave like a Flux layer. This is important so that calling `params` on a `Join` returns the underlying weight arrays on each path. +```julia +Flux.@functor Join +``` + +Finally, we define the forward pass. For `Join`, this means applying each `path` in `paths` to each input array, then using `combine` to merge the results. +```julia +(m::Join)(xs::Tuple) = m.combine(map((f, x) -> f(x), m.paths, xs)) +(m::Join)(xs...) = m(xs) +``` + +Lastly, we can test our new layer. Thanks to the proper abstractions in Julia, our layer works on GPU arrays out of the box! +```julia +model = Chain( + Join(vcat, + Chain( + Dense(1, 5), + Dense(5, 1) + ), + Dense(1, 2), + Dense(1, 1), + ), + Dense(4, 1) + ) |> gpu + +xs = map(gpu, (rand(1), rand(1), rand(1))) + +model(xs) +# returns a single float vector with one value +``` + +### Multiple outputs: a custom `Split` layer + +Our custom `Split` layer will accept a single input, then pass the input through a separate path to produce multiple outputs. + +We start by following the same steps as the `Join` layer: define a struct, use [`Flux.@functor`](@ref), and define the forward pass. +```julia +using Flux +using CUDA + +# custom split layer +struct Split{T} + paths::T +end + +Split(paths...) = Split(paths) + +Flux.@functor Split + +(m::Split)(x::AbstractArray) = tuple(map(f -> f(x), m.paths)) +``` + +Now we can test to see that our `Split` does indeed produce multiple outputs. +```julia +model = Chain( + Dense(10, 5), + CustomSplit( + Dense(5, 1), + Dense(5, 3), + Dense(5, 2) + ) + ) |> gpu + +model(gpu(rand(10))) +# returns a tuple with three float vectors +``` + +A custom loss function for the multiple outputs may look like this: +```julia +using Statistics + +# assuming model returns the output of a Split +# x is a single input +# ys is a tuple of outputs +function loss(x, ys, model) + # rms over all the mse + ŷs = model(x) + return sqrt(mean(Flux.mse(y, ŷ) for (y, ŷ) in zip(ys, ŷs))) +end +``` From 4143dd8aba7e59c69feade5b85a012106463b407 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Wed, 13 Jan 2021 11:14:52 -0600 Subject: [PATCH 13/15] Add Parallel docstring to docs --- docs/src/models/layers.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/src/models/layers.md b/docs/src/models/layers.md index 0572748566..34dc1b57f5 100644 --- a/docs/src/models/layers.md +++ b/docs/src/models/layers.md @@ -49,6 +49,7 @@ But in contrast to the layers described in the other sections are not readily gr ```@docs Maxout SkipConnection +Parallel ``` ## Normalisation & Regularisation From 49a05b4f8585d41e9439f5e99729b1d6be18534c Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Wed, 13 Jan 2021 11:17:28 -0600 Subject: [PATCH 14/15] Add NEWS.md entry --- NEWS.md | 1 + 1 file changed, 1 insertion(+) diff --git a/NEWS.md b/NEWS.md index 602ef43026..f8efaa5f1f 100644 --- a/NEWS.md +++ b/NEWS.md @@ -8,6 +8,7 @@ * Removed kwarg only constructors for [`convolutional layers`](https://github.com/FluxML/Flux.jl/pull/1379). * Add [sparse initialization](https://github.com/FluxML/Flux.jl/pull/1454) as described in [Deep learning via Hessian-free optimization](https://dl.acm.org/doi/abs/10.5555/3104322.3104416). * Moved GPU CI to use buildkite instead of GitLab +* New [`Parallel` layer](https://github.com/FluxML/Flux.jl/pull/1462) adds inception module-like building blocks. * Other new features and bug fixes (see GitHub releases page) ## v0.11.2 From 377c977c60e507af622c42b0af09b037e0413a6b Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Wed, 13 Jan 2021 12:07:52 -0600 Subject: [PATCH 15/15] Add tests for vararg Parallel and updated docs --- docs/src/models/advanced.md | 28 +++++++++++++++++++++++++++- src/layers/basic.jl | 1 + test/layers/basic.jl | 5 +++++ 3 files changed, 33 insertions(+), 1 deletion(-) diff --git a/docs/src/models/advanced.md b/docs/src/models/advanced.md index 9bfc985ddd..b1f02c66e9 100644 --- a/docs/src/models/advanced.md +++ b/docs/src/models/advanced.md @@ -79,7 +79,7 @@ Naively, we could have a struct that stores the weights of along each path and i ### Multiple inputs: a custom `Join` layer -Our custom `Join` layer will accept multiple inputs at once, pass each input through a separate path, then combine the results together. +Our custom `Join` layer will accept multiple inputs at once, pass each input through a separate path, then combine the results together. Note that this layer can already be constructed using [`Parallel`](@ref), but we will first walk through how do this manually. We start by defining a new struct, `Join`, that stores the different paths and a combine operation as its fields. ```julia @@ -128,6 +128,32 @@ model(xs) # returns a single float vector with one value ``` +#### Using `Parallel` + +Flux already provides [`Parallel`](@ref) that can offer the same functionality. In this case, `Join` is going to just be syntactic sugar for `Parallel`. +```julia +Join(combine, paths) = Parallel(combine, paths) +Join(combine, paths...) = Join(combine, paths) + +# use vararg/tuple version of Parallel forward pass +model = Chain( + Join(vcat, + Chain( + Dense(1, 5), + Dense(5, 1) + ), + Dense(1, 2), + Dense(1, 1), + ), + Dense(4, 1) + ) |> gpu + +xs = map(gpu, (rand(1), rand(1), rand(1))) + +model(xs) +# returns a single float vector with one value +``` + ### Multiple outputs: a custom `Split` layer Our custom `Split` layer will accept a single input, then pass the input through a separate path to produce multiple outputs. diff --git a/src/layers/basic.jl b/src/layers/basic.jl index cac2387b7f..6a4be2dcab 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -291,6 +291,7 @@ Parallel(connection, layers...) = Parallel(connection, layers) (m::Parallel)(x::AbstractArray) = mapreduce(f -> f(x), m.connection, m.layers) (m::Parallel)(xs::Vararg{<:AbstractArray}) = mapreduce((f, x) -> f(x), m.connection, m.layers, xs) +(m::Parallel)(xs::Tuple) = m(xs...) Base.getindex(m::Parallel, i::Integer) = m.layers[i] Base.getindex(m::Parallel, i::AbstractVector) = Parallel(m.connection, m.layers[i]...) diff --git a/test/layers/basic.jl b/test/layers/basic.jl index f05b5b5100..d8771fb76c 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -117,5 +117,10 @@ import Flux: activations input = randn(10, 2) @test size(Parallel((a, b) -> cat(a, b; dims=2), Dense(10, 10), identity)(input)) == (10, 4) end + + @testset "vararg input" begin + inputs = randn(10), randn(5), randn(4) + @test size(Parallel(+, Dense(10, 2), Dense(5, 2), Dense(4, 2))(inputs)) == (2,) + end end end