diff --git a/src/outputsize.jl b/src/outputsize.jl index 016dbcfe89..eaee34433f 100644 --- a/src/outputsize.jl +++ b/src/outputsize.jl @@ -50,18 +50,19 @@ using .NilNumber: Nil, nil """ outputsize(m, inputsize::Tuple; padbatch=false) -Calculate the output size of model `m` given the input size. +Calculate the size of the output from model `m`, given the size of the input. Obeys `outputsize(m, size(x)) == size(m(x))` for valid input `x`. -Keyword `padbatch=true` is equivalent to using `(inputsize..., 1)`, and + +Keyword `padbatch=true` is equivalent to using `(inputsize..., 1)`, and returns the final size including this extra batch dimension. -This should be faster than calling `size(m(x))`. It uses a trivial number type, -and thus should work out of the box for custom layers. +This should be faster than calling `size(m(x))`. It uses a trivial number type, +which should work out of the box for custom layers. If `m` is a `Tuple` or `Vector`, its elements are applied in sequence, like `Chain(m...)`. # Examples -```jldoctest +```julia-repl julia> using Flux: outputsize julia> outputsize(Dense(10, 4), (10,); padbatch=true) @@ -79,32 +80,70 @@ julia> outputsize(m, (10, 10, 3, 64)) (6, 6, 32, 64) julia> try outputsize(m, (10, 10, 7, 64)) catch e println(e) end +┌ Error: layer Conv((3, 3), 3=>16), index 1 in Chain, gave an error with input of size (10, 10, 7, 64) +└ @ Flux ~/.julia/dev/Flux/src/outputsize.jl:114 DimensionMismatch("Input channels must match! (7 vs. 3)") -julia> outputsize([Dense(10, 4), Dense(4, 2)], (10, 1)) +julia> outputsize([Dense(10, 4), Dense(4, 2)], (10, 1)) # Vector of layers becomes a Chain (2, 1) +``` +""" +function outputsize(m, inputsizes::Tuple...; padbatch=false) + x = nil_input(padbatch, inputsizes...) + return size(m(x)) +end -julia> using LinearAlgebra: norm +nil_input(pad::Bool, s::Tuple{Vararg{Integer}}) = pad ? fill(nil, (s...,1)) : fill(nil, s) +nil_input(pad::Bool, multi::Tuple{Vararg{Integer}}...) = nil_input.(pad, multi) +nil_input(pad::Bool, tup::Tuple{Vararg{Tuple}}) = nil_input(pad, tup...) + +function outputsize(m::Chain, inputsizes::Tuple{Vararg{Integer}}...; padbatch=false) + x = nil_input(padbatch, inputsizes...) + for (i,lay) in enumerate(m.layers) + try + x = lay(x) + catch err + str = x isa AbstractArray ? "with input of size $(size(x))" : "" + @error "layer $lay, index $i in Chain, gave an error $str" + rethrow(err) + end + end + return size(x) +end -julia> f(x) = x ./ norm.(eachcol(x)); +""" + outputsize(m, x_size, y_size, ...; padbatch=false) -julia> outputsize(f, (10, 1)) # manually specify batch size as 1 -(10, 1) +For model or layer `m` accepting multiple arrays as input, +this returns `size(m((x, y, ...)))` given `size_x = size(x)`, etc. -julia> outputsize(f, (10,); padbatch=true) # no need to mention batch size -(10, 1) +# Examples +```jldoctest +julia> x, y = rand(Float32, 5, 64), rand(Float32, 7, 64); + +julia> par = Parallel(vcat, Dense(5, 9), Dense(7, 11)); + +julia> Flux.outputsize(par, (5, 64), (7, 64)) +(20, 64) + +julia> m = Chain(par, Dense(20, 13), softmax); + +julia> Flux.outputsize(m, (5,), (7,); padbatch=true) +(13, 1) + +julia> par(x, y) == par((x, y)) == Chain(par, identity)((x, y)) +true ``` +Notice that `Chain` only accepts multiple arrays as a tuple, +while `Parallel` also accepts them as multiple arguments; +`outputsize` always supplies the tuple. """ -function outputsize(m, inputsize::Tuple; padbatch=false) - inputsize = padbatch ? (inputsize..., 1) : inputsize - - return size(m(fill(nil, inputsize))) -end +outputsize ## make tuples and vectors be like Chains -outputsize(m::Tuple, inputsize::Tuple; padbatch=false) = outputsize(Chain(m...), inputsize; padbatch=padbatch) -outputsize(m::AbstractVector, inputsize::Tuple; padbatch=false) = outputsize(Chain(m...), inputsize; padbatch=padbatch) +outputsize(m::Tuple, input::Tuple...; padbatch=false) = outputsize(Chain(m...), input...; padbatch=padbatch) +outputsize(m::AbstractVector, input::Tuple...; padbatch=false) = outputsize(Chain(m...), input...; padbatch=padbatch) ## bypass statistics in normalization layers diff --git a/test/outputsize.jl b/test/outputsize.jl index b3e531f0f4..5b25c8bbd8 100644 --- a/test/outputsize.jl +++ b/test/outputsize.jl @@ -36,6 +36,20 @@ @test outputsize(m, (10, 10, 3, 1)) == (10, 10, 19, 1) end +@testset "multiple inputs" begin + m = Parallel(vcat, Dense(2, 4, relu), Dense(3, 6, relu)) + @test outputsize(m, (2,), (3,)) == (10,) + @test outputsize(m, ((2,), (3,))) == (10,) + @test outputsize(m, (2,), (3,); padbatch=true) == (10, 1) + @test outputsize(m, (2,7), (3,7)) == (10, 7) + + m = Chain(m, Dense(10, 13, tanh), softmax) + @test outputsize(m, (2,), (3,)) == (13,) + @test outputsize(m, ((2,), (3,))) == (13,) + @test outputsize(m, (2,), (3,); padbatch=true) == (13, 1) + @test outputsize(m, (2,7), (3,7)) == (13, 7) +end + @testset "activations" begin @testset for f in [celu, elu, gelu, hardsigmoid, hardtanh, leakyrelu, lisht, logcosh, logσ, mish,