Skip to content

Commit cbbefc1

Browse files
jakobnissenKristofferC
authored andcommitted
Fix collect on stateful generator (#41919)
Previously this code would drop 1 from the length of some generators. Fixes #35530 (cherry picked from commit 8364a4c)
1 parent c750018 commit cbbefc1

File tree

2 files changed

+16
-3
lines changed

2 files changed

+16
-3
lines changed

base/array.jl

+7-3
Original file line numberDiff line numberDiff line change
@@ -666,21 +666,25 @@ else
666666
end
667667
end
668668

669-
_array_for(::Type{T}, itr, ::HasLength) where {T} = Vector{T}(undef, Int(length(itr)::Integer))
670-
_array_for(::Type{T}, itr, ::HasShape{N}) where {T,N} = similar(Array{T,N}, axes(itr))
669+
_array_for(::Type{T}, itr, isz::HasLength) where {T} = _array_for(T, itr, isz, length(itr))
670+
_array_for(::Type{T}, itr, isz::HasShape{N}) where {T,N} = _array_for(T, itr, isz, axes(itr))
671+
_array_for(::Type{T}, itr, ::HasLength, len) where {T} = Vector{T}(undef, len)
672+
_array_for(::Type{T}, itr, ::HasShape{N}, axs) where {T,N} = similar(Array{T,N}, axs)
671673

672674
function collect(itr::Generator)
673675
isz = IteratorSize(itr.iter)
674676
et = @default_eltype(itr)
675677
if isa(isz, SizeUnknown)
676678
return grow_to!(Vector{et}(), itr)
677679
else
680+
shape = isz isa HasLength ? length(itr) : axes(itr)
678681
y = iterate(itr)
679682
if y === nothing
680683
return _array_for(et, itr.iter, isz)
681684
end
682685
v1, st = y
683-
collect_to_with_first!(_array_for(typeof(v1), itr.iter, isz), v1, itr, st)
686+
arr = _array_for(typeof(v1), itr.iter, isz, shape)
687+
return collect_to_with_first!(arr, v1, itr, st)
684688
end
685689
end
686690

test/iterators.jl

+9
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,15 @@ let (a, b) = (1:3, [4 6;
291291
end
292292
end
293293

294+
# collect stateful iterator
295+
let
296+
itr = (i+1 for i in Base.Stateful([1,2,3]))
297+
@test collect(itr) == [2, 3, 4]
298+
A = zeros(Int, 0, 0)
299+
itr = (i-1 for i in Base.Stateful(A))
300+
@test collect(itr) == Int[] # Stateful do not preserve shape
301+
end
302+
294303
# with 1D inputs
295304
let a = 1:2,
296305
b = 1.0:10.0,

0 commit comments

Comments
 (0)