Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend the ParallelExtend implementations #1129

Merged
merged 4 commits into from
Feb 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
218 changes: 109 additions & 109 deletions src/iter/extend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,83 @@ use super::noop::NoopConsumer;
use super::plumbing::{Consumer, Folder, Reducer, UnindexedConsumer};
use super::{IntoParallelIterator, ParallelExtend, ParallelIterator};

use either::Either;
use std::borrow::Cow;
use std::collections::LinkedList;
use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet};
use std::collections::{BinaryHeap, VecDeque};
use std::ffi::{OsStr, OsString};
use std::hash::{BuildHasher, Hash};

/// Performs a generic `par_extend` by collecting to a `LinkedList<Vec<_>>` in
/// parallel, then extending the collection sequentially.
macro_rules! extend {
($self:ident, $par_iter:ident, $extend:ident) => {
$extend($self, drive_list_vec($par_iter));
($self:ident, $par_iter:ident) => {
extend!($self <- fast_collect($par_iter))
};
($self:ident <- $vecs:expr) => {
match $vecs {
Either::Left(vec) => $self.extend(vec),
Either::Right(list) => {
for vec in list {
$self.extend(vec);
}
}
}
};
}
macro_rules! extend_reserved {
($self:ident, $par_iter:ident, $len:ident) => {
let vecs = fast_collect($par_iter);
$self.reserve($len(&vecs));
extend!($self <- vecs)
};
($self:ident, $par_iter:ident) => {
extend_reserved!($self, $par_iter, len)
};
}

/// Computes the total length of a `fast_collect` result.
fn len<T>(vecs: &Either<Vec<T>, LinkedList<Vec<T>>>) -> usize {
match vecs {
Either::Left(vec) => vec.len(),
Either::Right(list) => list.iter().map(Vec::len).sum(),
}
}

/// Computes the total length of a `LinkedList<Vec<_>>`.
fn len<T>(list: &LinkedList<Vec<T>>) -> usize {
list.iter().map(Vec::len).sum()
/// Computes the total string length of a `fast_collect` result.
fn string_len<T: AsRef<str>>(vecs: &Either<Vec<T>, LinkedList<Vec<T>>>) -> usize {
let strs = match vecs {
Either::Left(vec) => Either::Left(vec.iter()),
Either::Right(list) => Either::Right(list.iter().flatten()),
};
strs.map(AsRef::as_ref).map(str::len).sum()
}

pub(super) fn drive_list_vec<I, T>(pi: I) -> LinkedList<Vec<T>>
/// Computes the total OS-string length of a `fast_collect` result.
fn osstring_len<T: AsRef<OsStr>>(vecs: &Either<Vec<T>, LinkedList<Vec<T>>>) -> usize {
let osstrs = match vecs {
Either::Left(vec) => Either::Left(vec.iter()),
Either::Right(list) => Either::Right(list.iter().flatten()),
};
osstrs.map(AsRef::as_ref).map(OsStr::len).sum()
}

pub(super) fn fast_collect<I, T>(pi: I) -> Either<Vec<T>, LinkedList<Vec<T>>>
where
I: IntoParallelIterator<Item = T>,
T: Send,
{
pi.into_par_iter().drive_unindexed(ListVecConsumer)
let par_iter = pi.into_par_iter();
match par_iter.opt_len() {
Some(len) => {
// Pseudo-specialization. See impl of ParallelExtend for Vec for more details.
let mut vec = Vec::new();
super::collect::special_extend(par_iter, len, &mut vec);
Either::Left(vec)
}
None => Either::Right(par_iter.drive_unindexed(ListVecConsumer)),
}
}

struct ListVecConsumer;
Expand Down Expand Up @@ -92,16 +144,6 @@ impl<T> Folder<T> for ListVecFolder<T> {
}
}

fn heap_extend<T, Item>(heap: &mut BinaryHeap<T>, list: LinkedList<Vec<Item>>)
where
BinaryHeap<T>: Extend<Item>,
{
heap.reserve(len(&list));
for vec in list {
heap.extend(vec);
}
}

/// Extends a binary heap with items from a parallel iterator.
impl<T> ParallelExtend<T> for BinaryHeap<T>
where
Expand All @@ -111,7 +153,7 @@ where
where
I: IntoParallelIterator<Item = T>,
{
extend!(self, par_iter, heap_extend);
extend_reserved!(self, par_iter);
}
}

