From 9f5f540e52fd3928fb32f50b91f6f66754362ef0 Mon Sep 17 00:00:00 2001 From: Jameson Nash Date: Tue, 30 Jan 2024 13:58:00 -0500 Subject: [PATCH] remove length from Stateful (#51747) Stateful iterators do not have a consistent notion of length, as it is continuously changing as elements are removed. As the main purpose of Stateful is to take elements from multiple places, any notion of HaveShape is invalid for those cases, and thus not useful in general. Fix #47790 --- base/iterators.jl | 50 ++++++++++++----------------------------------- test/iterators.jl | 31 +++++++++++++++++++---------- 2 files changed, 33 insertions(+), 48 deletions(-) diff --git a/base/iterators.jl b/base/iterators.jl index a03d426e05622..b51920cdddb68 100644 --- a/base/iterators.jl +++ b/base/iterators.jl @@ -973,7 +973,7 @@ cycle(xs) = Cycle(xs) eltype(::Type{Cycle{I}}) where {I} = eltype(I) IteratorEltype(::Type{Cycle{I}}) where {I} = IteratorEltype(I) -IteratorSize(::Type{Cycle{I}}) where {I} = IsInfinite() +IteratorSize(::Type{Cycle{I}}) where {I} = IsInfinite() # XXX: this is false if iterator ever becomes empty iterate(it::Cycle) = iterate(it.xs) isdone(it::Cycle) = isdone(it.xs) @@ -1422,43 +1422,30 @@ julia> sum(a) # Sum the remaining elements 7 ``` """ -mutable struct Stateful{T, VS, N<:Integer} +mutable struct Stateful{T, VS} itr::T # A bit awkward right now, but adapted to the new iteration protocol nextvalstate::Union{VS, Nothing} - - # Number of remaining elements, if itr is HasLength or HasShape. - # if not, store -1 - number_of_consumed_elements. - # This allows us to defer calculating length until asked for. - # See PR #45924 - remaining::N @inline function Stateful{<:Any, Any}(itr::T) where {T} - itl = iterlength(itr) - new{T, Any, typeof(itl)}(itr, iterate(itr), itl) + return new{T, Any}(itr, iterate(itr)) end @inline function Stateful(itr::T) where {T} VS = approx_iter_type(T) - itl = iterlength(itr) - return new{T, VS, typeof(itl)}(itr, iterate(itr)::VS, itl) + return new{T, VS}(itr, iterate(itr)::VS) end end -function iterlength(it)::Signed - if IteratorSize(it) isa Union{HasShape, HasLength} - return length(it) - else - -1 - end +function reset!(s::Stateful) + setfield!(s, :nextvalstate, iterate(s.itr)) # bypass convert call of setproperty! + return s end - -function reset!(s::Stateful{T,VS}, itr::T=s.itr) where {T,VS} +function reset!(s::Stateful{T}, itr::T) where {T} s.itr = itr - itl = iterlength(itr) - setfield!(s, :nextvalstate, iterate(itr)) - s.remaining = itl - s + reset!(s) + return s end + # Try to find an appropriate type for the (value, state tuple), # by doing a recursive unrolling of the iteration protocol up to # fixpoint. @@ -1480,7 +1467,6 @@ end Stateful(x::Stateful) = x convert(::Type{Stateful}, itr) = Stateful(itr) - @inline isdone(s::Stateful, st=nothing) = s.nextvalstate === nothing @inline function popfirst!(s::Stateful) @@ -1490,8 +1476,6 @@ convert(::Type{Stateful}, itr) = Stateful(itr) else val, state = vs Core.setfield!(s, :nextvalstate, iterate(s.itr, state)) - rem = s.remaining - s.remaining = rem - typeof(rem)(1) return val end end @@ -1501,20 +1485,10 @@ end return ns !== nothing ? ns[1] : sentinel end @inline iterate(s::Stateful, state=nothing) = s.nextvalstate === nothing ? nothing : (popfirst!(s), nothing) -IteratorSize(::Type{<:Stateful{T}}) where {T} = IteratorSize(T) isa HasShape ? HasLength() : IteratorSize(T) +IteratorSize(::Type{<:Stateful{T}}) where {T} = IteratorSize(T) isa IsInfinite ? IsInfinite() : SizeUnknown() eltype(::Type{<:Stateful{T}}) where {T} = eltype(T) IteratorEltype(::Type{<:Stateful{T}}) where {T} = IteratorEltype(T) -function length(s::Stateful) - rem = s.remaining - # If rem is actually remaining length, return it. - # else, rem is number of consumed elements. - if rem >= 0 - rem - else - length(s.itr) - (typeof(rem)(1) - rem) - end -end end # if statement several hundred lines above """ diff --git a/test/iterators.jl b/test/iterators.jl index d8184eab7b656..6fd308a31d746 100644 --- a/test/iterators.jl +++ b/test/iterators.jl @@ -853,8 +853,10 @@ end v, s = iterate(z) @test Base.isdone(z, s) end - # Stateful wrapping mutable iterators of known length (#43245) - @test length(Iterators.Stateful(Iterators.Stateful(1:5))) == 5 + # Stateful does not define length + let s = Iterators.Stateful(Iterators.Stateful(1:5)) + @test_throws MethodError length(s) + end end @testset "pair for Svec" begin @@ -866,6 +868,10 @@ end @testset "inference for large zip #26765" begin x = zip(1:2, ["a", "b"], (1.0, 2.0), Base.OneTo(2), Iterators.repeated("a"), 1.0:0.2:2.0, (1 for i in 1:2), Iterators.Stateful(["a", "b", "c"]), (1.0 for i in 1:2, j in 1:3)) + @test Base.IteratorSize(x) isa Base.SizeUnknown + x = zip(1:2, ["a", "b"], (1.0, 2.0), Base.OneTo(2), Iterators.repeated("a"), 1.0:0.2:2.0, + (1 for i in 1:2), Iterators.cycle(Iterators.Stateful(["a", "b", "c"])), (1.0 for i in 1:2, j in 1:3)) + @test Base.IteratorSize(x) isa Base.HasLength @test @inferred(length(x)) == 2 z = Iterators.filter(x -> x[1] >= 1, x) @test @inferred(eltype(z)) <: Tuple{Int,String,Float64,Int,String,Float64,Any,String,Any} @@ -874,20 +880,20 @@ end end @testset "Stateful fix #30643" begin - @test Base.IteratorSize(1:10) isa Base.HasShape + @test Base.IteratorSize(1:10) isa Base.HasShape{1} a = Iterators.Stateful(1:10) - @test Base.IteratorSize(a) isa Base.HasLength - @test length(a) == 10 + @test Base.IteratorSize(a) isa Base.SizeUnknown + @test !Base.isdone(a) @test length(collect(a)) == 10 - @test length(a) == 0 + @test Base.isdone(a) b = Iterators.Stateful(Iterators.take(1:10,3)) - @test Base.IteratorSize(b) isa Base.HasLength - @test length(b) == 3 + @test Base.IteratorSize(b) isa Base.SizeUnknown + @test !Base.isdone(b) @test length(collect(b)) == 3 - @test length(b) == 0 + @test Base.isdone(b) c = Iterators.Stateful(Iterators.countfrom(1)) @test Base.IteratorSize(c) isa Base.IsInfinite - @test length(Iterators.take(c,3)) == 3 + @test !Base.isdone(Iterators.take(c,3)) @test length(collect(Iterators.take(c,3))) == 3 d = Iterators.Stateful(Iterators.filter(isodd,1:10)) @test Base.IteratorSize(d) isa Base.SizeUnknown @@ -1010,6 +1016,11 @@ end @test collect(Iterators.partition(lstrip("01111", '0'), 2)) == ["11", "11"] end +let itr = (i for i in 1:9) # Base.eltype == Any + @test first(Iterators.partition(itr, 3)) isa Vector{Any} + @test collect(zip(repeat([Iterators.Stateful(itr)], 3)...)) == [(1, 2, 3), (4, 5, 6), (7, 8, 9)] +end + @testset "no single-argument map methods" begin maps = (tuple, Returns(nothing), (() -> nothing)) mappers = (Iterators.map, map, foreach)