From c874376913cab0690d1684afd2dab18e6c391bac Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Sat, 15 Oct 2022 13:55:42 +0200 Subject: [PATCH 1/6] improvements to stack --- src/utils.jl | 35 ++++++++++++++++++++++++++++++----- test/test_utils.jl | 13 +++++++++++++ test/utils.jl | 12 ++++++++++++ 3 files changed, 55 insertions(+), 5 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index fd254b7..aaad225 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -5,6 +5,7 @@ Return `x` reshaped into an array one dimensionality higher than `x`, where `dims` indicates in which dimension `x` is extended. +`dims` can be an integer between 1 and `ndims(x)+1`. See also [`flatten`](@ref), [`stack`](@ref). @@ -33,8 +34,9 @@ julia> unsqueeze(xs, dims=1) [1, 2] [3, 4] [5, 6] ``` """ -function unsqueeze(x::AbstractArray; dims::Int) - sz = ntuple(i -> i < dims ? size(x, i) : i == dims ? 1 : size(x, i - 1), ndims(x) + 1) +function unsqueeze(x::AbstractArray{T,N}; dims::Int) where {T, N} + # @assert 1 <= dims <= N + 1 + sz = ntuple(i -> i < dims ? size(x, i) : i == dims ? 1 : size(x, i - 1), N + 1) return reshape(x, sz) end @@ -59,9 +61,11 @@ Base.show_function(io::IO, u::Base.Fix2{typeof(_unsqueeze)}, ::Bool) = print(io, stack(xs; dims) Concatenate the given array of arrays `xs` into a single array along the -given dimension `dims`. +given dimension `dims`. All arrays need to be of the same size. +The number of dimension in the final arrays is one more than the number +of dimensions in the input arrays. -See also [`stack`](@ref) and [`batch`](@ref). +See also [`unsqueeze`](@ref), [`unstack`](@ref) and [`batch`](@ref). # Examples @@ -98,7 +102,28 @@ julia> stack(xs, dims=3) 6 ``` """ -stack(xs; dims::Int) = cat(unsqueeze.(xs; dims)...; dims) +function stack(xs; dims::Int) + N = ndims(xs[1]) + if dims <= N + vs = unsqueeze.(xs; dims) + else + vs = xs + end + if dims == 1 + return reduce(vcat, vs) + elseif dims === 2 + return reduce(hcat, vs) + else + return reduce((x, y) -> cat(x, y; dims=dims), vs) + end +end + +function rrule(::typeof(stack), xs; dims::Int) + function stack_pullback(Δ) + return (NoTangent(), unstack(unthunk(Δ); dims=dims)) + end + return stack(xs; dims=dims), stack_pullback +end """ unstack(xs; dims) diff --git a/test/test_utils.jl b/test/test_utils.jl index 669ad96..7f334f3 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -1,4 +1,17 @@ +""" +Test gradients through zygote. + +# Arguments + +- `f`: function to test +- `xs`: inputs to `f` + +# Keyword Arguments +Keyword arguments are passed to `rrule`. + +- `fkwargs`: keyword arguments to `f` +""" function test_zygote(f, xs...; kws...) config = ZygoteRuleConfig() test_rrule(config, f, xs...; kws..., rrule_f = rrule_via_ad) diff --git a/test/utils.jl b/test/utils.jl index d5b80bb..c2343e9 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -6,6 +6,8 @@ @test @inferred(unsqueeze(x; dims=4)) == reshape(x, 2, 3, 2, 1) @test unsqueeze(dims=2)(x) == unsqueeze(x, dims=2) + + @test_throws AssertionError unsqueeze(rand(2,2), dims=4) end @testset "stack and unstack" begin @@ -19,6 +21,16 @@ end @test unstack(stacked_array, dims=2) == unstacked_array @test stack(unstacked_array, dims=2) == stacked_array @test stack(unstack(stacked_array, dims=1), dims=1) == stacked_array + + for d in (1,2,3) + test_zygote(stack, [x,2x], fkwargs=(; dims=d), check_inferred=false) + end + + # Issue #121 + a = [[1] for i in 1:10000] + @test size(stack(a, dims=1)) == (10000, 1) + @test size(stack(a, dims=2)) == (1, 10000) + @test size(stack(a, dims=3)) == (1, 1, 10000) end @testset "batch and unbatch" begin From 639942225b0c5c5a1edcc02b5776cd728bfd48ab Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Sat, 15 Oct 2022 13:57:57 +0200 Subject: [PATCH 2/6] cleanup --- src/utils.jl | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index aaad225..5909405 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -35,7 +35,7 @@ julia> unsqueeze(xs, dims=1) ``` """ function unsqueeze(x::AbstractArray{T,N}; dims::Int) where {T, N} - # @assert 1 <= dims <= N + 1 + @assert 1 <= dims <= N + 1 sz = ntuple(i -> i < dims ? size(x, i) : i == dims ? 1 : size(x, i - 1), N + 1) return reshape(x, sz) end @@ -61,9 +61,7 @@ Base.show_function(io::IO, u::Base.Fix2{typeof(_unsqueeze)}, ::Bool) = print(io, stack(xs; dims) Concatenate the given array of arrays `xs` into a single array along the -given dimension `dims`. All arrays need to be of the same size. -The number of dimension in the final arrays is one more than the number -of dimensions in the input arrays. +new dimension `dims`. All arrays need to be of the same size. See also [`unsqueeze`](@ref), [`unstack`](@ref) and [`batch`](@ref). From 14da8cb0684186bee37dffa1ff5c39f3ea9f0e6a Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Sat, 15 Oct 2022 17:41:46 +0200 Subject: [PATCH 3/6] use Base definition of stack --- Project.toml | 1 + src/MLUtils.jl | 8 ++++-- src/deprecations.jl | 1 - src/utils.jl | 66 --------------------------------------------- test/utils.jl | 3 +-- 5 files changed, 8 insertions(+), 71 deletions(-) diff --git a/Project.toml b/Project.toml index ca42c9b..d63852d 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "0.3.1" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab" FLoops = "cc61a311-1640-44b5-9fba-1b764f453329" diff --git a/src/MLUtils.jl b/src/MLUtils.jl index c4d0dec..df65156 100644 --- a/src/MLUtils.jl +++ b/src/MLUtils.jl @@ -22,7 +22,11 @@ import NNlib @traitdef IsTable{X} @traitimpl IsTable{X} <- Tables.istable(X) - +if VERSION < v"1.9.0-DEV.1163" + import Compat: stack +else + import Base: stack +end include("observation.jl") export numobs, getobs, @@ -75,7 +79,7 @@ export batch, rand_like, randn_like, rpad_constant, - stack, + stack, # in Base since julia v1.9 unbatch, unsqueeze, unstack, diff --git a/src/deprecations.jl b/src/deprecations.jl index 797a14f..224602f 100644 --- a/src/deprecations.jl +++ b/src/deprecations.jl @@ -1,5 +1,4 @@ # Deprecated in v0.2 -@deprecate stack(x, dims) stack(x; dims=dims) @deprecate unstack(x, dims) unstack(x; dims=dims) @deprecate unsqueeze(x::AbstractArray, dims::Int) unsqueeze(x; dims=dims) @deprecate unsqueeze(dims::Int) unsqueeze(dims=dims) diff --git a/src/utils.jl b/src/utils.jl index 5909405..e4348c1 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -57,72 +57,6 @@ _unsqueeze(x, dims) = unsqueeze(x; dims) Base.show_function(io::IO, u::Base.Fix2{typeof(_unsqueeze)}, ::Bool) = print(io, "unsqueeze(dims=", u.x, ")") -""" - stack(xs; dims) - -Concatenate the given array of arrays `xs` into a single array along the -new dimension `dims`. All arrays need to be of the same size. - -See also [`unsqueeze`](@ref), [`unstack`](@ref) and [`batch`](@ref). - -# Examples - -```jldoctest -julia> xs = [[1, 2], [3, 4], [5, 6]] -3-element Vector{Vector{Int64}}: - [1, 2] - [3, 4] - [5, 6] - -julia> stack(xs, dims=1) -3×2 Matrix{Int64}: - 1 2 - 3 4 - 5 6 - -julia> stack(xs, dims=2) -2×3 Matrix{Int64}: - 1 3 5 - 2 4 6 - -julia> stack(xs, dims=3) -2×1×3 Array{Int64, 3}: -[:, :, 1] = - 1 - 2 - -[:, :, 2] = - 3 - 4 - -[:, :, 3] = - 5 - 6 -``` -""" -function stack(xs; dims::Int) - N = ndims(xs[1]) - if dims <= N - vs = unsqueeze.(xs; dims) - else - vs = xs - end - if dims == 1 - return reduce(vcat, vs) - elseif dims === 2 - return reduce(hcat, vs) - else - return reduce((x, y) -> cat(x, y; dims=dims), vs) - end -end - -function rrule(::typeof(stack), xs; dims::Int) - function stack_pullback(Δ) - return (NoTangent(), unstack(unthunk(Δ); dims=dims)) - end - return stack(xs; dims=dims), stack_pullback -end - """ unstack(xs; dims) diff --git a/test/utils.jl b/test/utils.jl index c2343e9..a8418c6 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -14,7 +14,7 @@ end x = randn(3,3) stacked = stack([x, x], dims=2) @test size(stacked) == (3,2,3) - @test_broken @inferred(stack([x, x], dims=2)) == stacked + @test @inferred(stack([x, x], dims=2)) == stacked stacked_array=[ 8 9 3 5; 9 6 6 9; 9 1 7 2; 7 4 10 6 ] unstacked_array=[[8, 9, 9, 7], [9, 6, 1, 4], [3, 6, 7, 10], [5, 9, 2, 6]] @@ -30,7 +30,6 @@ end a = [[1] for i in 1:10000] @test size(stack(a, dims=1)) == (10000, 1) @test size(stack(a, dims=2)) == (1, 10000) - @test size(stack(a, dims=3)) == (1, 1, 10000) end @testset "batch and unbatch" begin From 9c023610f77c78e1719e2f175ad3534f597abc44 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Fri, 11 Nov 2022 06:52:19 +0100 Subject: [PATCH 4/6] v0.4 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index d63852d..dca36b7 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MLUtils" uuid = "f1d291b0-491e-4a28-83b9-f70985020b54" authors = ["Carlo Lucibello and contributors"] -version = "0.3.1" +version = "0.4.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" From cb85fb7cff05041fcad0e1514b872df35b198f39 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Fri, 11 Nov 2022 07:07:10 +0100 Subject: [PATCH 5/6] use stack in batch --- src/utils.jl | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index e4348c1..6df0b6d 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -286,17 +286,7 @@ end batchindex(xs, i) = (reverse(Base.tail(reverse(axes(xs))))..., i) -function batch(xs::AbstractArray{<:AbstractArray}) - # Don't use stack(xs, dims=N+1), it is much slower. - # Here we do reduce(vcat, xs) along with some reshapes. - szxs = size(xs) - @assert length(xs) > 0 "Minimum batch size is 1." - szx = size(xs[1]) - @assert all(x -> size(x) == szx, xs) "All arrays must be of the same size." - vxs = vec(vec.(xs)) - y = reduce(vcat, vxs) - return reshape(y, szx..., szxs...) -end +batch(xs::AbstractArray{<:AbstractArray}) = stack(xs) function batch(xs::Vector{<:Tuple}) @assert length(xs) > 0 "Input should be non-empty" From 3fa64eebe93039ad343a314c8031fb411ba31f75 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Fri, 11 Nov 2022 19:42:10 +0100 Subject: [PATCH 6/6] Compat bound --- Project.toml | 1 + src/MLUtils.jl | 9 +++------ 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/Project.toml b/Project.toml index dca36b7..1b61a18 100644 --- a/Project.toml +++ b/Project.toml @@ -21,6 +21,7 @@ Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999" [compat] ChainRulesCore = "1.0" +Compat = "4.2" DataAPI = "1.0" DelimitedFiles = "1.0" FLoops = "0.2" diff --git a/src/MLUtils.jl b/src/MLUtils.jl index df65156..8261ab9 100644 --- a/src/MLUtils.jl +++ b/src/MLUtils.jl @@ -21,12 +21,9 @@ import NNlib @traitdef IsTable{X} @traitimpl IsTable{X} <- Tables.istable(X) - -if VERSION < v"1.9.0-DEV.1163" - import Compat: stack -else - import Base: stack -end + +using Compat: stack + include("observation.jl") export numobs, getobs,