From d84bc95a4e1277ea70c2b6d0a4ed693e87915ed7 Mon Sep 17 00:00:00 2001 From: Mason Protter Date: Fri, 26 May 2023 11:01:23 -0600 Subject: [PATCH 1/7] add getindex method to ProductIterator --- base/iterators.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/base/iterators.jl b/base/iterators.jl index 11e94d3384de8..ba4afa7dc4d6f 100644 --- a/base/iterators.jl +++ b/base/iterators.jl @@ -1128,6 +1128,7 @@ end reverse(p::ProductIterator) = ProductIterator(Base.map(reverse, p.iterators)) last(p::ProductIterator) = Base.map(last, p.iterators) intersect(a::ProductIterator, b::ProductIterator) = ProductIterator(intersect.(a.iterators, b.iterators)) +getindex(p::ProductIterator, inds...) = map(getindex, p.iterators, inds) # flatten an iterator of iterators From bbc374f3a7d999fb42520538ce614dc02c9bd770 Mon Sep 17 00:00:00 2001 From: Mason Protter Date: Sun, 26 Nov 2023 23:48:16 +0100 Subject: [PATCH 2/7] use `Base.map` --- base/iterators.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/base/iterators.jl b/base/iterators.jl index 689e18d6da7cd..8d7c5a18292f9 100644 --- a/base/iterators.jl +++ b/base/iterators.jl @@ -1158,7 +1158,7 @@ end reverse(p::ProductIterator) = ProductIterator(Base.map(reverse, p.iterators)) last(p::ProductIterator) = Base.map(last, p.iterators) intersect(a::ProductIterator, b::ProductIterator) = ProductIterator(intersect.(a.iterators, b.iterators)) -getindex(p::ProductIterator, inds...) = map(getindex, p.iterators, inds) +getindex(p::ProductIterator, inds...) = Base.map(getindex, p.iterators, inds) # flatten an iterator of iterators From af4e77a78a55c5e234ac45a941d92ca5b520210b Mon Sep 17 00:00:00 2001 From: Mason Protter Date: Sun, 26 Nov 2023 23:56:45 +0100 Subject: [PATCH 3/7] Update iterators.jl --- base/iterators.jl | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/base/iterators.jl b/base/iterators.jl index 8d7c5a18292f9..91b47bd2c5072 100644 --- a/base/iterators.jl +++ b/base/iterators.jl @@ -1158,7 +1158,11 @@ end reverse(p::ProductIterator) = ProductIterator(Base.map(reverse, p.iterators)) last(p::ProductIterator) = Base.map(last, p.iterators) intersect(a::ProductIterator, b::ProductIterator) = ProductIterator(intersect.(a.iterators, b.iterators)) -getindex(p::ProductIterator, inds...) = Base.map(getindex, p.iterators, inds) +function getindex(p::ProductIterator, inds...) + length(inds) == length(p.iterators) || prod_indexing_error(p, inds) + Base.map(getindex, p.iterators, inds) +end +@noinline prod_indexing_error(p, inds) = throw(BoundsError(p, inds)) # flatten an iterator of iterators From dced4c2ea1d8c14a9e525bdd9307da846e2b1394 Mon Sep 17 00:00:00 2001 From: Mason Protter Date: Sun, 26 Nov 2023 23:57:00 +0100 Subject: [PATCH 4/7] add some tests --- test/iterators.jl | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/test/iterators.jl b/test/iterators.jl index 46e7c8b454335..1c2261de4a70f 100644 --- a/test/iterators.jl +++ b/test/iterators.jl @@ -320,6 +320,18 @@ let (a, b) = (1:3, [4 6; end end +let + p1 = product(Dict(:a => 1, :b => 2), [1, 2, 3]) + p2 = product(-5:5, 12:10000) + p3 = product(rand(3), (x for x in 1:10 if rand(Bool))) + p4 = product([:a, :b, :c], ['i', 'j', 'k']) + @test p1[:b, 3] == (2, 3) + @test p2[2, 1] == (-4, 12) + @test_throws MethodError p3[1, 2] + @test p4[2, 3] == (:b, 'k') + @test_throws BoundsError p4[2] +end + # collect stateful iterator let itr itr = Iterators.Stateful(Iterators.map(identity, 1:5)) From b59e86bae247e453646c97ece5c179265e57d246 Mon Sep 17 00:00:00 2001 From: Mason Protter Date: Mon, 27 Nov 2023 00:00:59 +0100 Subject: [PATCH 5/7] whitespace --- test/iterators.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/iterators.jl b/test/iterators.jl index 1c2261de4a70f..48057b3b462aa 100644 --- a/test/iterators.jl +++ b/test/iterators.jl @@ -331,7 +331,7 @@ let @test p4[2, 3] == (:b, 'k') @test_throws BoundsError p4[2] end - + # collect stateful iterator let itr itr = Iterators.Stateful(Iterators.map(identity, 1:5)) From 0c089e561b32a99d4d8c8409b0a64d910cdd4950 Mon Sep 17 00:00:00 2001 From: Mason Protter Date: Wed, 6 Dec 2023 10:30:44 +0100 Subject: [PATCH 6/7] fixup --- base/iterators.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/base/iterators.jl b/base/iterators.jl index 2ead570b42901..2c497a9872e6e 100644 --- a/base/iterators.jl +++ b/base/iterators.jl @@ -9,10 +9,10 @@ baremodule Iterators import ..@__MODULE__, ..parentmodule const Base = parentmodule(@__MODULE__) using .Base: - @inline, Pair, Pairs, AbstractDict, IndexLinear, IndexStyle, AbstractVector, Vector, - SizeUnknown, HasLength, HasShape, IsInfinite, EltypeUnknown, HasEltype, OneTo, + @inline, @noinline, Pair, Pairs, AbstractDict, IndexLinear, IndexStyle, AbstractVector, + Vector, SizeUnknown, HasLength, HasShape, IsInfinite, EltypeUnknown, HasEltype, OneTo, @propagate_inbounds, @isdefined, @boundscheck, @inbounds, Generator, IdDict, - AbstractRange, AbstractUnitRange, UnitRange, LinearIndices, TupleOrBottom, + AbstractRange, AbstractUnitRange, UnitRange, LinearIndices, TupleOrBottom, DimensionMismatch, (:), |, +, -, *, !==, !, ==, !=, <=, <, >, >=, missing, any, _counttuple, eachindex, ntuple, zero, prod, reduce, in, firstindex, lastindex, tail, fieldtypes, min, max, minimum, zero, oneunit, promote, promote_shape @@ -1162,7 +1162,7 @@ function getindex(p::ProductIterator, inds...) length(inds) == length(p.iterators) || prod_indexing_error(p, inds) Base.map(getindex, p.iterators, inds) end -@noinline prod_indexing_error(p, inds) = throw(BoundsError(p, inds)) +@noinline prod_indexing_error(p, inds) = throw(DimensionMismatch("Attempted to index a product of $(length(p.iterators)) iterators with $(length(inds)) indices.")) # flatten an iterator of iterators From 5493114660c6f377f20321dc5191c96601afb91a Mon Sep 17 00:00:00 2001 From: Mason Protter Date: Wed, 6 Dec 2023 10:31:10 +0100 Subject: [PATCH 7/7] Update iterators.jl --- test/iterators.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/iterators.jl b/test/iterators.jl index 48057b3b462aa..d4f741348f86d 100644 --- a/test/iterators.jl +++ b/test/iterators.jl @@ -329,7 +329,7 @@ let @test p2[2, 1] == (-4, 12) @test_throws MethodError p3[1, 2] @test p4[2, 3] == (:b, 'k') - @test_throws BoundsError p4[2] + @test_throws DimensionMismatch p4[2] end # collect stateful iterator