Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC: Change iteratorsize trait of product(itr1, itr2) #16437

Merged
merged 5 commits into from
May 29, 2016
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Improved behaviour for multi-dimensional arguments
Davide Lasagna committed May 21, 2016
commit 5c082ec4a7710a62be1d54fa3c5811d0c90860f8
29 changes: 21 additions & 8 deletions base/iterator.jl
Original file line number Diff line number Diff line change
@@ -302,7 +302,24 @@ repeated(x, n::Int) = take(repeated(x), n)
# Product -- cartesian product of iterators

abstract AbstractProdIterator

length(p::AbstractProdIterator) = prod(size(p))
size(p::AbstractProdIterator) = _prod_size(p.a, p.b, iteratorsize(p.a), iteratorsize(p.b))
ndims(p::AbstractProdIterator) = length(size(p))

# generic methods to handle size of Prod* types
_prod_size(a, ::HasShape) = size(a)
_prod_size(a, ::HasLength) = (length(a), )
_prod_size(a, A) =
throw(ArgumentError("Cannot compute size for object of type $(typeof(a))"))
_prod_size(a, b, ::HasLength, ::HasLength) = (length(a), length(b))
_prod_size(a, b, ::HasLength, ::HasShape) = (length(a), size(b)...)
_prod_size(a, b, ::HasShape, ::HasLength) = (size(a)..., length(b))
_prod_size(a, b, ::HasShape, ::HasShape) = (size(a)..., size(b)...)
_prod_size(a, b, A, ::Union{HasShape, HasLength}) =
throw(ArgumentError("Cannot compute size for object of type $(typeof(a))"))
_prod_size(a, b, ::Union{HasShape, HasLength}, B) =
throw(ArgumentError("Cannot compute size for object of type $(typeof(b))"))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one is ambiguous

# one iterator
immutable Prod1{I} <: AbstractProdIterator
@@ -311,13 +328,12 @@ end
product(a) = Prod1(a)

eltype{I}(::Type{Prod1{I}}) = Tuple{eltype(I)}
size(p::Prod1) = size(p.a)
ndims(p::Prod1) = ndims(p.a)
size(p::Prod1) = _prod_size(p.a, iteratorsize(p.a))

@inline start(p::Prod1) = start(p.a)
@inline function next(p::Prod1, st)
n, st = next(p.a, st)
return (n, ), st
(n, ), st
end
@inline done(p::Prod1, st) = done(p.a, st)

@@ -349,8 +365,6 @@ changes the fastest. Example:
product(a, b) = Prod2(a, b)

eltype{I1,I2}(::Type{Prod2{I1,I2}}) = Tuple{eltype(I1), eltype(I2)}
size(p::Prod2) = (length(p.a), length(p.b))
ndims(p::Prod2) = 2

iteratoreltype{I1,I2}(::Type{Prod2{I1,I2}}) = and_iteratoreltype(iteratoreltype(I1),iteratoreltype(I2))
iteratorsize{I1,I2}(::Type{Prod2{I1,I2}}) = prod_iteratorsize(iteratorsize(I1),iteratorsize(I2))
@@ -388,8 +402,6 @@ end
product(a, b, c...) = Prod(a, product(b, c...))

eltype{I1,I2}(::Type{Prod{I1,I2}}) = tuple_type_cons(eltype(I1), eltype(I2))
size(p::Prod) = (length(p.a), size(p.b)...)
ndims(p::Prod) = length(size(p))

