Skip to content

Commit

Permalink
remove length from Stateful
Browse files Browse the repository at this point in the history
Stateful iterators do not have a consistent notion of length, as it is
continuously changing as elements are removed. As the main purpose of
Stateful is to take elements from multiple places, any notion of
HaveShape is invalid for those cases, and thus not useful in general.

Fix #47790
  • Loading branch information
vtjnash committed Oct 17, 2023
1 parent 4d2d849 commit fef0b07
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 48 deletions.
50 changes: 12 additions & 38 deletions base/iterators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -952,7 +952,7 @@ cycle(xs) = Cycle(xs)

eltype(::Type{Cycle{I}}) where {I} = eltype(I)
IteratorEltype(::Type{Cycle{I}}) where {I} = IteratorEltype(I)
IteratorSize(::Type{Cycle{I}}) where {I} = IsInfinite()
IteratorSize(::Type{Cycle{I}}) where {I} = IsInfinite() # XXX: this is false if iterator ever becomes empty

iterate(it::Cycle) = iterate(it.xs)
isdone(it::Cycle) = isdone(it.xs)
Expand Down Expand Up @@ -1401,43 +1401,30 @@ julia> sum(a) # Sum the remaining elements
7
```
"""
mutable struct Stateful{T, VS, N<:Integer}
mutable struct Stateful{T, VS}
itr::T
# A bit awkward right now, but adapted to the new iteration protocol
nextvalstate::Union{VS, Nothing}

# Number of remaining elements, if itr is HasLength or HasShape.
# if not, store -1 - number_of_consumed_elements.
# This allows us to defer calculating length until asked for.
# See PR #45924
remaining::N
@inline function Stateful{<:Any, Any}(itr::T) where {T}
itl = iterlength(itr)
new{T, Any, typeof(itl)}(itr, iterate(itr), itl)
return new{T, Any}(itr, iterate(itr))
end
@inline function Stateful(itr::T) where {T}
VS = approx_iter_type(T)
itl = iterlength(itr)
return new{T, VS, typeof(itl)}(itr, iterate(itr)::VS, itl)
return new{T, VS}(itr, iterate(itr)::VS)
end
end

function iterlength(it)::Signed
if IteratorSize(it) isa Union{HasShape, HasLength}
return length(it)
else
-1
end
function reset!(s::Stateful)
setfield!(s, :nextvalstate, iterate(s.itr)) # bypass convert call of setproperty!
return s
end

function reset!(s::Stateful{T,VS}, itr::T=s.itr) where {T,VS}
function reset!(s::Stateful{T}, itr::T) where {T}
s.itr = itr
itl = iterlength(itr)
setfield!(s, :nextvalstate, iterate(itr))
s.remaining = itl
s
reset!(s)
return s
end


# Try to find an appropriate type for the (value, state tuple),
# by doing a recursive unrolling of the iteration protocol up to
# fixpoint.
Expand All @@ -1459,7 +1446,6 @@ end

Stateful(x::Stateful) = x
convert(::Type{Stateful}, itr) = Stateful(itr)

@inline isdone(s::Stateful, st=nothing) = s.nextvalstate === nothing

@inline function popfirst!(s::Stateful)
Expand All @@ -1469,8 +1455,6 @@ convert(::Type{Stateful}, itr) = Stateful(itr)
else
val, state = vs
Core.setfield!(s, :nextvalstate, iterate(s.itr, state))
rem = s.remaining
s.remaining = rem - typeof(rem)(1)
return val
end
end
Expand All @@ -1480,20 +1464,10 @@ end
return ns !== nothing ? ns[1] : sentinel
end
@inline iterate(s::Stateful, state=nothing) = s.nextvalstate === nothing ? nothing : (popfirst!(s), nothing)
IteratorSize(::Type{<:Stateful{T}}) where {T} = IteratorSize(T) isa HasShape ? HasLength() : IteratorSize(T)
IteratorSize(::Type{<:Stateful{T}}) where {T} = IteratorSize(T) isa IsInfinite ? IsInfinite() : SizeUnknown()
eltype(::Type{<:Stateful{T}}) where {T} = eltype(T)
IteratorEltype(::Type{<:Stateful{T}}) where {T} = IteratorEltype(T)

function length(s::Stateful)
rem = s.remaining
# If rem is actually remaining length, return it.
# else, rem is number of consumed elements.
if rem >= 0
rem
else
length(s.itr) - (typeof(rem)(1) - rem)
end
end
end # if statement several hundred lines above

"""
Expand Down
31 changes: 21 additions & 10 deletions test/iterators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -849,8 +849,10 @@ end
v, s = iterate(z)
@test Base.isdone(z, s)
end
# Stateful wrapping mutable iterators of known length (#43245)
@test length(Iterators.Stateful(Iterators.Stateful(1:5))) == 5
# Stateful does not define length
let s = Iterators.Stateful(Iterators.Stateful(1:5))
@test_throws MethodError length(s)
end
end

@testset "pair for Svec" begin
Expand All @@ -862,6 +864,10 @@ end
@testset "inference for large zip #26765" begin
x = zip(1:2, ["a", "b"], (1.0, 2.0), Base.OneTo(2), Iterators.repeated("a"), 1.0:0.2:2.0,
(1 for i in 1:2), Iterators.Stateful(["a", "b", "c"]), (1.0 for i in 1:2, j in 1:3))
@test Base.IteratorSize(x) isa Base.SizeUnknown
x = zip(1:2, ["a", "b"], (1.0, 2.0), Base.OneTo(2), Iterators.repeated("a"), 1.0:0.2:2.0,
(1 for i in 1:2), Iterators.cycle(Iterators.Stateful(["a", "b", "c"])), (1.0 for i in 1:2, j in 1:3))
@test Base.IteratorSize(x) isa Base.HasLength
@test @inferred(length(x)) == 2
z = Iterators.filter(x -> x[1] >= 1, x)
@test @inferred(eltype(z)) <: Tuple{Int,String,Float64,Int,String,Float64,Any,String,Any}
Expand All @@ -870,20 +876,20 @@ end
end

@testset "Stateful fix #30643" begin
@test Base.IteratorSize(1:10) isa Base.HasShape
@test Base.IteratorSize(1:10) isa Base.HasShape{1}
a = Iterators.Stateful(1:10)
@test Base.IteratorSize(a) isa Base.HasLength
@test length(a) == 10
@test Base.IteratorSize(a) isa Base.SizeUnknown
@test !Base.isdone(a)
@test length(collect(a)) == 10
@test length(a) == 0
@test Base.isdone(a)
b = Iterators.Stateful(Iterators.take(1:10,3))
@test Base.IteratorSize(b) isa Base.HasLength
@test length(b) == 3
@test Base.IteratorSize(b) isa Base.SizeUnknown
@test !Base.isdone(b)
@test length(collect(b)) == 3
@test length(b) == 0
@test Base.isdone(b)
c = Iterators.Stateful(Iterators.countfrom(1))
@test Base.IteratorSize(c) isa Base.IsInfinite
@test length(Iterators.take(c,3)) == 3
@test !Base.isdone(Iterators.take(c,3))
@test length(collect(Iterators.take(c,3))) == 3
d = Iterators.Stateful(Iterators.filter(isodd,1:10))
@test Base.IteratorSize(d) isa Base.SizeUnknown
Expand Down Expand Up @@ -1001,3 +1007,8 @@ end
end
@test v == ()
end

let itr = (i for i in 1:9) # Base.eltype == Any
@test first(Iterators.partition(itr, 3)) isa Vector{Any}
@test collect(zip(repeat([Iterators.Stateful(itr)], 3)...)) == [(1, 2, 3), (4, 5, 6), (7, 8, 9)]
end

0 comments on commit fef0b07

Please sign in to comment.