From 531174ad036ca8fde9e4afa7e92235bb68861064 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Fri, 11 Nov 2022 07:07:10 +0100 Subject: [PATCH] use stack in batch --- src/utils.jl | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index e4348c1..6df0b6d 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -286,17 +286,7 @@ end batchindex(xs, i) = (reverse(Base.tail(reverse(axes(xs))))..., i) -function batch(xs::AbstractArray{<:AbstractArray}) - # Don't use stack(xs, dims=N+1), it is much slower. - # Here we do reduce(vcat, xs) along with some reshapes. - szxs = size(xs) - @assert length(xs) > 0 "Minimum batch size is 1." - szx = size(xs[1]) - @assert all(x -> size(x) == szx, xs) "All arrays must be of the same size." - vxs = vec(vec.(xs)) - y = reduce(vcat, vxs) - return reshape(y, szx..., szxs...) -end +batch(xs::AbstractArray{<:AbstractArray}) = stack(xs) function batch(xs::Vector{<:Tuple}) @assert length(xs) > 0 "Input should be non-empty"