Skip to content

Commit ca8a108

Browse files
Rollup merge of #87091 - the8472:more-advance-by-impls, r=joshtriplett
implement advance_(back_)_by on more iterators Add more efficient, non-default implementations for `feature(iter_advance_by)` (#77404) on more iterators and adapters. This PR only contains implementations where skipping over items doesn't elide any observable side-effects such as user-provided closures or `clone()` functions. I'll put those in a separate PR.
2 parents 4e9cf04 + ffd7ade commit ca8a108

File tree

15 files changed

+399
-3
lines changed

15 files changed

+399
-3
lines changed

library/alloc/src/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@
111111
// that the feature-gate isn't enabled. Ideally, it wouldn't check for the feature gate for docs
112112
// from other crates, but since this can only appear for lang items, it doesn't seem worth fixing.
113113
#![feature(intra_doc_pointers)]
114+
#![feature(iter_advance_by)]
114115
#![feature(iter_zip)]
115116
#![feature(lang_items)]
116117
#![feature(layout_for_ptr)]

library/alloc/src/vec/into_iter.rs

+46
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,29 @@ impl<T, A: Allocator> Iterator for IntoIter<T, A> {
161161
(exact, Some(exact))
162162
}
163163

164+
#[inline]
165+
fn advance_by(&mut self, n: usize) -> Result<(), usize> {
166+
let step_size = self.len().min(n);
167+
let to_drop = ptr::slice_from_raw_parts_mut(self.ptr as *mut T, step_size);
168+
if mem::size_of::<T>() == 0 {
169+
// SAFETY: due to unchecked casts of unsigned amounts to signed offsets the wraparound
170+
// effectively results in unsigned pointers representing positions 0..usize::MAX,
171+
// which is valid for ZSTs.
172+
self.ptr = unsafe { arith_offset(self.ptr as *const i8, step_size as isize) as *mut T }
173+
} else {
174+
// SAFETY: the min() above ensures that step_size is in bounds
175+
self.ptr = unsafe { self.ptr.add(step_size) };
176+
}
177+
// SAFETY: the min() above ensures that step_size is in bounds
178+
unsafe {
179+
ptr::drop_in_place(to_drop);
180+
}
181+
if step_size < n {
182+
return Err(step_size);
183+
}
184+
Ok(())
185+
}
186+
164187
#[inline]
165188
fn count(self) -> usize {
166189
self.len()
@@ -203,6 +226,29 @@ impl<T, A: Allocator> DoubleEndedIterator for IntoIter<T, A> {
203226
Some(unsafe { ptr::read(self.end) })
204227
}
205228
}
229+
230+
#[inline]
231+
fn advance_back_by(&mut self, n: usize) -> Result<(), usize> {
232+
let step_size = self.len().min(n);
233+
if mem::size_of::<T>() == 0 {
234+
// SAFETY: same as for advance_by()
235+
self.end = unsafe {
236+
arith_offset(self.end as *const i8, step_size.wrapping_neg() as isize) as *mut T
237+
}
238+
} else {
239+
// SAFETY: same as for advance_by()
240+
self.end = unsafe { self.end.offset(step_size.wrapping_neg() as isize) };
241+
}
242+
let to_drop = ptr::slice_from_raw_parts_mut(self.end as *mut T, step_size);
243+
// SAFETY: same as for advance_by()
244+
unsafe {
245+
ptr::drop_in_place(to_drop);
246+
}
247+
if step_size < n {
248+
return Err(step_size);
249+
}
250+
Ok(())
251+
}
206252
}
207253

208254
#[stable(feature = "rust1", since = "1.0.0")]

library/alloc/tests/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#![feature(binary_heap_retain)]
1919
#![feature(binary_heap_as_slice)]
2020
#![feature(inplace_iteration)]
21+
#![feature(iter_advance_by)]
2122
#![feature(slice_group_by)]
2223
#![feature(slice_partition_dedup)]
2324
#![feature(vec_spare_capacity)]

library/alloc/tests/vec.rs

+18
Original file line numberDiff line numberDiff line change
@@ -970,6 +970,24 @@ fn test_into_iter_leak() {
970970
assert_eq!(unsafe { DROPS }, 3);
971971
}
972972

