diff --git a/Project.toml b/Project.toml index ca42c9b..1b61a18 100644 --- a/Project.toml +++ b/Project.toml @@ -1,10 +1,11 @@ name = "MLUtils" uuid = "f1d291b0-491e-4a28-83b9-f70985020b54" authors = ["Carlo Lucibello and contributors"] -version = "0.3.1" +version = "0.4.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab" FLoops = "cc61a311-1640-44b5-9fba-1b764f453329" @@ -20,6 +21,7 @@ Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999" [compat] ChainRulesCore = "1.0" +Compat = "4.2" DataAPI = "1.0" DelimitedFiles = "1.0" FLoops = "0.2" diff --git a/src/MLUtils.jl b/src/MLUtils.jl index c4d0dec..8261ab9 100644 --- a/src/MLUtils.jl +++ b/src/MLUtils.jl @@ -21,7 +21,8 @@ import NNlib @traitdef IsTable{X} @traitimpl IsTable{X} <- Tables.istable(X) - + +using Compat: stack include("observation.jl") export numobs, @@ -75,7 +76,7 @@ export batch, rand_like, randn_like, rpad_constant, - stack, + stack, # in Base since julia v1.9 unbatch, unsqueeze, unstack, diff --git a/src/deprecations.jl b/src/deprecations.jl index 797a14f..224602f 100644 --- a/src/deprecations.jl +++ b/src/deprecations.jl @@ -1,5 +1,4 @@ # Deprecated in v0.2 -@deprecate stack(x, dims) stack(x; dims=dims) @deprecate unstack(x, dims) unstack(x; dims=dims) @deprecate unsqueeze(x::AbstractArray, dims::Int) unsqueeze(x; dims=dims) @deprecate unsqueeze(dims::Int) unsqueeze(dims=dims) diff --git a/src/utils.jl b/src/utils.jl index fd254b7..6df0b6d 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -5,6 +5,7 @@ Return `x` reshaped into an array one dimensionality higher than `x`, where `dims` indicates in which dimension `x` is extended. +`dims` can be an integer between 1 and `ndims(x)+1`. See also [`flatten`](@ref), [`stack`](@ref). @@ -33,8 +34,9 @@ julia> unsqueeze(xs, dims=1) [1, 2] [3, 4] [5, 6] ``` """ -function unsqueeze(x::AbstractArray; dims::Int) - sz = ntuple(i -> i < dims ? size(x, i) : i == dims ? 1 : size(x, i - 1), ndims(x) + 1) +function unsqueeze(x::AbstractArray{T,N}; dims::Int) where {T, N} + @assert 1 <= dims <= N + 1 + sz = ntuple(i -> i < dims ? size(x, i) : i == dims ? 1 : size(x, i - 1), N + 1) return reshape(x, sz) end @@ -55,51 +57,6 @@ _unsqueeze(x, dims) = unsqueeze(x; dims) Base.show_function(io::IO, u::Base.Fix2{typeof(_unsqueeze)}, ::Bool) = print(io, "unsqueeze(dims=", u.x, ")") -""" - stack(xs; dims) - -Concatenate the given array of arrays `xs` into a single array along the -given dimension `dims`. - -See also [`stack`](@ref) and [`batch`](@ref). - -# Examples - -```jldoctest -julia> xs = [[1, 2], [3, 4], [5, 6]] -3-element Vector{Vector{Int64}}: - [1, 2] - [3, 4] - [5, 6] - -julia> stack(xs, dims=1) -3×2 Matrix{Int64}: - 1 2 - 3 4 - 5 6 - -julia> stack(xs, dims=2) -2×3 Matrix{Int64}: - 1 3 5 - 2 4 6 - -julia> stack(xs, dims=3) -2×1×3 Array{Int64, 3}: -[:, :, 1] = - 1 - 2 - -[:, :, 2] = - 3 - 4 - -[:, :, 3] = - 5 - 6 -``` -""" -stack(xs; dims::Int) = cat(unsqueeze.(xs; dims)...; dims) - """ unstack(xs; dims) @@ -329,17 +286,7 @@ end batchindex(xs, i) = (reverse(Base.tail(reverse(axes(xs))))..., i) -function batch(xs::AbstractArray{<:AbstractArray}) - # Don't use stack(xs, dims=N+1), it is much slower. - # Here we do reduce(vcat, xs) along with some reshapes. - szxs = size(xs) - @assert length(xs) > 0 "Minimum batch size is 1." - szx = size(xs[1]) - @assert all(x -> size(x) == szx, xs) "All arrays must be of the same size." - vxs = vec(vec.(xs)) - y = reduce(vcat, vxs) - return reshape(y, szx..., szxs...) -end +batch(xs::AbstractArray{<:AbstractArray}) = stack(xs) function batch(xs::Vector{<:Tuple}) @assert length(xs) > 0 "Input should be non-empty" diff --git a/test/test_utils.jl b/test/test_utils.jl index 669ad96..7f334f3 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -1,4 +1,17 @@ +""" +Test gradients through zygote. + +# Arguments + +- `f`: function to test +- `xs`: inputs to `f` + +# Keyword Arguments +Keyword arguments are passed to `rrule`. + +- `fkwargs`: keyword arguments to `f` +""" function test_zygote(f, xs...; kws...) config = ZygoteRuleConfig() test_rrule(config, f, xs...; kws..., rrule_f = rrule_via_ad) diff --git a/test/utils.jl b/test/utils.jl index d5b80bb..a8418c6 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -6,19 +6,30 @@ @test @inferred(unsqueeze(x; dims=4)) == reshape(x, 2, 3, 2, 1) @test unsqueeze(dims=2)(x) == unsqueeze(x, dims=2) + + @test_throws AssertionError unsqueeze(rand(2,2), dims=4) end @testset "stack and unstack" begin x = randn(3,3) stacked = stack([x, x], dims=2) @test size(stacked) == (3,2,3) - @test_broken @inferred(stack([x, x], dims=2)) == stacked + @test @inferred(stack([x, x], dims=2)) == stacked stacked_array=[ 8 9 3 5; 9 6 6 9; 9 1 7 2; 7 4 10 6 ] unstacked_array=[[8, 9, 9, 7], [9, 6, 1, 4], [3, 6, 7, 10], [5, 9, 2, 6]] @test unstack(stacked_array, dims=2) == unstacked_array @test stack(unstacked_array, dims=2) == stacked_array @test stack(unstack(stacked_array, dims=1), dims=1) == stacked_array + + for d in (1,2,3) + test_zygote(stack, [x,2x], fkwargs=(; dims=d), check_inferred=false) + end + + # Issue #121 + a = [[1] for i in 1:10000] + @test size(stack(a, dims=1)) == (10000, 1) + @test size(stack(a, dims=2)) == (1, 10000) end @testset "batch and unbatch" begin