Skip to content

Commit

Permalink
RFC: Change iteratorsize trait of product(itr1, itr2) (#16437)
Browse files Browse the repository at this point in the history
change iteratorsize trait of `product(itr1, itr2)`

fixes #16436
- Adds many tests to product function and tests more thoroughly the iterator traits
- Adds a Prod1 type
- Adds ndims(::Base.Prod*)
- Change state of Prod1 iterator from tuple to integer
  • Loading branch information
gasagna authored and JeffBezanson committed May 29, 2016
1 parent 2fa728a commit 26a6bfc
Show file tree
Hide file tree
Showing 2 changed files with 221 additions and 17 deletions.
58 changes: 49 additions & 9 deletions base/iterator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -302,10 +302,47 @@ done(it::Repeated, state) = false
iteratorsize{O}(::Type{Repeated{O}}) = IsInfinite()
iteratoreltype{O}(::Type{Repeated{O}}) = HasEltype()

# product

# Product -- cartesian product of iterators

abstract AbstractProdIterator

length(p::AbstractProdIterator) = prod(size(p))
size(p::AbstractProdIterator) = _prod_size(p.a, p.b, iteratorsize(p.a), iteratorsize(p.b))
ndims(p::AbstractProdIterator) = length(size(p))

# generic methods to handle size of Prod* types
_prod_size(a, ::HasShape) = size(a)
_prod_size(a, ::HasLength) = (length(a), )
_prod_size(a, A) =
throw(ArgumentError("Cannot compute size for object of type $(typeof(a))"))
_prod_size(a, b, ::HasLength, ::HasLength) = (length(a), length(b))
_prod_size(a, b, ::HasLength, ::HasShape) = (length(a), size(b)...)
_prod_size(a, b, ::HasShape, ::HasLength) = (size(a)..., length(b))
_prod_size(a, b, ::HasShape, ::HasShape) = (size(a)..., size(b)...)
_prod_size(a, b, A, B) =
throw(ArgumentError("Cannot construct size for objects of types $(typeof(a)) and $(typeof(b))"))

# one iterator
immutable Prod1{I} <: AbstractProdIterator
a::I
end
product(a) = Prod1(a)

eltype{I}(::Type{Prod1{I}}) = Tuple{eltype(I)}
size(p::Prod1) = _prod_size(p.a, iteratorsize(p.a))

@inline start(p::Prod1) = start(p.a)
@inline function next(p::Prod1, st)
n, st = next(p.a, st)
(n, ), st
end
@inline done(p::Prod1, st) = done(p.a, st)

iteratoreltype{I}(::Type{Prod1{I}}) = iteratoreltype(I)
iteratorsize{I}(::Type{Prod1{I}}) = iteratorsize(I)

# two iterators
immutable Prod2{I1, I2} <: AbstractProdIterator
a::I1
b::I2
Expand All @@ -327,11 +364,11 @@ changes the fastest. Example:
(1,5)
(2,5)
"""
product(a) = Zip1(a)
product(a, b) = Prod2(a, b)

eltype{I1,I2}(::Type{Prod2{I1,I2}}) = Tuple{eltype(I1), eltype(I2)}

iteratoreltype{I1,I2}(::Type{Prod2{I1,I2}}) = and_iteratoreltype(iteratoreltype(I1),iteratoreltype(I2))
length(p::AbstractProdIterator) = length(p.a)*length(p.b)
iteratorsize{I1,I2}(::Type{Prod2{I1,I2}}) = prod_iteratorsize(iteratorsize(I1),iteratorsize(I2))

function start(p::AbstractProdIterator)
Expand Down Expand Up @@ -359,13 +396,15 @@ end
@inline next(p::Prod2, st) = prod_next(p, st)
@inline done(p::AbstractProdIterator, st) = st[4]

# n iterators
immutable Prod{I1, I2<:AbstractProdIterator} <: AbstractProdIterator
a::I1
b::I2
end

product(a, b, c...) = Prod(a, product(b, c...))

eltype{I1,I2}(::Type{Prod{I1,I2}}) = tuple_type_cons(eltype(I1), eltype(I2))

iteratoreltype{I1,I2}(::Type{Prod{I1,I2}}) = and_iteratoreltype(iteratoreltype(I1),iteratoreltype(I2))
iteratorsize{I1,I2}(::Type{Prod{I1,I2}}) = prod_iteratorsize(iteratorsize(I1),iteratorsize(I2))

Expand All @@ -374,12 +413,13 @@ iteratorsize{I1,I2}(::Type{Prod{I1,I2}}) = prod_iteratorsize(iteratorsize(I1),it
((x[1][1],x[1][2]...), x[2])
end

prod_iteratorsize(::Union{HasLength,HasShape}, ::Union{HasLength,HasShape}) = HasLength()
prod_iteratorsize(a, ::IsInfinite) = IsInfinite() # products can have an infinite last iterator (which moves slowest)
prod_iteratorsize(::Union{HasLength,HasShape}, ::Union{HasLength,HasShape}) = HasShape()
# products can have an infinite iterator
prod_iteratorsize(::IsInfinite, ::IsInfinite) = IsInfinite()
prod_iteratorsize(a, ::IsInfinite) = IsInfinite()
prod_iteratorsize(::IsInfinite, b) = IsInfinite()
prod_iteratorsize(a, b) = SizeUnknown()

_size(p::Prod2) = (length(p.a), length(p.b))
_size(p::Prod) = (length(p.a), _size(p.b)...)

"""
IteratorND(iter, dims)
Expand All @@ -400,7 +440,7 @@ immutable IteratorND{I,N}
end
new{I,N}(iter, shape)
end
(::Type{IteratorND}){I<:AbstractProdIterator}(p::I) = IteratorND(p, _size(p))
(::Type{IteratorND}){I<:AbstractProdIterator}(p::I) = IteratorND(p, size(p))
end

start(i::IteratorND) = start(i.iter)
Expand Down
180 changes: 172 additions & 8 deletions test/functional.jl
Original file line number Diff line number Diff line change
Expand Up @@ -188,14 +188,178 @@ end
# product
# -------

@test isempty(Base.product(1:2,1:0))
@test isempty(Base.product(1:2,1:0,1:10))
@test isempty(Base.product(1:2,1:10,1:0))
@test isempty(Base.product(1:0,1:2,1:10))
@test collect(Base.product(1:2,3:4)) == [(1,3),(2,3),(1,4),(2,4)]
@test isempty(collect(Base.product(1:0,1:2)))
@test length(Base.product(1:2,1:10,4:6)) == 60
@test Base.iteratorsize(Base.product(1:2, countfrom(1))) == Base.IsInfinite()
# empty?
for itr in [Base.product(1:0),
Base.product(1:2, 1:0),
Base.product(1:0, 1:2),
Base.product(1:0, 1:1, 1:2),
Base.product(1:1, 1:0, 1:2),
Base.product(1:1, 1:2 ,1:0)]
@test isempty(itr)
@test isempty(collect(itr))
end

# collect a product - first iterators runs faster
@test collect(Base.product(1:2)) == [(i,) for i=1:2]
@test collect(Base.product(1:2, 3:4)) == [(i, j) for i=1:2, j=3:4]
@test collect(Base.product(1:2, 3:4, 5:6)) == [(i, j, k) for i=1:2, j=3:4, k=5:6]

# iteration order
let
expected = [(1,3,5), (2,3,5), (1,4,5), (2,4,5), (1,3,6), (2,3,6), (1,4,6), (2,4,6)]
actual = Base.product(1:2, 3:4, 5:6)
for (exp, act) in zip(expected, actual)
@test exp == act
end
end

# collect multidimensional array
let
a, b = 1:3, [4 6;
5 7]
p = Base.product(a, b)
@test size(p) == (3, 2, 2)
@test length(p) == 12
@test ndims(p) == 3
@test eltype(p) == NTuple{2, Int}
cp = collect(p)
for i = 1:3
@test cp[i, :, :] == [(i, 4) (i, 6);
(i, 5) (i, 7)]
end
end

# with 1D inputs
let
a, b, c = 1:2, 1.0:10.0, Int32(1):Int32(0)

# length
@test length(Base.product(a)) == 2
@test length(Base.product(a, b)) == 20
@test length(Base.product(a, b, c)) == 0

# size
@test size(Base.product(a)) == (2, )
@test size(Base.product(a, b)) == (2, 10)
@test size(Base.product(a, b, c)) == (2, 10, 0)

# eltype
@test eltype(Base.product(a)) == Tuple{Int}
@test eltype(Base.product(a, b)) == Tuple{Int, Float64}
@test eltype(Base.product(a, b, c)) == Tuple{Int, Float64, Int32}

# ndims
@test ndims(Base.product(a)) == 1
@test ndims(Base.product(a, b)) == 2
@test ndims(Base.product(a, b, c)) == 3
end

# with multidimensional inputs
let
a, b, c = randn(4, 4), randn(3, 3, 3), randn(2, 2, 2, 2)
args = Any[(a,),
(a, a),
(a, b),
(a, a, a),
(a, b, c)]
sizes = Any[(4, 4),
(4, 4, 4, 4),
(4, 4, 3, 3, 3),
(4, 4, 4, 4, 4, 4),
(4, 4, 3, 3, 3, 2, 2, 2, 2)]
for (method, fun) in zip([size, ndims, length], [x->x, length, prod])
for i in 1:length(args)
@test method(Base.product(args[i]...)) == method(collect(Base.product(args[i]...))) == fun(sizes[i])
end
end
end

# more tests on product with iterators of various type
let
iters = (1:2,
rand(2, 2, 2),
take(1:4, 2),
Base.product(1:2, 1:3),
Base.product(rand(2, 2), rand(1, 1, 1))
)
for method in [size, length, ndims, eltype]
for i = 1:length(iters)
args = iters[i]
@test method(Base.product(args...)) == method(collect(Base.product(args...)))
for j = 1:length(iters)
args = iters[i], iters[j]
@test method(Base.product(args...)) == method(collect(Base.product(args...)))
for k = 1:length(iters)
args = iters[i], iters[j], iters[k]
@test method(Base.product(args...)) == method(collect(Base.product(args...)))
end
end
end
end
end

# product of finite length and infinite length iterators
let
a = 1:2
b = countfrom(1)
ab = Base.product(a, b)
ba = Base.product(b, a)
abexp = [(1, 1), (2, 1), (1, 2), (2, 2), (1, 3), (2, 3)]
baexp = [(1, 1), (2, 1), (3, 1), (4, 1), (5, 1), (6, 1)]
for (expected, actual) in zip([abexp, baexp], [ab, ba])
for (i, el) in enumerate(actual)
@test el == expected[i]
i == length(expected) && break
end
@test_throws ArgumentError length(actual)
@test_throws ArgumentError size(actual)
@test_throws ArgumentError ndims(actual)
end

# size infinite or unknown raises an error
for itr in Any[countfrom(1), Filter(i->0, 1:10)]
@test_throws ArgumentError length(Base.product(itr))
@test_throws ArgumentError size(Base.product(itr))
@test_throws ArgumentError ndims(Base.product(itr))
end
end

# iteratorsize trait business
let f1 = Filter(i->i>0, 1:10)
@test Base.iteratorsize(Base.product(f1)) == Base.SizeUnknown()
@test Base.iteratorsize(Base.product(1:2, f1)) == Base.SizeUnknown()
@test Base.iteratorsize(Base.product(f1, 1:2)) == Base.SizeUnknown()
@test Base.iteratorsize(Base.product(f1, f1)) == Base.SizeUnknown()
@test Base.iteratorsize(Base.product(f1, countfrom(1))) == Base.IsInfinite()
@test Base.iteratorsize(Base.product(countfrom(1), f1)) == Base.IsInfinite()
end
@test Base.iteratorsize(Base.product(1:2, countfrom(1))) == Base.IsInfinite()
@test Base.iteratorsize(Base.product(countfrom(2), countfrom(1))) == Base.IsInfinite()
@test Base.iteratorsize(Base.product(countfrom(1), 1:2)) == Base.IsInfinite()
@test Base.iteratorsize(Base.product(1:2)) == Base.HasShape()
@test Base.iteratorsize(Base.product(1:2, 1:2)) == Base.HasShape()
@test Base.iteratorsize(Base.product(take(1:2, 1), take(1:2, 1))) == Base.HasShape()
@test Base.iteratorsize(Base.product(take(1:2, 2))) == Base.HasLength()
@test Base.iteratorsize(Base.product([1 2; 3 4])) == Base.HasShape()

# iteratoreltype trait business
let f1 = Filter(i->i>0, 1:10)
@test Base.iteratoreltype(Base.product(f1)) == Base.HasEltype() # FIXME? eltype(f1) is Any
@test Base.iteratoreltype(Base.product(1:2, f1)) == Base.HasEltype() # FIXME? eltype(f1) is Any
@test Base.iteratoreltype(Base.product(f1, 1:2)) == Base.HasEltype() # FIXME? eltype(f1) is Any
@test Base.iteratoreltype(Base.product(f1, f1)) == Base.HasEltype() # FIXME? eltype(f1) is Any
@test Base.iteratoreltype(Base.product(f1, countfrom(1))) == Base.HasEltype() # FIXME? eltype(f1) is Any
@test Base.iteratoreltype(Base.product(countfrom(1), f1)) == Base.HasEltype() # FIXME? eltype(f1) is Any
end
@test Base.iteratoreltype(Base.product(1:2, countfrom(1))) == Base.HasEltype()
@test Base.iteratoreltype(Base.product(countfrom(1), 1:2)) == Base.HasEltype()
@test Base.iteratoreltype(Base.product(1:2)) == Base.HasEltype()
@test Base.iteratoreltype(Base.product(1:2, 1:2)) == Base.HasEltype()
@test Base.iteratoreltype(Base.product(take(1:2, 1), take(1:2, 1))) == Base.HasEltype()
@test Base.iteratoreltype(Base.product(take(1:2, 2))) == Base.HasEltype()
@test Base.iteratoreltype(Base.product([1 2; 3 4])) == Base.HasEltype()



# flatten
# -------
Expand Down

0 comments on commit 26a6bfc

Please sign in to comment.