973+
#[test]
974+
fn test_into_iter_advance_by() {
975+
let mut i = vec![1, 2, 3, 4, 5].into_iter();
976+
i.advance_by(0).unwrap();
977+
i.advance_back_by(0).unwrap();
978+
assert_eq!(i.as_slice(), [1, 2, 3, 4, 5]);
979+
980+
i.advance_by(1).unwrap();
981+
i.advance_back_by(1).unwrap();
982+
assert_eq!(i.as_slice(), [2, 3, 4]);
983+
984+
assert_eq!(i.advance_back_by(usize::MAX), Err(3));
985+
986+
assert_eq!(i.advance_by(usize::MAX), Err(0));
987+
988+
assert_eq!(i.len(), 0);
989+
}
990+
973991
#[test]
974992
fn test_from_iter_specialization() {
975993
let src: Vec<usize> = vec![0usize; 1];

library/core/src/iter/adapters/copied.rs

+10
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,11 @@ where
7676
self.it.count()
7777
}
7878

79+
#[inline]
80+
fn advance_by(&mut self, n: usize) -> Result<(), usize> {
81+
self.it.advance_by(n)
82+
}
83+
7984
#[doc(hidden)]
8085
unsafe fn __iterator_get_unchecked(&mut self, idx: usize) -> T
8186
where
@@ -112,6 +117,11 @@ where
112117
{
113118
self.it.rfold(init, copy_fold(f))
114119
}
120+
121+
#[inline]
122+
fn advance_back_by(&mut self, n: usize) -> Result<(), usize> {
123+
self.it.advance_back_by(n)
124+
}
115125
}
116126

117127
#[stable(feature = "iter_copied", since = "1.36.0")]

library/core/src/iter/adapters/cycle.rs

+21
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,27 @@ where
7979
}
8080
}
8181

82+
#[inline]
83+
#[rustc_inherit_overflow_checks]
84+
fn advance_by(&mut self, n: usize) -> Result<(), usize> {
85+
let mut rem = n;
86+
match self.iter.advance_by(rem) {
87+
ret @ Ok(_) => return ret,
88+
Err(advanced) => rem -= advanced,
89+
}
90+
91+
while rem > 0 {
92+
self.iter = self.orig.clone();
93+
match self.iter.advance_by(rem) {
94+
ret @ Ok(_) => return ret,
95+
Err(0) => return Err(n - rem),
96+
Err(advanced) => rem -= advanced,
97+
}
98+
}
99+
100+
Ok(())
101+
}
102+
82103
// No `fold` override, because `fold` doesn't make much sense for `Cycle`,
83104
// and we can't do anything better than the default.
84105
}

library/core/src/iter/adapters/enumerate.rs

+22
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,21 @@ where
112112
self.iter.fold(init, enumerate(self.count, fold))
113113
}
114114

115+
#[inline]
116+
#[rustc_inherit_overflow_checks]
117+
fn advance_by(&mut self, n: usize) -> Result<(), usize> {
118+
match self.iter.advance_by(n) {
119+
ret @ Ok(_) => {
120+
self.count += n;
121+
ret
122+
}
123+
ret @ Err(advanced) => {
124+
self.count += advanced;
125+
ret
126+
}
127+
}
128+
}
129+
115130
#[rustc_inherit_overflow_checks]
116131
#[doc(hidden)]
117132
unsafe fn __iterator_get_unchecked(&mut self, idx: usize) -> <Self as Iterator>::Item
@@ -191,6 +206,13 @@ where
191206
let count = self.count + self.iter.len();
192207
self.iter.rfold(init, enumerate(count, fold))
193208
}
209+
210+
#[inline]
211+
fn advance_back_by(&mut self, n: usize) -> Result<(), usize> {
212+
// we do not need to update the count since that only tallies the number of items
213+
// consumed from the front. consuming items from the back can never reduce that.
214+
self.iter.advance_back_by(n)
215+
}
194216
}
195217

196218
#[stable(feature = "rust1", since = "1.0.0")]

library/core/src/iter/adapters/flatten.rs

+70
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,41 @@ where
391391

392392
init
393393
}
394+
395+
#[inline]
396+
#[rustc_inherit_overflow_checks]
397+
fn advance_by(&mut self, n: usize) -> Result<(), usize> {
398+
let mut rem = n;
399+
loop {
400+
if let Some(ref mut front) = self.frontiter {
401+
match front.advance_by(rem) {
402+
ret @ Ok(_) => return ret,
403+
Err(advanced) => rem -= advanced,
404+
}
405+
}
406+
self.frontiter = match self.iter.next() {
407+
Some(iterable) => Some(iterable.into_iter()),
408+
_ => break,
409+
}
410+
}
411+
412+
self.frontiter = None;
413+
414+
if let Some(ref mut back) = self.backiter {
415+
match back.advance_by(rem) {
416+
ret @ Ok(_) => return ret,
417+
Err(advanced) => rem -= advanced,
418+
}
419+
}
420+
421+
if rem > 0 {
422+
return Err(n - rem);
423+
}
424+
425+
self.backiter = None;
426+
427+
Ok(())
428+
}
394429
}
395430