Expand All @@ -124,16 +166,7 @@ where
where
I: IntoParallelIterator<Item = &'a T>,
{
extend!(self, par_iter, heap_extend);
}
}

fn btree_map_extend<K, V, Item>(map: &mut BTreeMap<K, V>, list: LinkedList<Vec<Item>>)
where
BTreeMap<K, V>: Extend<Item>,
{
for vec in list {
map.extend(vec);
extend_reserved!(self, par_iter);
}
}

Expand All @@ -147,7 +180,7 @@ where
where
I: IntoParallelIterator<Item = (K, V)>,
{
extend!(self, par_iter, btree_map_extend);
extend!(self, par_iter);
}
}

Expand All @@ -161,16 +194,7 @@ where
where
I: IntoParallelIterator<Item = (&'a K, &'a V)>,
{
extend!(self, par_iter, btree_map_extend);
}
}

fn btree_set_extend<T, Item>(set: &mut BTreeSet<T>, list: LinkedList<Vec<Item>>)
where
BTreeSet<T>: Extend<Item>,
{
for vec in list {
set.extend(vec);
extend!(self, par_iter);
}
}

Expand All @@ -183,7 +207,7 @@ where
where
I: IntoParallelIterator<Item = T>,
{
extend!(self, par_iter, btree_set_extend);
extend!(self, par_iter);
}
}

Expand All @@ -196,19 +220,7 @@ where
where
I: IntoParallelIterator<Item = &'a T>,
{
extend!(self, par_iter, btree_set_extend);
}
}

fn hash_map_extend<K, V, S, Item>(map: &mut HashMap<K, V, S>, list: LinkedList<Vec<Item>>)
where
HashMap<K, V, S>: Extend<Item>,
K: Eq + Hash,
S: BuildHasher,
{
map.reserve(len(&list));
for vec in list {
map.extend(vec);
extend!(self, par_iter);
}
}

Expand All @@ -224,7 +236,7 @@ where
I: IntoParallelIterator<Item = (K, V)>,
{
// See the map_collect benchmarks in rayon-demo for different strategies.
extend!(self, par_iter, hash_map_extend);
extend_reserved!(self, par_iter);
}
}

Expand All @@ -239,19 +251,7 @@ where
where
I: IntoParallelIterator<Item = (&'a K, &'a V)>,
{
extend!(self, par_iter, hash_map_extend);
}
}

fn hash_set_extend<T, S, Item>(set: &mut HashSet<T, S>, list: LinkedList<Vec<Item>>)
where
HashSet<T, S>: Extend<Item>,
T: Eq + Hash,
S: BuildHasher,
{
set.reserve(len(&list));
for vec in list {
set.extend(vec);
extend_reserved!(self, par_iter);
}
}

Expand All @@ -265,7 +265,7 @@ where
where
I: IntoParallelIterator<Item = T>,
{
extend!(self, par_iter, hash_set_extend);
extend_reserved!(self, par_iter);
}
}

Expand All @@ -279,7 +279,7 @@ where
where
I: IntoParallelIterator<Item = &'a T>,
{
extend!(self, par_iter, hash_set_extend);
extend_reserved!(self, par_iter);
}
}

Expand Down Expand Up @@ -380,9 +380,34 @@ impl<T> Reducer<LinkedList<T>> for ListReducer {
}
}

