Skip to content

Commit

Permalink
Loosen signature for repeat()
Browse files Browse the repository at this point in the history
Accept any AbstractArray as first argument, and any iterable
for inner and outer arguments. Use tuples by default rather
than arrays, and allow passing an empty collection to mean
no-op.
  • Loading branch information
nalimilan committed May 19, 2016
1 parent 807ec46 commit 93966d5
Show file tree
Hide file tree
Showing 4 changed files with 150 additions and 52 deletions.
57 changes: 44 additions & 13 deletions base/abstractarraymath.jl
Original file line number Diff line number Diff line change
Expand Up @@ -200,31 +200,62 @@ function repmat(a::AbstractVector, m::Int)
return b
end

# Generalized repmat
function repeat{T}(A::AbstractArray{T};
inner::Array{Int} = ones(Int, ndims(A)),
outer::Array{Int} = ones(Int, ndims(A)))
"""
repeat(A::AbstractArray; inner=ntuple(x->1, ndims(A)), outer=ntuple(x->1, ndims(A)))
Construct an array by repeating the entries of `A`. The i-th element of `inner` specifies
the number of times that the individual entries of the i-th dimension of `A` should be
repeated. The i-th element of `outer` specifies the number of times that a slice along the
i-th dimension of `A` should be repeated. If `inner` or `outer` are omitted, no repetition
is performed.
```jldoctest
julia> repeat(1:2, inner=2)
4-element Array{Int64,1}:
1
1
2
2
julia> repeat(1:2, outer=2)
4-element Array{Int64,1}:
1
2
1
2
julia> repeat([1 2; 3 4], inner=(2, 1), outer=(1, 3))
4×6 Array{Int64,2}:
1 2 1 2 1 2
1 2 1 2 1 2
3 4 3 4 3 4
3 4 3 4 3 4
```
"""
function repeat(A::AbstractArray;
inner=ntuple(x->1, ndims(A)),
outer=ntuple(x->1, ndims(A)))
ndims_in = ndims(A)
length_inner = length(inner)
length_outer = length(outer)
ndims_out = max(ndims_in, length_inner, length_outer)

if length_inner < ndims_in || length_outer < ndims_in
throw(ArgumentError("inner/outer repetitions must be set for all input dimensions"))
end
length_inner >= ndims_in || throw(ArgumentError("number of inner repetitions ($(length(inner))) cannot be less than number of dimensions of input ($(ndims(A)))"))
length_outer >= ndims_in || throw(ArgumentError("number of outer repetitions ($(length(outer))) cannot be less than number of dimensions of input ($(ndims(A)))"))

ndims_out = max(ndims_in, length_inner, length_outer)

inner = vcat(inner, ones(Int,ndims_out-length_inner))
outer = vcat(outer, ones(Int,ndims_out-length_outer))
inner = vcat(collect(inner), ones(Int,ndims_out-length_inner))
outer = vcat(collect(outer), ones(Int,ndims_out-length_outer))

size_in = size(A)
size_out = ntuple(i->inner[i]*size(A,i)*outer[i],ndims_out)::Dims
inner_size_out = ntuple(i->inner[i]*size(A,i),ndims_out)::Dims

indices_in = Array(Int, ndims_in)
indices_out = Array(Int, ndims_out)
indices_in = Vector{Int}(ndims_in)
indices_out = Vector{Int}(ndims_out)

length_out = prod(size_out)
R = Array(T, size_out)
R = similar(A, size_out)

for index_out in 1:length_out
ind2sub!(indices_out, size_out, index_out)
Expand Down
10 changes: 0 additions & 10 deletions base/docs/helpdb/Base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4300,16 +4300,6 @@ Return an iterator over all keys in a collection. `collect(keys(d))` returns an
"""
keys

"""
repeat(A, inner = Int[], outer = Int[])
Construct an array by repeating the entries of `A`. The i-th element of `inner` specifies
the number of times that the individual entries of the i-th dimension of `A` should be
repeated. The i-th element of `outer` specifies the number of times that a slice along the
i-th dimension of `A` should be repeated.
"""
repeat

"""
ReentrantLock()
Expand Down
27 changes: 25 additions & 2 deletions doc/stdlib/linalg.rst
Original file line number Diff line number Diff line change
Expand Up @@ -979,11 +979,34 @@ Linear algebra functions in Julia are largely implemented by calling functions f
Construct a matrix by repeating the given matrix ``n`` times in dimension 1 and ``m`` times in dimension 2.

.. function:: repeat(A, inner = Int[], outer = Int[])
.. function:: repeat(A::AbstractArray; inner=ntuple(x->1, ndims(A)), outer=ntuple(x->1, ndims(A)))

.. Docstring generated from Julia source
Construct an array by repeating the entries of ``A``\ . The i-th element of ``inner`` specifies the number of times that the individual entries of the i-th dimension of ``A`` should be repeated. The i-th element of ``outer`` specifies the number of times that a slice along the i-th dimension of ``A`` should be repeated.
Construct an array by repeating the entries of ``A``\ . The i-th element of ``inner`` specifies the number of times that the individual entries of the i-th dimension of ``A`` should be repeated. The i-th element of ``outer`` specifies the number of times that a slice along the i-th dimension of ``A`` should be repeated. If ``inner`` or ``outer`` are omitted, no repetition is performed.

