diff --git a/base/essentials.jl b/base/essentials.jl index 1af26f6bb01ac3..827a4fa3f75ba7 100644 --- a/base/essentials.jl +++ b/base/essentials.jl @@ -764,3 +764,6 @@ Indicate whether `x` is [`missing`](@ref). """ ismissing(::Any) = false ismissing(::Missing) = true + +function take! end +function peek end \ No newline at end of file diff --git a/base/iterators.jl b/base/iterators.jl index 811f44a327344a..5d7018369af33e 100644 --- a/base/iterators.jl +++ b/base/iterators.jl @@ -18,7 +18,8 @@ import .Base: isempty, length, size, axes, ndims, eltype, IteratorSize, IteratorEltype, haskey, keys, values, pairs, - getindex, setindex!, get + getindex, setindex!, get, take!, + peek export enumerate, zip, rest, countfrom, take, drop, cycle, repeated, product, flatten, partition @@ -957,4 +958,97 @@ function next(itr::PartitionIterator, state) return resize!(v, i), state end +""" + Stateful(itr) + +There are several different ways to think about this iterator wrapper: + 1. It provides a mutable wrapper around an iterator and + its iteration state. + 2. It turns an iterator-like abstraction into a Channel-like + abstraction. + 3. It's an iterator that mutates to become its own rest iterator + whenever an item is produced. + +`Stateful` provides the regular iterator interface. Like other mutable iterators +(e.g. Channel), if iteration is stopped early (e.g. by a `break` in a for loop), +iteration can be resumed from the same spot by continuing to iterate over the +same iterator object (in contrast, an immutable iterator would restart from the +beginning). + +# Example: +```jldoctest +julia> a = Iterators.Stateful("abcdef"); + +julia> isempty(a) +false + +julia> take!(a) +'a': ASCII/Unicode U+0061 (category Ll: Letter, lowercase) + +julia> collect(Iterators.take(a, 3)) +3-element Array{Any,1}: + 'b' + 'c' + 'd' + +julia> collect(a) +2-element Array{Char,1}: + 'e' + 'f' +``` + +```jldoctest +julia> a = Iterators.Stateful([1,1,1,2,3,4]); + +# Skip any leading ones +julia> for x in a; x == 1 || break; end + +# Sum the remaining elements +julia> sum(a) +7 +```` +""" +struct Stateful{VS,T} + itr::T + # A bit awkward right now, but adapted to the new iteration protocol + nextvalstate::Ref{Union{VS, Nothing}} + taken::Ref{Int} +end + +convert(::Type{Stateful}, itr) = Stateful(itr) +convert(::Type{Stateful{S}}, itr) where {S} = Stateful{S}(itr) + +function Stateful(itr::T) where {T} + state = start(itr) + vs = done(itr, state) ? nothing : next(itr, start(itr)) + VS = typeof(vs) + Stateful{VS, T}(itr, Ref{Union{VS, Nothing}}(vs), Ref{Int}(0)) +end + +function Stateful{VS}(itr::T) where {VS,T} + state = start(itr) + sv = done(itr, state) ? nothing : next(itr, start(itr))::VS + Stateful{VS, T}(sv, Ref{Union{VS, Nothing}}(vs), Ref{Int}(0)) +end + +isempty(s::Stateful) = s.nextvalstate[] === nothing + +function take!(s::Stateful) + isempty(s) && throw(EOFError()) + val, state = s.nextvalstate[] + s.nextvalstate[] = done(s.itr, state) ? nothing : next(s.itr, state) + s.taken[] += 1 + val +end + +peek(s::Stateful, sentinel=nothing) = s.nextvalstate[] === nothing ? s.nextvalstate[][1] : sentinel +start(s::Stateful) = nothing +next(s::Stateful, state) = take!(s), nothing +done(s::Stateful, state) = isempty(s) +IteratorSize(::Type{Stateful{VS,T}} where VS) where {T} = + isa(IteratorSize(T), SizeUnknown) ? SizeUnknown() : HasLength() +eltype(::Type{Stateful{VS, T}} where VS) where {T} = eltype(T) +IteratorEltype(::Type{Stateful{VS,T}} where VS) where {T} = IteratorEltype(T) +length(s::Stateful) = length(s.itr) - s.taken[] + end diff --git a/test/iterators.jl b/test/iterators.jl index ee8be576b2b74a..766c4d0d9c8eb0 100644 --- a/test/iterators.jl +++ b/test/iterators.jl @@ -493,3 +493,16 @@ end @test Iterators.reverse(Iterators.reverse(t)) === t end end + +@testset "Iterators.Stateful" begin + let a = Iterators.Stateful("abcdef") + @test !isempty(a) + @test take!(a) == 'a' + @test collect(Iterators.take(a, 3)) == ['b','c','d'] + @test collect(a) == ['e', 'f'] + end + let a = Iterators.Stateful([1, 1, 1, 2, 3, 4]) + for x in a; x == 1 || break; end + @test sum(a) == 7 + end +end \ No newline at end of file