diff --git a/src/iter/collect/mod.rs b/src/iter/collect/mod.rs index 7cbf215c4..f4ff49aa1 100644 --- a/src/iter/collect/mod.rs +++ b/src/iter/collect/mod.rs @@ -1,4 +1,4 @@ -use super::{IndexedParallelIterator, IntoParallelIterator, ParallelExtend, ParallelIterator}; +use super::{IndexedParallelIterator, ParallelIterator}; use std::mem::MaybeUninit; use std::slice; @@ -33,7 +33,7 @@ where /// *any* `ParallelIterator` here, and `CollectConsumer` has to also implement /// `UnindexedConsumer`. That implementation panics `unreachable!` in case /// there's a bug where we actually do try to use this unindexed. -fn special_extend(pi: I, len: usize, v: &mut Vec) +pub(super) fn special_extend(pi: I, len: usize, v: &mut Vec) where I: ParallelIterator, T: Send, @@ -141,33 +141,3 @@ impl<'c, T: Send + 'c> Collect<'c, T> { unsafe { slice::from_raw_parts_mut(tail_ptr, len) } } } - -/// Extends a vector with items from a parallel iterator. -impl ParallelExtend for Vec -where - T: Send, -{ - fn par_extend(&mut self, par_iter: I) - where - I: IntoParallelIterator, - { - // See the vec_collect benchmarks in rayon-demo for different strategies. - let par_iter = par_iter.into_par_iter(); - match par_iter.opt_len() { - Some(len) => { - // When Rust gets specialization, we can get here for indexed iterators - // without relying on `opt_len`. Until then, `special_extend()` fakes - // an unindexed mode on the promise that `opt_len()` is accurate. - special_extend(par_iter, len, self); - } - None => { - // This works like `extend`, but `Vec::append` is more efficient. - let list = super::extend::collect(par_iter); - self.reserve(super::extend::len(&list)); - for mut vec in list { - self.append(&mut vec); - } - } - } - } -} diff --git a/src/iter/extend.rs b/src/iter/extend.rs index fb89249f5..1769d476b 100644 --- a/src/iter/extend.rs +++ b/src/iter/extend.rs @@ -1,4 +1,5 @@ use super::noop::NoopConsumer; +use super::plumbing::{Consumer, Folder, Reducer, UnindexedConsumer}; use super::{IntoParallelIterator, ParallelExtend, ParallelIterator}; use std::borrow::Cow; @@ -9,55 +10,91 @@ use std::hash::{BuildHasher, Hash}; /// Performs a generic `par_extend` by collecting to a `LinkedList>` in /// parallel, then extending the collection sequentially. -fn extend(collection: &mut C, par_iter: I, reserve: F) -where - I: IntoParallelIterator, - F: FnOnce(&mut C, &LinkedList>), - C: Extend, -{ - let list = collect(par_iter); - reserve(collection, &list); - for vec in list { - collection.extend(vec); - } +macro_rules! extend { + ($self:ident, $par_iter:ident, $extend:ident) => { + $extend( + $self, + $par_iter.into_par_iter().drive_unindexed(ListVecConsumer), + ); + }; } -pub(super) fn collect(par_iter: I) -> LinkedList> -where - I: IntoParallelIterator, -{ - par_iter - .into_par_iter() - .fold(Vec::new, vec_push) - .map(as_list) - .reduce(LinkedList::new, list_append) +/// Computes the total length of a `LinkedList>`. +fn len(list: &LinkedList>) -> usize { + list.iter().map(Vec::len).sum() } -fn vec_push(mut vec: Vec, elem: T) -> Vec { - vec.push(elem); - vec -} +struct ListVecConsumer; -fn as_list(item: T) -> LinkedList { - let mut list = LinkedList::new(); - list.push_back(item); - list +struct ListVecFolder { + vec: Vec, } -fn list_append(mut list1: LinkedList, mut list2: LinkedList) -> LinkedList { - list1.append(&mut list2); - list1 +impl Consumer for ListVecConsumer { + type Folder = ListVecFolder; + type Reducer = ListReducer; + type Result = LinkedList>; + + fn split_at(self, _index: usize) -> (Self, Self, Self::Reducer) { + (Self, Self, ListReducer) + } + + fn into_folder(self) -> Self::Folder { + ListVecFolder { vec: Vec::new() } + } + + fn full(&self) -> bool { + false + } } -/// Computes the total length of a `LinkedList>`. -pub(super) fn len(list: &LinkedList>) -> usize { - list.iter().map(Vec::len).sum() +impl UnindexedConsumer for ListVecConsumer { + fn split_off_left(&self) -> Self { + Self + } + + fn to_reducer(&self) -> Self::Reducer { + ListReducer + } } -fn no_reserve(_: &mut C, _: &LinkedList>) {} +impl Folder for ListVecFolder { + type Result = LinkedList>; + + fn consume(mut self, item: T) -> Self { + self.vec.push(item); + self + } + + fn consume_iter(mut self, iter: I) -> Self + where + I: IntoIterator, + { + self.vec.extend(iter); + self + } -fn heap_reserve(heap: &mut BinaryHeap, list: &LinkedList>) { - heap.reserve(len(list)); + fn complete(self) -> Self::Result { + let mut list = LinkedList::new(); + if !self.vec.is_empty() { + list.push_back(self.vec); + } + list + } + + fn full(&self) -> bool { + false + } +} + +fn heap_extend(heap: &mut BinaryHeap, list: LinkedList>) +where + BinaryHeap: Extend, +{ + heap.reserve(len(&list)); + for vec in list { + heap.extend(vec); + } } /// Extends a binary heap with items from a parallel iterator. @@ -69,7 +106,7 @@ where where I: IntoParallelIterator, { - extend(self, par_iter, heap_reserve); + extend!(self, par_iter, heap_extend); } } @@ -82,7 +119,16 @@ where where I: IntoParallelIterator, { - extend(self, par_iter, heap_reserve); + extend!(self, par_iter, heap_extend); + } +} + +fn btree_map_extend(map: &mut BTreeMap, list: LinkedList>) +where + BTreeMap: Extend, +{ + for vec in list { + map.extend(vec); } } @@ -96,7 +142,7 @@ where where I: IntoParallelIterator, { - extend(self, par_iter, no_reserve); + extend!(self, par_iter, btree_map_extend); } } @@ -110,7 +156,16 @@ where where I: IntoParallelIterator, { - extend(self, par_iter, no_reserve); + extend!(self, par_iter, btree_map_extend); + } +} + +fn btree_set_extend(set: &mut BTreeSet, list: LinkedList>) +where + BTreeSet: Extend, +{ + for vec in list { + set.extend(vec); } } @@ -123,7 +178,7 @@ where where I: IntoParallelIterator, { - extend(self, par_iter, no_reserve); + extend!(self, par_iter, btree_set_extend); } } @@ -136,16 +191,20 @@ where where I: IntoParallelIterator, { - extend(self, par_iter, no_reserve); + extend!(self, par_iter, btree_set_extend); } } -fn map_reserve(map: &mut HashMap, list: &LinkedList>) +fn hash_map_extend(map: &mut HashMap, list: LinkedList>) where + HashMap: Extend, K: Eq + Hash, S: BuildHasher, { - map.reserve(len(list)); + map.reserve(len(&list)); + for vec in list { + map.extend(vec); + } } /// Extends a hash map with items from a parallel iterator. @@ -160,7 +219,7 @@ where I: IntoParallelIterator, { // See the map_collect benchmarks in rayon-demo for different strategies. - extend(self, par_iter, map_reserve); + extend!(self, par_iter, hash_map_extend); } } @@ -175,16 +234,20 @@ where where I: IntoParallelIterator, { - extend(self, par_iter, map_reserve); + extend!(self, par_iter, hash_map_extend); } } -fn set_reserve(set: &mut HashSet, list: &LinkedList>) +fn hash_set_extend(set: &mut HashSet, list: LinkedList>) where + HashSet: Extend, T: Eq + Hash, S: BuildHasher, { - set.reserve(len(list)); + set.reserve(len(&list)); + for vec in list { + set.extend(vec); + } } /// Extends a hash set with items from a parallel iterator. @@ -197,7 +260,7 @@ where where I: IntoParallelIterator, { - extend(self, par_iter, set_reserve); + extend!(self, par_iter, hash_set_extend); } } @@ -211,15 +274,10 @@ where where I: IntoParallelIterator, { - extend(self, par_iter, set_reserve); + extend!(self, par_iter, hash_set_extend); } } -fn list_push_back(mut list: LinkedList, elem: T) -> LinkedList { - list.push_back(elem); - list -} - /// Extends a linked list with items from a parallel iterator. impl ParallelExtend for LinkedList where @@ -229,10 +287,7 @@ where where I: IntoParallelIterator, { - let mut list = par_iter - .into_par_iter() - .fold(LinkedList::new, list_push_back) - .reduce(LinkedList::new, list_append); + let mut list = par_iter.into_par_iter().drive_unindexed(ListConsumer); self.append(&mut list); } } @@ -246,13 +301,83 @@ where where I: IntoParallelIterator, { - self.par_extend(par_iter.into_par_iter().cloned()) + self.par_extend(par_iter.into_par_iter().copied()) + } +} + +struct ListConsumer; + +struct ListFolder { + list: LinkedList, +} + +struct ListReducer; + +impl Consumer for ListConsumer { + type Folder = ListFolder; + type Reducer = ListReducer; + type Result = LinkedList; + + fn split_at(self, _index: usize) -> (Self, Self, Self::Reducer) { + (Self, Self, ListReducer) + } + + fn into_folder(self) -> Self::Folder { + ListFolder { + list: LinkedList::new(), + } + } + + fn full(&self) -> bool { + false + } +} + +impl UnindexedConsumer for ListConsumer { + fn split_off_left(&self) -> Self { + Self + } + + fn to_reducer(&self) -> Self::Reducer { + ListReducer + } +} + +impl Folder for ListFolder { + type Result = LinkedList; + + fn consume(mut self, item: T) -> Self { + self.list.push_back(item); + self + } + + fn consume_iter(mut self, iter: I) -> Self + where + I: IntoIterator, + { + self.list.extend(iter); + self + } + + fn complete(self) -> Self::Result { + self.list + } + + fn full(&self) -> bool { + false + } +} + +impl Reducer> for ListReducer { + fn reduce(self, mut left: LinkedList, mut right: LinkedList) -> LinkedList { + left.append(&mut right); + left } } -fn string_push(mut string: String, ch: char) -> String { - string.push(ch); - string +fn flat_string_extend(string: &mut String, list: LinkedList) { + string.reserve(list.iter().map(String::len).sum()); + string.extend(list); } /// Extends a string with characters from a parallel iterator. @@ -263,14 +388,8 @@ impl ParallelExtend for String { { // This is like `extend`, but `Vec` is less efficient to deal // with than `String`, so instead collect to `LinkedList`. - let list: LinkedList<_> = par_iter - .into_par_iter() - .fold(String::new, string_push) - .map(as_list) - .reduce(LinkedList::new, list_append); - - self.reserve(list.iter().map(String::len).sum()); - self.extend(list) + let list = par_iter.into_par_iter().drive_unindexed(ListStringConsumer); + flat_string_extend(self, list); } } @@ -280,13 +399,85 @@ impl<'a> ParallelExtend<&'a char> for String { where I: IntoParallelIterator, { - self.par_extend(par_iter.into_par_iter().cloned()) + self.par_extend(par_iter.into_par_iter().copied()) } } -fn string_reserve>(string: &mut String, list: &LinkedList>) { - let len = list.iter().flatten().map(T::as_ref).map(str::len).sum(); +struct ListStringConsumer; + +struct ListStringFolder { + string: String, +} + +impl Consumer for ListStringConsumer { + type Folder = ListStringFolder; + type Reducer = ListReducer; + type Result = LinkedList; + + fn split_at(self, _index: usize) -> (Self, Self, Self::Reducer) { + (Self, Self, ListReducer) + } + + fn into_folder(self) -> Self::Folder { + ListStringFolder { + string: String::new(), + } + } + + fn full(&self) -> bool { + false + } +} + +impl UnindexedConsumer for ListStringConsumer { + fn split_off_left(&self) -> Self { + Self + } + + fn to_reducer(&self) -> Self::Reducer { + ListReducer + } +} + +impl Folder for ListStringFolder { + type Result = LinkedList; + + fn consume(mut self, item: char) -> Self { + self.string.push(item); + self + } + + fn consume_iter(mut self, iter: I) -> Self + where + I: IntoIterator, + { + self.string.extend(iter); + self + } + + fn complete(self) -> Self::Result { + let mut list = LinkedList::new(); + if !self.string.is_empty() { + list.push_back(self.string); + } + list + } + + fn full(&self) -> bool { + false + } +} + +fn string_extend(string: &mut String, list: LinkedList>) +where + String: Extend, + Item: AsRef, +{ + let len = list.iter().flatten().map(Item::as_ref).map(str::len).sum(); string.reserve(len); + for vec in list { + string.extend(vec); + } } /// Extends a string with string slices from a parallel iterator. @@ -295,7 +486,7 @@ impl<'a> ParallelExtend<&'a str> for String { where I: IntoParallelIterator, { - extend(self, par_iter, string_reserve); + extend!(self, par_iter, string_extend); } } @@ -305,7 +496,7 @@ impl ParallelExtend for String { where I: IntoParallelIterator, { - extend(self, par_iter, string_reserve); + extend!(self, par_iter, string_extend); } } @@ -315,12 +506,18 @@ impl<'a> ParallelExtend> for String { where I: IntoParallelIterator>, { - extend(self, par_iter, string_reserve); + extend!(self, par_iter, string_extend); } } -fn deque_reserve(deque: &mut VecDeque, list: &LinkedList>) { - deque.reserve(len(list)); +fn deque_extend(deque: &mut VecDeque, list: LinkedList>) +where + VecDeque: Extend, +{ + deque.reserve(len(&list)); + for vec in list { + deque.extend(vec); + } } /// Extends a deque with items from a parallel iterator. @@ -332,7 +529,7 @@ where where I: IntoParallelIterator, { - extend(self, par_iter, deque_reserve); + extend!(self, par_iter, deque_extend); } } @@ -345,12 +542,43 @@ where where I: IntoParallelIterator, { - extend(self, par_iter, deque_reserve); + extend!(self, par_iter, deque_extend); } } -// See the `collect` module for the `Vec` implementation. -// impl ParallelExtend for Vec +fn vec_append(vec: &mut Vec, list: LinkedList>) { + vec.reserve(len(&list)); + for mut other in list { + vec.append(&mut other); + } +} + +/// Extends a vector with items from a parallel iterator. +impl ParallelExtend for Vec +where + T: Send, +{ + fn par_extend(&mut self, par_iter: I) + where + I: IntoParallelIterator, + { + // See the vec_collect benchmarks in rayon-demo for different strategies. + let par_iter = par_iter.into_par_iter(); + match par_iter.opt_len() { + Some(len) => { + // When Rust gets specialization, we can get here for indexed iterators + // without relying on `opt_len`. Until then, `special_extend()` fakes + // an unindexed mode on the promise that `opt_len()` is accurate. + super::collect::special_extend(par_iter, len, self); + } + None => { + // This works like `extend`, but `Vec::append` is more efficient. + let list = par_iter.drive_unindexed(ListVecConsumer); + vec_append(self, list); + } + } + } +} /// Extends a vector with copied items from a parallel iterator. impl<'a, T> ParallelExtend<&'a T> for Vec @@ -361,7 +589,7 @@ where where I: IntoParallelIterator, { - self.par_extend(par_iter.into_par_iter().cloned()) + self.par_extend(par_iter.into_par_iter().copied()) } }