Skip to content

Commit

Permalink
Merge pull request #925 from jakeKonrad/master
Browse files Browse the repository at this point in the history
Added A `par_chunk_by` method
  • Loading branch information
cuviper authored Mar 24, 2024
2 parents 9ee7649 + e37ec9e commit ac2fa4d
Show file tree
Hide file tree
Showing 5 changed files with 342 additions and 0 deletions.
244 changes: 244 additions & 0 deletions src/slice/chunk_by.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
use crate::iter::plumbing::*;
use crate::iter::*;
use std::marker::PhantomData;
use std::{fmt, mem};

trait ChunkBySlice<T>: AsRef<[T]> + Default + Send {
fn split(self, index: usize) -> (Self, Self);

fn find(&self, pred: &impl Fn(&T, &T) -> bool, start: usize, end: usize) -> Option<usize> {
self.as_ref()[start..end]
.windows(2)
.position(move |w| !pred(&w[0], &w[1]))
.map(|i| i + 1)
}

fn rfind(&self, pred: &impl Fn(&T, &T) -> bool, end: usize) -> Option<usize> {
self.as_ref()[..end]
.windows(2)
.rposition(move |w| !pred(&w[0], &w[1]))
.map(|i| i + 1)
}
}

impl<T: Sync> ChunkBySlice<T> for &[T] {
fn split(self, index: usize) -> (Self, Self) {
self.split_at(index)
}
}

impl<T: Send> ChunkBySlice<T> for &mut [T] {
fn split(self, index: usize) -> (Self, Self) {
self.split_at_mut(index)
}
}

struct ChunkByProducer<'p, T, Slice, Pred> {
slice: Slice,
pred: &'p Pred,
tail: usize,
marker: PhantomData<fn(&T)>,
}

// Note: this implementation is very similar to `SplitProducer`.
impl<T, Slice, Pred> UnindexedProducer for ChunkByProducer<'_, T, Slice, Pred>
where
Slice: ChunkBySlice<T>,
Pred: Fn(&T, &T) -> bool + Send + Sync,
{
type Item = Slice;

fn split(self) -> (Self, Option<Self>) {
if self.tail < 2 {
return (Self { tail: 0, ..self }, None);
}

// Look forward for the separator, and failing that look backward.
let mid = self.tail / 2;
let index = match self.slice.find(self.pred, mid, self.tail) {
Some(i) => Some(mid + i),
None => self.slice.rfind(self.pred, mid + 1),
};

if let Some(index) = index {
let (left, right) = self.slice.split(index);

let (left_tail, right_tail) = if index <= mid {
// If we scanned backwards to find the separator, everything in
// the right side is exhausted, with no separators left to find.
(index, 0)
} else {
(mid + 1, self.tail - index)
};

// Create the left split before the separator.
let left = Self {
slice: left,
tail: left_tail,
..self
};

// Create the right split following the separator.
let right = Self {
slice: right,
tail: right_tail,
..self
};

(left, Some(right))
} else {
// The search is exhausted, no more separators...
(Self { tail: 0, ..self }, None)
}
}

fn fold_with<F>(self, mut folder: F) -> F
where
F: Folder<Self::Item>,
{
let Self {
slice, pred, tail, ..
} = self;

let (slice, tail) = if tail == slice.as_ref().len() {
// No tail section, so just let `consume_iter` do it all.
(Some(slice), None)
} else if let Some(index) = slice.rfind(pred, tail) {
// We found the last separator to complete the tail, so
// end with that slice after `consume_iter` finds the rest.
let (left, right) = slice.split(index);
(Some(left), Some(right))
} else {
// We know there are no separators at all, so it's all "tail".
(None, Some(slice))
};

if let Some(mut slice) = slice {
// TODO (MSRV 1.77) use either:
// folder.consume_iter(slice.chunk_by(pred))
// folder.consume_iter(slice.chunk_by_mut(pred))

folder = folder.consume_iter(std::iter::from_fn(move || {
let len = slice.as_ref().len();
if len > 0 {
let i = slice.find(pred, 0, len).unwrap_or(len);
let (head, tail) = mem::take(&mut slice).split(i);
slice = tail;
Some(head)
} else {
None
}
}));
}

if let Some(tail) = tail {
folder = folder.consume(tail);
}

folder
}
}

/// Parallel iterator over slice in (non-overlapping) chunks separated by a predicate.
///
/// This struct is created by the [`par_chunk_by`] method on `&[T]`.
///
/// [`par_chunk_by`]: trait.ParallelSlice.html#method.par_chunk_by
pub struct ChunkBy<'data, T, P> {
pred: P,
slice: &'data [T],
}