iteratoreltype{I1,I2}(::Type{Prod{I1,I2}}) = and_iteratoreltype(iteratoreltype(I1),iteratoreltype(I2))
iteratorsize{I1,I2}(::Type{Prod{I1,I2}}) = prod_iteratorsize(iteratorsize(I1),iteratorsize(I2))
@@ -400,7 +412,8 @@ iteratorsize{I1,I2}(::Type{Prod{I1,I2}}) = prod_iteratorsize(iteratorsize(I1),it
end

prod_iteratorsize(::Union{HasLength,HasShape}, ::Union{HasLength,HasShape}) = HasShape()
prod_iteratorsize(a, ::IsInfinite) = IsInfinite() # products can have an infinite last iterator (which moves slowest)
# products can have an infinite iterator
prod_iteratorsize(a, ::IsInfinite) = IsInfinite()
prod_iteratorsize(::IsInfinite, b) = IsInfinite()
prod_iteratorsize(a, b) = SizeUnknown()
Copy link
Contributor

@mschauer mschauer May 27, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one too, but this should be removed, because a or b might be zero and 0*infty=0 in terms of taking products.

Copy link
Contributor Author

@gasagna gasagna May 28, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But that would be a special case. How would you handle the majority of cases where you have (finite nonzero size) * (infinite size) ?


157 changes: 105 additions & 52 deletions test/functional.jl
Original file line number Diff line number Diff line change
@@ -194,70 +194,123 @@ end
@test collect(Base.product(1:2, 3:4, 5:6)) == [(i, j, k) for i=1:2, j=3:4, k=5:6]

# iteration order
let expected = [(1,3,5), (2,3,5), (1,4,5), (2,4,5), (1,3,6), (2,3,6), (1,4,6), (2,4,6)]
i = 1
for el in Base.product(1:2, 3:4, 5:6)
@test el == expected[i]
i+=1
let
expected = [(1,3,5), (2,3,5), (1,4,5), (2,4,5), (1,3,6), (2,3,6), (1,4,6), (2,4,6)]
actual = Base.product(1:2, 3:4, 5:6)
for (exp, act) in zip(expected, actual)
@test exp == act
end
end

# is this the correct behaviour?
@test collect(Base.product([1 2; 3 4], [5 6; 7 8])) == [(1,5) (1,7) (1,6) (1,8);
(3,5) (3,7) (3,6) (3,8);
(2,5) (2,7) (2,6) (2,8);
(4,5) (4,7) (4,6) (4,8)]
# collect multidimensional array
let
a, b = 1:3, [4 6;
5 7]
p = Base.product(a, b)
@test size(p) == (3, 2, 2)
@test length(p) == 12
@test ndims(p) == 3
@test eltype(p) == NTuple{2, Int}
cp = collect(p)
for i = 1:3
@test cp[i, :, :] == [(i, 4) (i, 6);
(i, 5) (i, 7)]
end
end

let
a, b, c, d, e = 1:2, 1.0:10.0, 4f0:6f0, 0x01:0x08, Int8(1):Int8(0)
# with 1D inputs
let
a, b, c = 1:2, 1.0:10.0, Int32(1):Int32(0)

# length
@test length(Base.product(a)) == 2
@test length(Base.product(a, b)) == 20
@test length(Base.product(a, b, c)) == 60
@test length(Base.product(a, b, c, d)) == 480
@test length(Base.product(a, b, c, d, e)) == 0
@test length(Base.product(a)) == 2
@test length(Base.product(a, b)) == 20
@test length(Base.product(a, b, c)) == 0

# size
@test size(Base.product(a)) == (2, )
@test size(Base.product(a, b)) == (2, 10)
@test size(Base.product(a, b, c)) == (2, 10, 3)
@test size(Base.product(a, b, c, d)) == (2, 10, 3, 8)
@test size(Base.product(a, b, c, d, e)) == (2, 10, 3, 8, 0)
@test size(Base.product(a)) == (2, )
@test size(Base.product(a, b)) == (2, 10)
@test size(Base.product(a, b, c)) == (2, 10, 0)

# eltype
@test eltype(Base.product(a)) == Tuple{Int}
@test eltype(Base.product(a, b)) == Tuple{Int, Float64}
@test eltype(Base.product(a, b, c)) == Tuple{Int, Float64, Float32}
@test eltype(Base.product(a, b, c, d)) == Tuple{Int, Float64, Float32, UInt8}
@test eltype(Base.product(a, b, c, d, e)) == Tuple{Int, Float64, Float32, UInt8, Int8}
@test eltype(Base.product(a)) == Tuple{Int}
@test eltype(Base.product(a, b)) == Tuple{Int, Float64}
@test eltype(Base.product(a, b, c)) == Tuple{Int, Float64, Int32}

# ndims
@test ndims(Base.product(a)) == 1
@test ndims(Base.product(a, b)) == 2
@test ndims(Base.product(a, b, c)) == 3
@test ndims(Base.product(a, b, c, d)) == 4
@test ndims(Base.product(a, b, c, d, e)) == 5

f, g, h = randn(1, 1), randn(1, 1, 1), randn(1, 1, 1, 1)
@test ndims(Base.product(f)) == ndims(collect(Base.product(f))) == 2
@test ndims(Base.product(g)) == ndims(collect(Base.product(g))) == 3
@test ndims(Base.product(h)) == ndims(collect(Base.product(h))) == 4
@test ndims(Base.product(f, f)) == ndims(collect(Base.product(f, f))) == 2
@test ndims(Base.product(f, g)) == ndims(collect(Base.product(f, g))) == 2
@test ndims(Base.product(g, g)) == ndims(collect(Base.product(g, g))) == 2
@test ndims(Base.product(g, h)) == ndims(collect(Base.product(g, h))) == 2
@test ndims(Base.product(h, h)) == ndims(collect(Base.product(h, h))) == 2
@test ndims(Base.product(f, f, f)) == ndims(collect(Base.product(f, f, f))) == 3
@test ndims(Base.product(f, f, g)) == ndims(collect(Base.product(f, f, g))) == 3
@test ndims(Base.product(f, g, g)) == ndims(collect(Base.product(f, g, g))) == 3
@test ndims(Base.product(g, g, g)) == ndims(collect(Base.product(g, g, g))) == 3
@test ndims(Base.product(g, g, h)) == ndims(collect(Base.product(g, g, h))) == 3
@test ndims(Base.product(g, h, h)) == ndims(collect(Base.product(g, h, h))) == 3
@test ndims(Base.product(h, h, h)) == ndims(collect(Base.product(h, h, h))) == 3
@test ndims(Base.product(f, f, f)) == ndims(collect(Base.product(f, f, f))) == 3
@test ndims(Base.product(g, g, g, g)) == ndims(collect(Base.product(g, g, g, g))) == 4
@test ndims(Base.product(h, h, h, h)) == ndims(collect(Base.product(h, h, h, h))) == 4
@test ndims(Base.product(a)) == 1
@test ndims(Base.product(a, b)) == 2
@test ndims(Base.product(a, b, c)) == 3
end

# with multidimensional inputs
let
a, b, c = randn(4, 4), randn(3, 3, 3), randn(2, 2, 2, 2)
args = Any[(a,),
(a, a),
(a, b),
(a, a, a),
(a, b, c)]
sizes = Any[(4, 4),
(4, 4, 4, 4),
(4, 4, 3, 3, 3),
(4, 4, 4, 4, 4, 4),
(4, 4, 3, 3, 3, 2, 2, 2, 2)]
for (method, fun) in zip([size, ndims, length], [x->x, length, prod])
for i in 1:length(args)
@test method(Base.product(args[i]...)) == method(collect(Base.product(args[i]...))) == fun(sizes[i])
end
end
end

# more tests on product with iterators of various type
let
iters = (1:2,
rand(2, 2, 2),
take(1:4, 2),
Base.product(1:2, 1:3),
Base.product(rand(2, 2), rand(1, 1, 1))
)
for method in [size, length, ndims, eltype]
for i = 1:length(iters)
args = iters[i]
@test method(Base.product(args...)) == method(collect(Base.product(args...)))
for j = 1:length(iters)
args = iters[i], iters[j]
@test method(Base.product(args...)) == method(collect(Base.product(args...)))
for k = 1:length(iters)
args = iters[i], iters[j], iters[k]
@test method(Base.product(args...)) == method(collect(Base.product(args...)))
end
end
end
end
end

# product of finite length and infinite length iterators
let
a = 1:2
b = countfrom(1)
ab = Base.product(a, b)
ba = Base.product(b, a)
abexp = [(1, 1), (2, 1), (1, 2), (2, 2), (1, 3), (2, 3)]
baexp = [(1, 1), (2, 1), (3, 1), (4, 1), (5, 1), (6, 1)]
for (expected, actual) in zip([abexp, baexp], [ab, ba])
for (i, el) in enumerate(actual)
@test el == expected[i]
i == length(expected) && break
end
@test_throws ArgumentError length(actual)
@test_throws ArgumentError size(actual)
@test_throws ArgumentError ndims(actual)
end

# size infinite or unknown raises an error
for itr in Any[countfrom(1), Filter(i->0, 1:10)]
@test_throws ArgumentError length(Base.product(itr))
@test_throws ArgumentError size(Base.product(itr))
@test_throws ArgumentError ndims(Base.product(itr))
end
end

# iteratorsize trait business