fn flat_string_extend(string: &mut String, list: LinkedList<String>) {
string.reserve(list.iter().map(String::len).sum());
string.extend(list);
/// Extends an OS-string with string slices from a parallel iterator.
impl<'a> ParallelExtend<&'a OsStr> for OsString {
fn par_extend<I>(&mut self, par_iter: I)
where
I: IntoParallelIterator<Item = &'a OsStr>,
{
extend_reserved!(self, par_iter, osstring_len);
}
}

/// Extends an OS-string with strings from a parallel iterator.
impl ParallelExtend<OsString> for OsString {
fn par_extend<I>(&mut self, par_iter: I)
where
I: IntoParallelIterator<Item = OsString>,
{
extend_reserved!(self, par_iter, osstring_len);
}
}

/// Extends an OS-string with string slices from a parallel iterator.
impl<'a> ParallelExtend<Cow<'a, OsStr>> for OsString {
fn par_extend<I>(&mut self, par_iter: I)
where
I: IntoParallelIterator<Item = Cow<'a, OsStr>>,
{
extend_reserved!(self, par_iter, osstring_len);
}
}

/// Extends a string with characters from a parallel iterator.
Expand All @@ -394,7 +419,8 @@ impl ParallelExtend<char> for String {
// This is like `extend`, but `Vec<char>` is less efficient to deal
// with than `String`, so instead collect to `LinkedList<String>`.
let list = par_iter.into_par_iter().drive_unindexed(ListStringConsumer);
flat_string_extend(self, list);
self.reserve(list.iter().map(String::len).sum());
self.extend(list);
}
}

Expand Down Expand Up @@ -473,25 +499,13 @@ impl Folder<char> for ListStringFolder {
}
}

fn string_extend<Item>(string: &mut String, list: LinkedList<Vec<Item>>)
where
String: Extend<Item>,
Item: AsRef<str>,
{
let len = list.iter().flatten().map(Item::as_ref).map(str::len).sum();
string.reserve(len);
for vec in list {
string.extend(vec);
}
}

/// Extends a string with string slices from a parallel iterator.
impl<'a> ParallelExtend<&'a str> for String {
fn par_extend<I>(&mut self, par_iter: I)
where
I: IntoParallelIterator<Item = &'a str>,
{
extend!(self, par_iter, string_extend);
extend_reserved!(self, par_iter, string_len);
}
}

Expand All @@ -501,7 +515,7 @@ impl ParallelExtend<String> for String {
where
I: IntoParallelIterator<Item = String>,
{
extend!(self, par_iter, string_extend);
extend_reserved!(self, par_iter, string_len);
}
}

Expand All @@ -511,7 +525,7 @@ impl ParallelExtend<Box<str>> for String {
where
I: IntoParallelIterator<Item = Box<str>>,
{
extend!(self, par_iter, string_extend);
extend_reserved!(self, par_iter, string_len);
}
}

Expand All @@ -521,17 +535,7 @@ impl<'a> ParallelExtend<Cow<'a, str>> for String {
where
I: IntoParallelIterator<Item = Cow<'a, str>>,
{
extend!(self, par_iter, string_extend);
}
}

fn deque_extend<T, Item>(deque: &mut VecDeque<T>, list: LinkedList<Vec<Item>>)
where
VecDeque<T>: Extend<Item>,
{
deque.reserve(len(&list));
for vec in list {
deque.extend(vec);
extend_reserved!(self, par_iter, string_len);
}
}

Expand All @@ -544,7 +548,7 @@ where
where
I: IntoParallelIterator<Item = T>,
{
extend!(self, par_iter, deque_extend);
extend_reserved!(self, par_iter);
}
}

Expand All @@ -557,14 +561,7 @@ where
where
I: IntoParallelIterator<Item = &'a T>,
{
extend!(self, par_iter, deque_extend);
}
}

fn vec_append<T>(vec: &mut Vec<T>, list: LinkedList<Vec<T>>) {
vec.reserve(len(&list));
for mut other in list {
vec.append(&mut other);
extend_reserved!(self, par_iter);
}
}

Expand All @@ -589,7 +586,10 @@ where
None => {
// This works like `extend`, but `Vec::append` is more efficient.
let list = par_iter.drive_unindexed(ListVecConsumer);
vec_append(self, list);
self.reserve(list.iter().map(Vec::len).sum());
for mut other in list {
self.append(&mut other);
}
}
}
}
Expand Down
Loading