.. doctest::

julia> repeat(1:2, inner=2)
4-element Array{Int64,1}:
1
1
2
2

julia> repeat(1:2, outer=2)
4-element Array{Int64,1}:
1
2
1
2

julia> repeat([1 2; 3 4], inner=(2, 1), outer=(1, 3))
4×6 Array{Int64,2}:
1 2 1 2 1 2
1 2 1 2 1 2
3 4 3 4 3 4
3 4 3 4 3 4

.. function:: kron(A, B)

Expand Down
108 changes: 81 additions & 27 deletions test/arrayops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -502,53 +502,86 @@ let
@test isequal(cumsum(A,2),A2)
@test isequal(cumsum(A,3),A3)

R = repeat([1, 2], inner = [1], outer = [1])
R = repeat([1, 2])
@test R == [1, 2]
R = repeat([1, 2], inner = [2], outer = [1])
R = repeat([1, 2], inner=1)
@test R == [1, 2]
R = repeat([1, 2], outer=1)
@test R == [1, 2]
R = repeat([1, 2], inner=(1,))
@test R == [1, 2]
R = repeat([1, 2], outer=(1,))
@test R == [1, 2]
R = repeat([1, 2], inner=[1])
@test R == [1, 2]
R = repeat([1, 2], outer=[1])
@test R == [1, 2]
R = repeat([1, 2], inner=1, outer=1)
@test R == [1, 2]
R = repeat([1, 2], inner=(1,), outer=(1,))
@test R == [1, 2]
R = repeat([1, 2], inner=[1], outer=[1])
@test R == [1, 2]

R = repeat([1, 2], inner=2)
@test R == [1, 1, 2, 2]
R = repeat([1, 2], outer=2)
@test R == [1, 2, 1, 2]
R = repeat([1, 2], inner=(2,))
@test R == [1, 1, 2, 2]
R = repeat([1, 2], inner = [1], outer = [2])
R = repeat([1, 2], outer=(2,))
@test R == [1, 2, 1, 2]
R = repeat([1, 2], inner = [2], outer = [2])
R = repeat([1, 2], inner=[2])
@test R == [1, 1, 2, 2]
R = repeat([1, 2], outer=[2])
@test R == [1, 2, 1, 2]

R = repeat([1, 2], inner=2, outer=2)
@test R == [1, 1, 2, 2, 1, 1, 2, 2]
R = repeat([1, 2], inner = [1, 1], outer = [1, 1])
R = repeat([1, 2], inner=(2,), outer=(2,))
@test R == [1, 1, 2, 2, 1, 1, 2, 2]
R = repeat([1, 2], inner=[2], outer=[2])
@test R == [1, 1, 2, 2, 1, 1, 2, 2]

R = repeat([1, 2], inner = (1, 1), outer = (1, 1))
@test R == [1, 2]''
R = repeat([1, 2], inner = [2, 1], outer = [1, 1])
R = repeat([1, 2], inner = (2, 1), outer = (1, 1))
@test R == [1, 1, 2, 2]''
R = repeat([1, 2], inner = [1, 2], outer = [1, 1])
R = repeat([1, 2], inner = (1, 2), outer = (1, 1))
@test R == [1 1; 2 2]
R = repeat([1, 2], inner = [1, 1], outer = [2, 1])
R = repeat([1, 2], inner = (1, 1), outer = (2, 1))
@test R == [1, 2, 1, 2]''
R = repeat([1, 2], inner = [1, 1], outer = [1, 2])
R = repeat([1, 2], inner = (1, 1), outer = (1, 2))
@test R == [1 1; 2 2]