impl<'data, T, P: Clone> Clone for ChunkBy<'data, T, P> {
fn clone(&self) -> Self {
ChunkBy {
pred: self.pred.clone(),
slice: self.slice,
}
}
}

impl<'data, T: fmt::Debug, P> fmt::Debug for ChunkBy<'data, T, P> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ChunkBy")
.field("slice", &self.slice)
.finish()
}
}

impl<'data, T, P> ChunkBy<'data, T, P> {
pub(super) fn new(slice: &'data [T], pred: P) -> Self {
Self { pred, slice }
}
}

impl<'data, T, P> ParallelIterator for ChunkBy<'data, T, P>
where
T: Sync,
P: Fn(&T, &T) -> bool + Send + Sync,
{
type Item = &'data [T];

fn drive_unindexed<C>(self, consumer: C) -> C::Result
where
C: UnindexedConsumer<Self::Item>,
{
bridge_unindexed(
ChunkByProducer {
tail: self.slice.len(),
slice: self.slice,
pred: &self.pred,
marker: PhantomData,
},
consumer,
)
}
}

/// Parallel iterator over slice in (non-overlapping) mutable chunks
/// separated by a predicate.
///
/// This struct is created by the [`par_chunk_by_mut`] method on `&mut [T]`.
///
/// [`par_chunk_by_mut`]: trait.ParallelSliceMut.html#method.par_chunk_by_mut
pub struct ChunkByMut<'data, T, P> {
pred: P,
slice: &'data mut [T],
}

impl<'data, T: fmt::Debug, P> fmt::Debug for ChunkByMut<'data, T, P> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ChunkByMut")
.field("slice", &self.slice)
.finish()
}
}

impl<'data, T, P> ChunkByMut<'data, T, P> {
pub(super) fn new(slice: &'data mut [T], pred: P) -> Self {
Self { pred, slice }
}
}

impl<'data, T, P> ParallelIterator for ChunkByMut<'data, T, P>
where
T: Send,
P: Fn(&T, &T) -> bool + Send + Sync,
{
type Item = &'data mut [T];

fn drive_unindexed<C>(self, consumer: C) -> C::Result
where
C: UnindexedConsumer<Self::Item>,
{
bridge_unindexed(
ChunkByProducer {
tail: self.slice.len(),
slice: self.slice,
pred: &self.pred,
marker: PhantomData,
},
consumer,
)
}
}
49 changes: 49 additions & 0 deletions src/slice/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
//!
//! [std::slice]: https://doc.rust-lang.org/stable/std/slice/
mod chunk_by;
mod chunks;
mod mergesort;
mod quicksort;
Expand All @@ -22,6 +23,7 @@ use std::cmp::Ordering;
use std::fmt::{self, Debug};
use std::mem;

pub use self::chunk_by::{ChunkBy, ChunkByMut};
pub use self::chunks::{Chunks, ChunksExact, ChunksExactMut, ChunksMut};
pub use self::rchunks::{RChunks, RChunksExact, RChunksExactMut, RChunksMut};

Expand Down Expand Up @@ -173,6 +175,29 @@ pub trait ParallelSlice<T: Sync> {
assert!(chunk_size != 0, "chunk_size must not be zero");
RChunksExact::new(chunk_size, self.as_parallel_slice())
}

/// Returns a parallel iterator over the slice producing non-overlapping runs
/// of elements using the predicate to separate them.
///
/// The predicate is called on two elements following themselves,
/// it means the predicate is called on `slice[0]` and `slice[1]`
/// then on `slice[1]` and `slice[2]` and so on.
///
/// # Examples
///
/// ```
/// use rayon::prelude::*;
/// let chunks: Vec<_> = [1, 2, 2, 3, 3, 3].par_chunk_by(|&x, &y| x == y).collect();
/// assert_eq!(chunks[0], &[1]);
/// assert_eq!(chunks[1], &[2, 2]);
/// assert_eq!(chunks[2], &[3, 3, 3]);
/// ```
fn par_chunk_by<F>(&self, pred: F) -> ChunkBy<'_, T, F>
where
F: Fn(&T, &T) -> bool + Send + Sync,
{
ChunkBy::new(self.as_parallel_slice(), pred)
}
}

