Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improvements to stack #125

Merged
merged 6 commits into from
Nov 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
name = "MLUtils"
uuid = "f1d291b0-491e-4a28-83b9-f70985020b54"
authors = ["Carlo Lucibello <carlo.lucibello@gmail.com> and contributors"]
version = "0.3.1"
version = "0.4.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should probably bound Compat = "4.2". (Can't bound ChainRules but that's ok.)

DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
FLoops = "cc61a311-1640-44b5-9fba-1b764f453329"
Expand All @@ -20,6 +21,7 @@ Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999"

[compat]
ChainRulesCore = "1.0"
Compat = "4.2"
DataAPI = "1.0"
DelimitedFiles = "1.0"
FLoops = "0.2"
Expand Down
5 changes: 3 additions & 2 deletions src/MLUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ import NNlib

@traitdef IsTable{X}
@traitimpl IsTable{X} <- Tables.istable(X)


using Compat: stack

include("observation.jl")
export numobs,
Expand Down Expand Up @@ -75,7 +76,7 @@ export batch,
rand_like,
randn_like,
rpad_constant,
stack,
stack, # in Base since julia v1.9
unbatch,
unsqueeze,
unstack,
Expand Down
1 change: 0 additions & 1 deletion src/deprecations.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Deprecated in v0.2
@deprecate stack(x, dims) stack(x; dims=dims)
@deprecate unstack(x, dims) unstack(x; dims=dims)
@deprecate unsqueeze(x::AbstractArray, dims::Int) unsqueeze(x; dims=dims)
@deprecate unsqueeze(dims::Int) unsqueeze(dims=dims)
Expand Down
63 changes: 5 additions & 58 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

Return `x` reshaped into an array one dimensionality higher than `x`,
where `dims` indicates in which dimension `x` is extended.
`dims` can be an integer between 1 and `ndims(x)+1`.

See also [`flatten`](@ref), [`stack`](@ref).

Expand Down Expand Up @@ -33,8 +34,9 @@ julia> unsqueeze(xs, dims=1)
[1, 2] [3, 4] [5, 6]
```
"""
function unsqueeze(x::AbstractArray; dims::Int)
sz = ntuple(i -> i < dims ? size(x, i) : i == dims ? 1 : size(x, i - 1), ndims(x) + 1)
function unsqueeze(x::AbstractArray{T,N}; dims::Int) where {T, N}
@assert 1 <= dims <= N + 1
sz = ntuple(i -> i < dims ? size(x, i) : i == dims ? 1 : size(x, i - 1), N + 1)
return reshape(x, sz)
end

Expand All @@ -55,51 +57,6 @@ _unsqueeze(x, dims) = unsqueeze(x; dims)

Base.show_function(io::IO, u::Base.Fix2{typeof(_unsqueeze)}, ::Bool) = print(io, "unsqueeze(dims=", u.x, ")")

"""
stack(xs; dims)

Concatenate the given array of arrays `xs` into a single array along the
given dimension `dims`.

See also [`stack`](@ref) and [`batch`](@ref).

# Examples

```jldoctest
julia> xs = [[1, 2], [3, 4], [5, 6]]
3-element Vector{Vector{Int64}}:
[1, 2]
[3, 4]
[5, 6]

julia> stack(xs, dims=1)
3×2 Matrix{Int64}:
1 2
3 4
5 6

julia> stack(xs, dims=2)
2×3 Matrix{Int64}:
1 3 5
2 4 6

julia> stack(xs, dims=3)
2×1×3 Array{Int64, 3}:
[:, :, 1] =
1
2

[:, :, 2] =
3
4

[:, :, 3] =
5
6
```
"""
stack(xs; dims::Int) = cat(unsqueeze.(xs; dims)...; dims)

"""
unstack(xs; dims)

Expand Down Expand Up @@ -329,17 +286,7 @@ end

batchindex(xs, i) = (reverse(Base.tail(reverse(axes(xs))))..., i)

function batch(xs::AbstractArray{<:AbstractArray})
# Don't use stack(xs, dims=N+1), it is much slower.
# Here we do reduce(vcat, xs) along with some reshapes.
szxs = size(xs)
@assert length(xs) > 0 "Minimum batch size is 1."
szx = size(xs[1])
@assert all(x -> size(x) == szx, xs) "All arrays must be of the same size."
vxs = vec(vec.(xs))
y = reduce(vcat, vxs)
return reshape(y, szx..., szxs...)
end
batch(xs::AbstractArray{<:AbstractArray}) = stack(xs)

function batch(xs::Vector{<:Tuple})
@assert length(xs) > 0 "Input should be non-empty"
Expand Down
13 changes: 13 additions & 0 deletions test/test_utils.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,17 @@

"""
Test gradients through zygote.

# Arguments

- `f`: function to test
- `xs`: inputs to `f`

# Keyword Arguments
Keyword arguments are passed to `rrule`.

- `fkwargs`: keyword arguments to `f`
"""
function test_zygote(f, xs...; kws...)
config = ZygoteRuleConfig()
test_rrule(config, f, xs...; kws..., rrule_f = rrule_via_ad)
Expand Down
13 changes: 12 additions & 1 deletion test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,30 @@
@test @inferred(unsqueeze(x; dims=4)) == reshape(x, 2, 3, 2, 1)

@test unsqueeze(dims=2)(x) == unsqueeze(x, dims=2)

@test_throws AssertionError unsqueeze(rand(2,2), dims=4)
end

@testset "stack and unstack" begin
x = randn(3,3)
stacked = stack([x, x], dims=2)
@test size(stacked) == (3,2,3)
@test_broken @inferred(stack([x, x], dims=2)) == stacked
@test @inferred(stack([x, x], dims=2)) == stacked

stacked_array=[ 8 9 3 5; 9 6 6 9; 9 1 7 2; 7 4 10 6 ]
unstacked_array=[[8, 9, 9, 7], [9, 6, 1, 4], [3, 6, 7, 10], [5, 9, 2, 6]]
@test unstack(stacked_array, dims=2) == unstacked_array
@test stack(unstacked_array, dims=2) == stacked_array
@test stack(unstack(stacked_array, dims=1), dims=1) == stacked_array

for d in (1,2,3)
test_zygote(stack, [x,2x], fkwargs=(; dims=d), check_inferred=false)
end

# Issue #121
a = [[1] for i in 1:10000]
@test size(stack(a, dims=1)) == (10000, 1)
@test size(stack(a, dims=2)) == (1, 10000)
end

@testset "batch and unbatch" begin
Expand Down