diff --git a/base/iterators.jl b/base/iterators.jl index 6b8d9fe75e302..7c1254fb7db23 100644 --- a/base/iterators.jl +++ b/base/iterators.jl @@ -9,10 +9,10 @@ baremodule Iterators import ..@__MODULE__, ..parentmodule const Base = parentmodule(@__MODULE__) using .Base: - @inline, Pair, Pairs, AbstractDict, IndexLinear, IndexStyle, AbstractVector, Vector, - SizeUnknown, HasLength, HasShape, IsInfinite, EltypeUnknown, HasEltype, OneTo, + @inline, @noinline, Pair, Pairs, AbstractDict, IndexLinear, IndexStyle, AbstractVector, + Vector, SizeUnknown, HasLength, HasShape, IsInfinite, EltypeUnknown, HasEltype, OneTo, @propagate_inbounds, @isdefined, @boundscheck, @inbounds, Generator, IdDict, - AbstractRange, AbstractUnitRange, UnitRange, LinearIndices, TupleOrBottom, + AbstractRange, AbstractUnitRange, UnitRange, LinearIndices, TupleOrBottom, DimensionMismatch, (:), |, +, -, *, !==, !, ==, !=, <=, <, >, >=, =>, missing, any, _counttuple, eachindex, ntuple, zero, prod, reduce, in, firstindex, lastindex, tail, fieldtypes, min, max, minimum, zero, oneunit, promote, promote_shape, LazyString @@ -1166,6 +1166,11 @@ end reverse(p::ProductIterator) = ProductIterator(Base.map(reverse, p.iterators)) last(p::ProductIterator) = Base.map(last, p.iterators) intersect(a::ProductIterator, b::ProductIterator) = ProductIterator(intersect.(a.iterators, b.iterators)) +function getindex(p::ProductIterator, inds...) + length(inds) == length(p.iterators) || prod_indexing_error(p, inds) + Base.map(getindex, p.iterators, inds) +end +@noinline prod_indexing_error(p, inds) = throw(DimensionMismatch("Attempted to index a product of $(length(p.iterators)) iterators with $(length(inds)) indices.")) # flatten an iterator of iterators diff --git a/test/iterators.jl b/test/iterators.jl index d1e7525c43465..70fe469e4f6e6 100644 --- a/test/iterators.jl +++ b/test/iterators.jl @@ -336,6 +336,18 @@ let (a, b) = (1:3, [4 6; end end +let + p1 = product(Dict(:a => 1, :b => 2), [1, 2, 3]) + p2 = product(-5:5, 12:10000) + p3 = product(rand(3), (x for x in 1:10 if rand(Bool))) + p4 = product([:a, :b, :c], ['i', 'j', 'k']) + @test p1[:b, 3] == (2, 3) + @test p2[2, 1] == (-4, 12) + @test_throws MethodError p3[1, 2] + @test p4[2, 3] == (:b, 'k') + @test_throws DimensionMismatch p4[2] +end + # collect stateful iterator let itr itr = Iterators.Stateful(Iterators.map(identity, 1:5))