impl<T: Sync> ParallelSlice<T> for [T] {
Expand Down Expand Up @@ -704,6 +729,30 @@ pub trait ParallelSliceMut<T: Send> {
{
par_quicksort(self.as_parallel_slice_mut(), |a, b| f(a).lt(&f(b)));
}

/// Returns a parallel iterator over the slice producing non-overlapping mutable
/// runs of elements using the predicate to separate them.
///
/// The predicate is called on two elements following themselves,
/// it means the predicate is called on `slice[0]` and `slice[1]`
/// then on `slice[1]` and `slice[2]` and so on.
///
/// # Examples
///
/// ```
/// use rayon::prelude::*;
/// let mut xs = [1, 2, 2, 3, 3, 3];
/// let chunks: Vec<_> = xs.par_chunk_by_mut(|&x, &y| x == y).collect();
/// assert_eq!(chunks[0], &mut [1]);
/// assert_eq!(chunks[1], &mut [2, 2]);
/// assert_eq!(chunks[2], &mut [3, 3, 3]);
/// ```
fn par_chunk_by_mut<F>(&mut self, pred: F) -> ChunkByMut<'_, T, F>
where
F: Fn(&T, &T) -> bool + Send + Sync,
{
ChunkByMut::new(self.as_parallel_slice_mut(), pred)
}
}

impl<T: Send> ParallelSliceMut<T> for [T] {
Expand Down
46 changes: 46 additions & 0 deletions src/slice/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use rand::distributions::Uniform;
use rand::seq::SliceRandom;
use rand::{thread_rng, Rng};
use std::cmp::Ordering::{Equal, Greater, Less};
use std::sync::atomic::{AtomicUsize, Ordering::Relaxed};

macro_rules! sort {
($f:ident, $name:ident) => {
Expand Down Expand Up @@ -168,3 +169,48 @@ fn test_par_rchunks_exact_mut_remainder() {
assert_eq!(c.take_remainder(), &[]);
assert_eq!(c.len(), 2);
}

#[test]
fn slice_chunk_by() {
let v: Vec<_> = (0..1000).collect();
assert_eq!(v[..0].par_chunk_by(|_, _| todo!()).count(), 0);
assert_eq!(v[..1].par_chunk_by(|_, _| todo!()).count(), 1);
assert_eq!(v[..2].par_chunk_by(|_, _| true).count(), 1);
assert_eq!(v[..2].par_chunk_by(|_, _| false).count(), 2);

let count = AtomicUsize::new(0);
let par: Vec<_> = v
.par_chunk_by(|x, y| {
count.fetch_add(1, Relaxed);
(x % 10 < 3) == (y % 10 < 3)
})
.collect();
assert_eq!(count.into_inner(), v.len() - 1);

let seq: Vec<_> = v.chunk_by(|x, y| (x % 10 < 3) == (y % 10 < 3)).collect();
assert_eq!(par, seq);
}

#[test]
fn slice_chunk_by_mut() {
let mut v: Vec<_> = (0..1000).collect();
assert_eq!(v[..0].par_chunk_by_mut(|_, _| todo!()).count(), 0);
assert_eq!(v[..1].par_chunk_by_mut(|_, _| todo!()).count(), 1);
assert_eq!(v[..2].par_chunk_by_mut(|_, _| true).count(), 1);
assert_eq!(v[..2].par_chunk_by_mut(|_, _| false).count(), 2);

let mut v2 = v.clone();
let count = AtomicUsize::new(0);
let par: Vec<_> = v
.par_chunk_by_mut(|x, y| {
count.fetch_add(1, Relaxed);
(x % 10 < 3) == (y % 10 < 3)
})
.collect();
assert_eq!(count.into_inner(), v2.len() - 1);

let seq: Vec<_> = v2
.chunk_by_mut(|x, y| (x % 10 < 3) == (y % 10 < 3))
.collect();
assert_eq!(par, seq);
}
1 change: 1 addition & 0 deletions tests/clones.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ fn clone_str() {
fn clone_vec() {
let v: Vec<_> = (0..1000).collect();
check(v.par_iter());
check(v.par_chunk_by(i32::eq));
check(v.par_chunks(42));
check(v.par_chunks_exact(42));
check(v.par_rchunks(42));
Expand Down
2 changes: 2 additions & 0 deletions tests/debug.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ fn debug_vec() {
let mut v: Vec<_> = (0..10).collect();
check(v.par_iter());
check(v.par_iter_mut());
check(v.par_chunk_by(i32::eq));
check(v.par_chunk_by_mut(i32::eq));
check(v.par_chunks(42));
check(v.par_chunks_exact(42));
check(v.par_chunks_mut(42));
Expand Down

0 comments on commit ac2fa4d

Please sign in to comment.