Skip to content

Commit 6672c16

Browse files
committed
Auto merge of rust-lang#121204 - cuviper:flatten-one-shot, r=the8472
Specialize flattening iterators with only one inner item For iterators like `Once` and `option::IntoIter` that only ever have a single item at most, the front and back iterator states in `FlatMap` and `Flatten` are a waste, as they're always consumed already. We can use specialization for these types to simplify the iterator methods. It's a somewhat common pattern to use `flatten()` for options and results, even recommended by [multiple][1] [clippy][2] [lints][3]. The implementation is more efficient with `filter_map`, as mentioned in [clippy#9377], but this new specialization should close some of that gap for existing code that flattens. [1]: https://rust-lang.github.io/rust-clippy/master/#filter_map_identity [2]: https://rust-lang.github.io/rust-clippy/master/#option_filter_map [3]: https://rust-lang.github.io/rust-clippy/master/#result_filter_map [clippy#9377]: rust-lang/rust-clippy#9377
2 parents cabdf3a + c36ae93 commit 6672c16

File tree

2 files changed

+275
-12
lines changed

2 files changed

+275
-12
lines changed

library/core/src/iter/adapters/flatten.rs

+209-12
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use crate::iter::{
33
Cloned, Copied, Filter, FilterMap, Fuse, FusedIterator, InPlaceIterable, Map, TrustedFused,
44
TrustedLen,
55
};
6-
use crate::iter::{Once, OnceWith};
6+
use crate::iter::{Empty, Once, OnceWith};
77
use crate::num::NonZero;
88
use crate::ops::{ControlFlow, Try};
99
use crate::result;
@@ -593,6 +593,7 @@ where
593593
}
594594
}
595595

596+
// See also the `OneShot` specialization below.
596597
impl<I, U> Iterator for FlattenCompat<I, U>
597598
where
598599
I: Iterator<Item: IntoIterator<IntoIter = U, Item = U::Item>>,
@@ -601,7 +602,7 @@ where
601602
type Item = U::Item;
602603

