Skip to content

Commit

Permalink
use stack in batch
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Nov 11, 2022
1 parent a9e60cb commit 531174a
Showing 1 changed file with 1 addition and 11 deletions.
12 changes: 1 addition & 11 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -286,17 +286,7 @@ end

batchindex(xs, i) = (reverse(Base.tail(reverse(axes(xs))))..., i)

function batch(xs::AbstractArray{<:AbstractArray})
# Don't use stack(xs, dims=N+1), it is much slower.
# Here we do reduce(vcat, xs) along with some reshapes.
szxs = size(xs)
@assert length(xs) > 0 "Minimum batch size is 1."
szx = size(xs[1])
@assert all(x -> size(x) == szx, xs) "All arrays must be of the same size."
vxs = vec(vec.(xs))
y = reduce(vcat, vxs)
return reshape(y, szx..., szxs...)
end
batch(xs::AbstractArray{<:AbstractArray}) = stack(xs)

function batch(xs::Vector{<:Tuple})
@assert length(xs) > 0 "Input should be non-empty"
Expand Down

0 comments on commit 531174a

Please sign in to comment.