-
-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Speed-up repeat for AbstractArrays #20635
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -360,40 +360,68 @@ julia> repeat([1 2; 3 4], inner=(2, 1), outer=(1, 3)) | |
``` | ||
""" | ||
function repeat(A::AbstractArray; | ||
inner=ntuple(x->1, ndims(A)), | ||
outer=ntuple(x->1, ndims(A))) | ||
ndims_in = ndims(A) | ||
length_inner = length(inner) | ||
length_outer = length(outer) | ||
|
||
length_inner >= ndims_in || throw(ArgumentError("number of inner repetitions ($(length(inner))) cannot be less than number of dimensions of input ($(ndims(A)))")) | ||
length_outer >= ndims_in || throw(ArgumentError("number of outer repetitions ($(length(outer))) cannot be less than number of dimensions of input ($(ndims(A)))")) | ||
|
||
ndims_out = max(ndims_in, length_inner, length_outer) | ||
|
||
inner = vcat(collect(inner), ones(Int,ndims_out-length_inner)) | ||
outer = vcat(collect(outer), ones(Int,ndims_out-length_outer)) | ||
|
||
size_in = size(A) | ||
size_out = ntuple(i->inner[i]*size(A,i)*outer[i],ndims_out)::Dims | ||
inner_size_out = ntuple(i->inner[i]*size(A,i),ndims_out)::Dims | ||
|
||
indices_in = Vector{Int}(ndims_in) | ||
indices_out = Vector{Int}(ndims_out) | ||
inner=ntuple(n->1, Val{ndims(A)}), | ||
outer=ntuple(n->1, Val{ndims(A)})) | ||
return _repeat(A, rep_kw2tup(inner), rep_kw2tup(outer)) | ||
end | ||
|
||
length_out = prod(size_out) | ||
R = similar(A, size_out) | ||
rep_kw2tup(n::Integer) = (n,) | ||
rep_kw2tup(v::AbstractArray{<:Integer}) = (v...) | ||
rep_kw2tup(t::Tuple) = t | ||
|
||
rep_shapes(A, i, o) = _rshps((), (), size(A), i, o) | ||
|
||
_rshps(shp, shp_i, ::Tuple{}, ::Tuple{}, ::Tuple{}) = (shp, shp_i) | ||
@inline _rshps(shp, shp_i, ::Tuple{}, ::Tuple{}, o) = | ||
_rshps((shp..., o[1]), (shp_i..., 1), (), (), tail(o)) | ||
@inline _rshps(shp, shp_i, ::Tuple{}, i, ::Tuple{}) = (n = i[1]; | ||
_rshps((shp..., n), (shp_i..., n), (), tail(i), ())) | ||
@inline _rshps(shp, shp_i, ::Tuple{}, i, o) = (n = i[1]; | ||
_rshps((shp..., n * o[1]), (shp_i..., n), (), tail(i), tail(o))) | ||
@inline _rshps(shp, shp_i, sz, i, o) = (n = sz[1] * i[1]; | ||
_rshps((shp..., n * o[1]), (shp_i..., n), tail(sz), tail(i), tail(o))) | ||
_rshps(shp, shp_i, sz, ::Tuple{}, ::Tuple{}) = | ||
(n = length(shp); N = n + length(sz); _reperr("inner", n, N)) | ||
_rshps(shp, shp_i, sz, ::Tuple{}, o) = | ||
(n = length(shp); N = n + length(sz); _reperr("inner", n, N)) | ||
_rshps(shp, shp_i, sz, i, ::Tuple{}) = | ||
(n = length(shp); N = n + length(sz); _reperr("outer", n, N)) | ||
_reperr(s, n, N) = throw(ArgumentError("number of " * s * " repetitions " * | ||
"($n) cannot be less than number of dimensions of input ($N)")) | ||
|
||
@propagate_inbounds function _repeat(A::AbstractArray, inner, outer) | ||
shape, inner_shape = rep_shapes(A, inner, outer) | ||
|
||
R = similar(A, shape) | ||
|
||
# fill the first inner block | ||
if all(x -> x == 1, inner) | ||
R[indices(A)...] = A | ||
else | ||
inner_indices = [1:n for n in inner] | ||
for c in CartesianRange(indices(A)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would this be any better as There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I tried that, the problem is that |
||
for i in 1:ndims(A) | ||
n = inner[i] | ||
inner_indices[i] = (1:n) + ((c.I[i] - 1) * n) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
end | ||
R[inner_indices...] = A[c.I...] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. rhs can be just |
||
end | ||
end | ||
|
||
for index_out in 1:length_out | ||
ind2sub!(indices_out, size_out, index_out) | ||
for t in 1:ndims_in | ||
# "Project" outer repetitions into inner repetitions | ||
indices_in[t] = mod1(indices_out[t], inner_size_out[t]) | ||
# Find inner repetitions using flooring division | ||
indices_in[t] = fld1(indices_in[t], inner[t]) | ||
# fill the outer blocks along each dimension | ||
if all(x -> x == 1, outer) | ||
return R | ||
end | ||
src_indices = [1:n for n in inner_shape] | ||
dest_indices = copy(src_indices) | ||
for i in 1:length(outer) | ||
B = view(R, src_indices...) | ||
for j in 2:outer[i] | ||
dest_indices[i] += inner_shape[i] | ||
R[dest_indices...] = B | ||
end | ||
index_in = sub2ind(size_in, indices_in...) | ||
R[index_out] = A[index_in] | ||
src_indices[i] = 1:dest_indices[i][end] | ||
copy!(dest_indices, src_indices) | ||
end | ||
|
||
return R | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Whoever made
ndims
inferable on 0.6, thanks!