R = repeat([1 2;
3 4], inner = [1, 1], outer = [1, 1])
3 4], inner = (1, 1), outer = (1, 1))
@test R == [1 2;
3 4]
R = repeat([1 2;
3 4], inner = [1, 1], outer = [2, 1])
3 4], inner = (1, 1), outer = (2, 1))
@test R == [1 2;
3 4;
1 2;
3 4]
R = repeat([1 2;
3 4], inner = [1, 1], outer = [1, 2])
3 4], inner = (1, 1), outer = (1, 2))
@test R == [1 2 1 2;
3 4 3 4]
R = repeat([1 2;
3 4], inner = [1, 1], outer = [2, 2])
3 4], inner = (1, 1), outer = (2, 2))
@test R == [1 2 1 2;
3 4 3 4;
1 2 1 2;
3 4 3 4]
R = repeat([1 2;
3 4], inner = [2, 1], outer = [1, 1])
3 4], inner = (2, 1), outer = (1, 1))
@test R == [1 2;
1 2;
3 4;
3 4]
R = repeat([1 2;
3 4], inner = [2, 1], outer = [2, 1])
3 4], inner = (2, 1), outer = (2, 1))
@test R == [1 2;
1 2;
3 4;
Expand All @@ -558,13 +591,13 @@ let
3 4;
3 4]
R = repeat([1 2;
3 4], inner = [2, 1], outer = [1, 2])
3 4], inner = (2, 1), outer = (1, 2))
@test R == [1 2 1 2;
1 2 1 2;
3 4 3 4;
3 4 3 4;]
R = repeat([1 2;
3 4], inner = [2, 1], outer = [2, 2])
3 4], inner = (2, 1), outer = (2, 2))
@test R == [1 2 1 2;
1 2 1 2;
3 4 3 4;
Expand All @@ -574,33 +607,33 @@ let
3 4 3 4;
3 4 3 4]
R = repeat([1 2;
3 4], inner = [1, 2], outer = [1, 1])
3 4], inner = (1, 2), outer = (1, 1))
@test R == [1 1 2 2;
3 3 4 4]
R = repeat([1 2;
3 4], inner = [1, 2], outer = [2, 1])
3 4], inner = (1, 2), outer = (2, 1))
@test R == [1 1 2 2;
3 3 4 4;
1 1 2 2;
3 3 4 4]
R = repeat([1 2;
3 4], inner = [1, 2], outer = [1, 2])
3 4], inner = (1, 2), outer = (1, 2))
@test R == [1 1 2 2 1 1 2 2;
3 3 4 4 3 3 4 4]
R = repeat([1 2;
3 4], inner = [1, 2], outer = [2, 2])
3 4], inner = (1, 2), outer = (2, 2))
@test R == [1 1 2 2 1 1 2 2;
3 3 4 4 3 3 4 4;
1 1 2 2 1 1 2 2;
3 3 4 4 3 3 4 4]
R = repeat([1 2;
3 4], inner = [2, 2], outer = [1, 1])
3 4], inner = (2, 2), outer = [1, 1])
@test R == [1 1 2 2;
1 1 2 2;
3 3 4 4;
3 3 4 4]
R = repeat([1 2;
3 4], inner = [2, 2], outer = [2, 1])
3 4], inner = (2, 2), outer = (2, 1))
@test R == [1 1 2 2;
1 1 2 2;
3 3 4 4;
Expand All @@ -610,13 +643,13 @@ let
3 3 4 4;
3 3 4 4]
R = repeat([1 2;
3 4], inner = [2, 2], outer = [1, 2])
3 4], inner = (2, 2), outer = (1, 2))
@test R == [1 1 2 2 1 1 2 2;
1 1 2 2 1 1 2 2;
3 3 4 4 3 3 4 4;
3 3 4 4 3 3 4 4]
R = repeat([1 2;
3 4], inner = [2, 2], outer = [2, 2])
3 4], inner = (2, 2), outer = (2, 2))
@test R == [1 1 2 2 1 1 2 2;
1 1 2 2 1 1 2 2;
3 3 4 4 3 3 4 4;
Expand All @@ -625,17 +658,25 @@ let
1 1 2 2 1 1 2 2;
3 3 4 4 3 3 4 4;
3 3 4 4 3 3 4 4]
@test_throws ArgumentError repeat([1 2;
3 4], inner=2, outer=(2, 2))
@test_throws ArgumentError repeat([1 2;
3 4], inner=(2, 2), outer=2)
@test_throws ArgumentError repeat([1 2;
3 4], inner=(2,), outer=(2, 2))
@test_throws ArgumentError repeat([1 2;
3 4], inner=(2, 2), outer=(2,))

A = reshape(1:8, 2, 2, 2)
R = repeat(A, inner = [1, 1, 2], outer = [1, 1, 1])
R = repeat(A, inner = (1, 1, 2), outer = (1, 1, 1))
T = reshape([1:4; 1:4; 5:8; 5:8], 2, 2, 4)
@test R == T
A = Array(Int, 2, 2, 2)
A[:, :, 1] = [1 2;
3 4]
A[:, :, 2] = [5 6;
7 8]
R = repeat(A, inner = [2, 2, 2], outer = [2, 2, 2])
R = repeat(A, inner = (2, 2, 2), outer = (2, 2, 2))
@test R[1, 1, 1] == 1
@test R[2, 2, 2] == 1
@test R[3, 3, 3] == 8
Expand All @@ -645,6 +686,19 @@ let
@test R[7, 7, 7] == 8
@test R[8, 8, 8] == 8

R = repeat(1:2)
@test R == [1, 2]
R = repeat(1:2, inner=1)
@test R == [1, 2]
R = repeat(1:2, inner=2)
@test R == [1, 1, 2, 2]
R = repeat(1:2, outer=1)
@test R == [1, 2]
R = repeat(1:2, outer=2)
@test R == [1, 2, 1, 2]
R = repeat(1:2, inner=(3,), outer=(2,))
@test R == [1, 1, 1, 2, 2, 2, 1, 1, 1, 2, 2, 2]

A = rand(4,4)
for s in Any[A[1:2:4, 1:2:4], sub(A, 1:2:4, 1:2:4)]
c = cumsum(s, 1)
Expand Down

0 comments on commit 93966d5

Please sign in to comment.