603604
#[inline]
604-
fn next(&mut self) -> Option<U::Item> {
605+
default fn next(&mut self) -> Option<U::Item> {
605606
loop {
606607
if let elt @ Some(_) = and_then_or_clear(&mut self.frontiter, Iterator::next) {
607608
return elt;
@@ -614,7 +615,7 @@ where
614615
}
615616

616617
#[inline]
617-
fn size_hint(&self) -> (usize, Option<usize>) {
618+
default fn size_hint(&self) -> (usize, Option<usize>) {
618619
let (flo, fhi) = self.frontiter.as_ref().map_or((0, Some(0)), U::size_hint);
619620
let (blo, bhi) = self.backiter.as_ref().map_or((0, Some(0)), U::size_hint);
620621
let lo = flo.saturating_add(blo);
@@ -636,7 +637,7 @@ where
636637
}
637638

638639
#[inline]
639-
fn try_fold<Acc, Fold, R>(&mut self, init: Acc, fold: Fold) -> R
640+
default fn try_fold<Acc, Fold, R>(&mut self, init: Acc, fold: Fold) -> R
640641
where
641642
Self: Sized,
642643
Fold: FnMut(Acc, Self::Item) -> R,
@@ -653,7 +654,7 @@ where
653654
}
654655

655656
#[inline]
656-
fn fold<Acc, Fold>(self, init: Acc, fold: Fold) -> Acc
657+
default fn fold<Acc, Fold>(self, init: Acc, fold: Fold) -> Acc
657658
where
658659
Fold: FnMut(Acc, Self::Item) -> Acc,
659660
{
@@ -669,7 +670,7 @@ where
669670

670671
#[inline]
671672
#[rustc_inherit_overflow_checks]
672-
fn advance_by(&mut self, n: usize) -> Result<(), NonZero<usize>> {
673+
default fn advance_by(&mut self, n: usize) -> Result<(), NonZero<usize>> {
673674
#[inline]
674675
#[rustc_inherit_overflow_checks]
675676
fn advance<U: Iterator>(n: usize, iter: &mut U) -> ControlFlow<(), usize> {
@@ -686,7 +687,7 @@ where
686687
}
687688

688689
#[inline]
689-
fn count(self) -> usize {
690+
default fn count(self) -> usize {
690691
#[inline]
691692
#[rustc_inherit_overflow_checks]
692693
fn count<U: Iterator>(acc: usize, iter: U) -> usize {
@@ -697,7 +698,7 @@ where
697698
}
698699

699700
#[inline]
700-
fn last(self) -> Option<Self::Item> {
701+
default fn last(self) -> Option<Self::Item> {
701702
#[inline]
702703
fn last<U: Iterator>(last: Option<U::Item>, iter: U) -> Option<U::Item> {
703704
iter.last().or(last)
@@ -707,13 +708,14 @@ where
707708
}
708709
}
709710

711+
// See also the `OneShot` specialization below.
710712
impl<I, U> DoubleEndedIterator for FlattenCompat<I, U>
711713
where
712714
I: DoubleEndedIterator<Item: IntoIterator<IntoIter = U, Item = U::Item>>,
713715
U: DoubleEndedIterator,
714716
{
715717
#[inline]
716-
fn next_back(&mut self) -> Option<U::Item> {
718+
default fn next_back(&mut self) -> Option<U::Item> {
717719
loop {
718720
if let elt @ Some(_) = and_then_or_clear(&mut self.backiter, |b| b.next_back()) {
719721
return elt;
@@ -726,7 +728,7 @@ where
726728
}
727729

728730
#[inline]
729-
fn try_rfold<Acc, Fold, R>(&mut self, init: Acc, fold: Fold) -> R
731+
default fn try_rfold<Acc, Fold, R>(&mut self, init: Acc, fold: Fold) -> R
730732
where
731733
Self: Sized,
732734
Fold: FnMut(Acc, Self::Item) -> R,
@@ -743,7 +745,7 @@ where
743745
}
744746

745747
#[inline]
746-
fn rfold<Acc, Fold>(self, init: Acc, fold: Fold) -> Acc
748+
default fn rfold<Acc, Fold>(self, init: Acc, fold: Fold) -> Acc
747749
where
748750
Fold: FnMut(Acc, Self::Item) -> Acc,
749751
{
@@ -759,7 +761,7 @@ where
759761

760762
#[inline]
761763
#[rustc_inherit_overflow_checks]
762-
fn advance_back_by(&mut self, n: usize) -> Result<(), NonZero<usize>> {
764+
default fn advance_back_by(&mut self, n: usize) -> Result<(), NonZero<usize>> {
763765
#[inline]
764766
#[rustc_inherit_overflow_checks]
765767
fn advance<U: DoubleEndedIterator>(n: usize, iter: &mut U) -> ControlFlow<(), usize> {
@@ -841,3 +843,198 @@ fn and_then_or_clear<T, U>(opt: &mut Option<T>, f: impl FnOnce(&mut T) -> Option
841843
}
842844
x
843845
}
846+
847+
/// Specialization trait for iterator types that never return more than one item.
848+
///
849+
/// Note that we still have to deal with the possibility that the iterator was
850+
/// already exhausted before it came into our control.
851+
#[rustc_specialization_trait]
852+
trait OneShot {}
853+
854+
// These all have exactly one item, if not already consumed.
855+
impl<T> OneShot for Once<T> {}
856+
impl<F> OneShot for OnceWith<F> {}
857+
impl<T> OneShot for array::IntoIter<T, 1> {}
858+
impl<T> OneShot for option::IntoIter<T> {}
859+
impl<T> OneShot for option::Iter<'_, T> {}
860+
impl<T> OneShot for option::IterMut<'_, T> {}
861+
impl<T> OneShot for result::IntoIter<T> {}
862+
impl<T> OneShot for result::Iter<'_, T> {}
863+
impl<T> OneShot for result::IterMut<'_, T> {}
864+
865+
// These are always empty, which is fine to optimize too.
866+
impl<T> OneShot for Empty<T> {}
867+
impl<T> OneShot for array::IntoIter<T, 0> {}
868+
869+
// These adaptors never increase the number of items.
870+
// (There are more possible, but for now this matches BoundedSize above.)
871+
impl<I: OneShot> OneShot for Cloned<I> {}
872+
impl<I: OneShot> OneShot for Copied<I> {}
873+
impl<I: OneShot, P> OneShot for Filter<I, P> {}
874+
impl<I: OneShot, P> OneShot for FilterMap<I, P> {}
875+
impl<I: OneShot, F> OneShot for Map<I, F> {}
876+
877+
// Blanket impls pass this property through as well
878+
// (but we can't do `Box<I>` unless we expose this trait to alloc)
879+
impl<I: OneShot> OneShot for &mut I {}
880+
881+
#[inline]
882+
fn into_item<I>(inner: I) -> Option<I::Item>
883+
where
884+
I: IntoIterator<IntoIter: OneShot>,
885+
{
886+
inner.into_iter().next()
887+
}
888+
889+
#[inline]
890+
fn flatten_one<I: IntoIterator<IntoIter: OneShot>, Acc>(
891+
mut fold: impl FnMut(Acc, I::Item) -> Acc,
892+
) -> impl FnMut(Acc, I) -> Acc {
893+
move |acc, inner| match inner.into_iter().next() {
894+
Some(item) => fold(acc, item),
895+
None => acc,
896+
}
897+
}
898+
899+
#[inline]
900+
fn try_flatten_one<I: IntoIterator<IntoIter: OneShot>, Acc, R: Try<Output = Acc>>(
901+
mut fold: impl FnMut(Acc, I::Item) -> R,
902+
) -> impl FnMut(Acc, I) -> R {
903+
move |acc, inner| match inner.into_iter().next() {
904+
Some(item) => fold(acc, item),
905+
None => try { acc },
906+
}
907+
}
908+
909+
#[inline]
910+
fn advance_by_one<I>(n: NonZero<usize>, inner: I) -> Option<NonZero<usize>>
911+
where
912+
I: IntoIterator<IntoIter: OneShot>,
913+
{
914+
match inner.into_iter().next() {
915+
Some(_) => NonZero::new(n.get() - 1),
916+
None => Some(n),
917+
}
918+
}
919+
920+
// Specialization: When the inner iterator `U` never returns more than one item, the `frontiter` and
921+
// `backiter` states are a waste, because they'll always have already consumed their item. So in
922+
// this impl, we completely ignore them and just focus on `self.iter`, and we only call the inner
923+
// `U::next()` one time.
924+
//
925+
// It's mostly fine if we accidentally mix this with the more generic impls, e.g. by forgetting to
926+
// specialize one of the methods. If the other impl did set the front or back, we wouldn't see it
927+
// here, but it would be empty anyway; and if the other impl looked for a front or back that we
928+
// didn't bother setting, it would just see `None` (or a previous empty) and move on.
929+
//
930+
// An exception to that is `advance_by(0)` and `advance_back_by(0)`, where the generic impls may set
931+
// `frontiter` or `backiter` without consuming the item, so we **must** override those.
932+
impl<I, U> Iterator for FlattenCompat<I, U>
933+
where
934+
I: Iterator<Item: IntoIterator<IntoIter = U, Item = U::Item>>,
935+
U: Iterator + OneShot,
936+
{
937+
#[inline]
938+
fn next(&mut self) -> Option<U::Item> {
939+
while let Some(inner) = self.iter.next() {
940+
if let item @ Some(_) = inner.into_iter().next() {
941+
return item;
942+
}
943+
}
944+
None
945+
}
946+
947+
#[inline]
948+
fn size_hint(&self) -> (usize, Option<usize>) {
949+
let (lower, upper) = self.iter.size_hint();
950+
match <I::Item as ConstSizeIntoIterator>::size() {
951+
Some(0) => (0, Some(0)),
952+
Some(1) => (lower, upper),
953+
_ => (0, upper),
954+
}
955+
}
956+
957+
#[inline]
958+
fn try_fold<Acc, Fold, R>(&mut self, init: Acc, fold: Fold) -> R
959+
where
960+
Self: Sized,
961+
Fold: FnMut(Acc, Self::Item) -> R,
962+
R: Try<Output = Acc>,
963+
{
964+
self.iter.try_fold(init, try_flatten_one(fold))
965+
}
966+
967+
#[inline]
968+
fn fold<Acc, Fold>(self, init: Acc, fold: Fold) -> Acc
969+
where
970+
Fold: FnMut(Acc, Self::Item) -> Acc,
971+
{
972+
self.iter.fold(init, flatten_one(fold))
973+
}
974+
975+
#[inline]
976+
fn advance_by(&mut self, n: usize) -> Result<(), NonZero<usize>> {
977+
if let Some(n) = NonZero::new(n) {
978+
self.iter.try_fold(n, advance_by_one).map_or(Ok(()), Err)
979+
} else {
980+
// Just advance the outer iterator
981+
self.iter.advance_by(0)
982+
}
983+
}
984+
985+
#[inline]
986+
fn count(self) -> usize {
987+
self.iter.filter_map(into_item).count()
988+
}
989+
990+
#[inline]
991+
fn last(self) -> Option<Self::Item> {
992+
self.iter.filter_map(into_item).last()
993+
}
994+
}
995+
996+
// Note: We don't actually care about `U: DoubleEndedIterator`, since forward and backward are the
997+
// same for a one-shot iterator, but we have to keep that to match the default specialization.
998+
impl<I, U> DoubleEndedIterator for FlattenCompat<I, U>
999+
where
1000+
I: DoubleEndedIterator<Item: IntoIterator<IntoIter = U, Item = U::Item>>,
1001+
U: DoubleEndedIterator + OneShot,
1002+
{
1003+
#[inline]
1004+
fn next_back(&mut self) -> Option<U::Item> {
1005+
while let Some(inner) = self.iter.next_back() {
1006+
if let item @ Some(_) = inner.into_iter().next() {
1007+
return item;
1008+
}
1009+
}
1010+
None
1011+
}
1012+
1013+
#[inline]
1014+
fn try_rfold<Acc, Fold, R>(&mut self, init: Acc, fold: Fold) -> R
1015+
where
1016+
Self: Sized,
1017+
Fold: FnMut(Acc, Self::Item) -> R,
1018+
R: Try<Output = Acc>,
1019+
{
1020+
self.iter.try_rfold(init, try_flatten_one(fold))
1021+
}
1022+
1023+
#[inline]
1024+
fn rfold<Acc, Fold>(self, init: Acc, fold: Fold) -> Acc
1025+
where
1026+
Fold: FnMut(Acc, Self::Item) -> Acc,
1027+
{
1028+
self.iter.rfold(init, flatten_one(fold))
1029+
}
1030+
1031+
#[inline]
1032+
fn advance_back_by(&mut self, n: usize) -> Result<(), NonZero<usize>> {
1033+
if let Some(n) = NonZero::new(n) {
1034+
self.iter.try_rfold(n, advance_by_one).map_or(Ok(()), Err)
1035+
} else {
1036+
// Just advance the outer iterator
1037+
self.iter.advance_back_by(0)
1038+
}
1039+
}
1040+
}

library/core/tests/iter/adapters/flatten.rs

+66
Original file line numberDiff line numberDiff line change
@@ -212,3 +212,69 @@ fn test_flatten_last() {
212212
assert_eq!(it.advance_by(3), Ok(())); // 22..22
213213
assert_eq!(it.clone().last(), None);
214214
}
215+
216+
#[test]
217+
fn test_flatten_one_shot() {
218+
// This could be `filter_map`, but people often do flatten options.
219+
let mut it = (0i8..10).flat_map(|i| NonZero::new(i % 7));
220+
assert_eq!(it.size_hint(), (0, Some(10)));
221+
assert_eq!(it.clone().count(), 8);
222+
assert_eq!(it.clone().last(), NonZero::new(2));
223+
224+
// sum -> fold
225+
let sum: i8 = it.clone().map(|n| n.get()).sum();
226+
assert_eq!(sum, 24);
227+
228+
// the product overflows at 6, remaining are 7,8,9 -> 1,2
229+
let one = NonZero::new(1i8).unwrap();
230+
let product = it.try_fold(one, |acc, x| acc.checked_mul(x));
231+
assert_eq!(product, None);
232+
assert_eq!(it.size_hint(), (0, Some(3)));
233+
assert_eq!(it.clone().count(), 2);
234+
235+
assert_eq!(it.advance_by(0), Ok(()));
236+
assert_eq!(it.clone().next(), NonZero::new(1));
237+
assert_eq!(it.advance_by(1), Ok(()));
238+
assert_eq!(it.clone().next(), NonZero::new(2));
239+
assert_eq!(it.advance_by(100), Err(NonZero::new(99).unwrap()));
240+
assert_eq!(it.next(), None);
241+
}
242+
243+
#[test]
244+
fn test_flatten_one_shot_rev() {
245+
let mut it = (0i8..10).flat_map(|i| NonZero::new(i % 7)).rev();
246+
assert_eq!(it.size_hint(), (0, Some(10)));
247+
assert_eq!(it.clone().count(), 8);
248+
assert_eq!(it.clone().last(), NonZero::new(1));
249+
250+
// sum -> Rev fold -> rfold
251+
let sum: i8 = it.clone().map(|n| n.get()).sum();
252+
assert_eq!(sum, 24);
253+
254+
// Rev try_fold -> try_rfold
255+
// the product overflows at 4, remaining are 3,2,1,0 -> 3,2,1
256+
let one = NonZero::new(1i8).unwrap();
257+
let product = it.try_fold(one, |acc, x| acc.checked_mul(x));
258+
assert_eq!(product, None);
259+
assert_eq!(it.size_hint(), (0, Some(4)));
260+
assert_eq!(it.clone().count(), 3);
261+
262+
// Rev advance_by -> advance_back_by
263+
assert_eq!(it.advance_by(0), Ok(()));
264+
assert_eq!(it.clone().next(), NonZero::new(3));
265+
assert_eq!(it.advance_by(1), Ok(()));
266+
assert_eq!(it.clone().next(), NonZero::new(2));
267+
assert_eq!(it.advance_by(100), Err(NonZero::new(98).unwrap()));
268+
assert_eq!(it.next(), None);
269+
}
270+
271+
#[test]
272+
fn test_flatten_one_shot_arrays() {
273+
let it = (0..10).flat_map(|i| [i]);
274+
assert_eq!(it.size_hint(), (10, Some(10)));
275+
assert_eq!(it.sum::<i32>(), 45);
276+
277+
let mut it = (0..10).flat_map(|_| -> [i32; 0] { [] });
278+
assert_eq!(it.size_hint(), (0, Some(0)));
279+
assert_eq!(it.next(), None);
280+
}

0 commit comments

Comments
 (0)