396431
impl<I, U> DoubleEndedIterator for FlattenCompat<I, U>
@@ -486,6 +521,41 @@ where
486521

487522
init
488523
}
524+
525+
#[inline]
526+
#[rustc_inherit_overflow_checks]
527+
fn advance_back_by(&mut self, n: usize) -> Result<(), usize> {
528+
let mut rem = n;
529+
loop {
530+
if let Some(ref mut back) = self.backiter {
531+
match back.advance_back_by(rem) {
532+
ret @ Ok(_) => return ret,
533+
Err(advanced) => rem -= advanced,
534+
}
535+
}
536+
match self.iter.next_back() {
537+
Some(iterable) => self.backiter = Some(iterable.into_iter()),
538+
_ => break,
539+
}
540+
}
541+
542+
self.backiter = None;
543+
544+
if let Some(ref mut front) = self.frontiter {
545+
match front.advance_back_by(rem) {
546+
ret @ Ok(_) => return ret,
547+
Err(advanced) => rem -= advanced,
548+
}
549+
}
550+
551+
if rem > 0 {
552+
return Err(n - rem);
553+
}
554+
555+
self.frontiter = None;
556+
557+
Ok(())
558+
}
489559
}
490560

491561
trait ConstSizeIntoIterator: IntoIterator {

library/core/src/iter/adapters/skip.rs

+42
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,38 @@ where
114114
}
115115
self.iter.fold(init, fold)
116116
}
117+
118+
#[inline]
119+
#[rustc_inherit_overflow_checks]
120+
fn advance_by(&mut self, n: usize) -> Result<(), usize> {
121+
let mut rem = n;
122+
123+
let step_one = self.n.saturating_add(rem);
124+
match self.iter.advance_by(step_one) {
125+
Ok(_) => {
126+
rem -= step_one - self.n;
127+
self.n = 0;
128+
}
129+
Err(advanced) => {
130+
let advanced_without_skip = advanced.saturating_sub(self.n);
131+
self.n = self.n.saturating_sub(advanced);
132+
return Err(advanced_without_skip);
133+
}
134+
}
135+
136+
// step_one calculation may have saturated
137+
if unlikely(rem > 0) {
138+
return match self.iter.advance_by(rem) {
139+
ret @ Ok(_) => ret,
140+
Err(advanced) => {
141+
rem -= advanced;
142+
Err(n - rem)
143+
}
144+
};
145+
}
146+
147+
Ok(())
148+
}
117149
}
118150

119151
#[stable(feature = "rust1", since = "1.0.0")]
@@ -174,6 +206,16 @@ where
174206

175207
self.try_rfold(init, ok(fold)).unwrap()
176208
}
209+
210+
#[inline]
211+
fn advance_back_by(&mut self, n: usize) -> Result<(), usize> {
212+
let min = crate::cmp::min(self.len(), n);
213+
return match self.iter.advance_back_by(min) {
214+
ret @ Ok(_) if n <= min => ret,
215+
Ok(_) => Err(min),
216+
_ => panic!("ExactSizeIterator contract violation"),
217+
};
218+
}
177219
}
178220

179221
#[stable(feature = "fused", since = "1.26.0")]

library/core/src/iter/adapters/take.rs

+34
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,22 @@ where
111111

112112
self.try_fold(init, ok(fold)).unwrap()
113113
}
114+
115+
#[inline]
116+
#[rustc_inherit_overflow_checks]
117+
fn advance_by(&mut self, n: usize) -> Result<(), usize> {
118+
let min = self.n.min(n);
119+
match self.iter.advance_by(min) {
120+
Ok(_) => {
121+
self.n -= min;
122+
if min < n { Err(min) } else { Ok(()) }
123+
}
124+
ret @ Err(advanced) => {
125+
self.n -= advanced;
126+
ret
127+
}
128+
}
129+
}
114130
}
115131

116132
#[unstable(issue = "none", feature = "inplace_iteration")]
@@ -197,6 +213,24 @@ where
197213
}
198214
}
199215
}
216+
217+
#[inline]
218+
fn advance_back_by(&mut self, n: usize) -> Result<(), usize> {
219+
let inner_len = self.iter.len();
220+
let len = self.n;
221+
let remainder = len.saturating_sub(n);
222+
let to_advance = inner_len - remainder;
223+
match self.iter.advance_back_by(to_advance) {
224+
Ok(_) => {
225+
self.n = remainder;
226+
if n > len {
227+
return Err(len);
228+
}
229+
return Ok(());
230+
}
231+
_ => panic!("ExactSizeIterator contract violation"),
232+
}
233+
}
200234
}
201235

202236
#[stable(feature = "rust1", since = "1.0.0")]

0 commit comments

Comments
 (0)