diff --git a/src/slice/chunk_by.rs b/src/slice/chunk_by.rs new file mode 100644 index 000000000..25833cabe --- /dev/null +++ b/src/slice/chunk_by.rs @@ -0,0 +1,244 @@ +use crate::iter::plumbing::*; +use crate::iter::*; +use std::marker::PhantomData; +use std::{fmt, mem}; + +trait ChunkBySlice: AsRef<[T]> + Default + Send { + fn split(self, index: usize) -> (Self, Self); + + fn find(&self, pred: &impl Fn(&T, &T) -> bool, start: usize, end: usize) -> Option { + self.as_ref()[start..end] + .windows(2) + .position(move |w| !pred(&w[0], &w[1])) + .map(|i| i + 1) + } + + fn rfind(&self, pred: &impl Fn(&T, &T) -> bool, end: usize) -> Option { + self.as_ref()[..end] + .windows(2) + .rposition(move |w| !pred(&w[0], &w[1])) + .map(|i| i + 1) + } +} + +impl ChunkBySlice for &[T] { + fn split(self, index: usize) -> (Self, Self) { + self.split_at(index) + } +} + +impl ChunkBySlice for &mut [T] { + fn split(self, index: usize) -> (Self, Self) { + self.split_at_mut(index) + } +} + +struct ChunkByProducer<'p, T, Slice, Pred> { + slice: Slice, + pred: &'p Pred, + tail: usize, + marker: PhantomData, +} + +// Note: this implementation is very similar to `SplitProducer`. +impl UnindexedProducer for ChunkByProducer<'_, T, Slice, Pred> +where + Slice: ChunkBySlice, + Pred: Fn(&T, &T) -> bool + Send + Sync, +{ + type Item = Slice; + + fn split(self) -> (Self, Option) { + if self.tail < 2 { + return (Self { tail: 0, ..self }, None); + } + + // Look forward for the separator, and failing that look backward. + let mid = self.tail / 2; + let index = match self.slice.find(self.pred, mid, self.tail) { + Some(i) => Some(mid + i), + None => self.slice.rfind(self.pred, mid + 1), + }; + + if let Some(index) = index { + let (left, right) = self.slice.split(index); + + let (left_tail, right_tail) = if index <= mid { + // If we scanned backwards to find the separator, everything in + // the right side is exhausted, with no separators left to find. + (index, 0) + } else { + (mid + 1, self.tail - index) + }; + + // Create the left split before the separator. + let left = Self { + slice: left, + tail: left_tail, + ..self + }; + + // Create the right split following the separator. + let right = Self { + slice: right, + tail: right_tail, + ..self + }; + + (left, Some(right)) + } else { + // The search is exhausted, no more separators... + (Self { tail: 0, ..self }, None) + } + } + + fn fold_with(self, mut folder: F) -> F + where + F: Folder, + { + let Self { + slice, pred, tail, .. + } = self; + + let (slice, tail) = if tail == slice.as_ref().len() { + // No tail section, so just let `consume_iter` do it all. + (Some(slice), None) + } else if let Some(index) = slice.rfind(pred, tail) { + // We found the last separator to complete the tail, so + // end with that slice after `consume_iter` finds the rest. + let (left, right) = slice.split(index); + (Some(left), Some(right)) + } else { + // We know there are no separators at all, so it's all "tail". + (None, Some(slice)) + }; + + if let Some(mut slice) = slice { + // TODO (MSRV 1.77) use either: + // folder.consume_iter(slice.chunk_by(pred)) + // folder.consume_iter(slice.chunk_by_mut(pred)) + + folder = folder.consume_iter(std::iter::from_fn(move || { + let len = slice.as_ref().len(); + if len > 0 { + let i = slice.find(pred, 0, len).unwrap_or(len); + let (head, tail) = mem::take(&mut slice).split(i); + slice = tail; + Some(head) + } else { + None + } + })); + } + + if let Some(tail) = tail { + folder = folder.consume(tail); + } + + folder + } +} + +/// Parallel iterator over slice in (non-overlapping) chunks separated by a predicate. +/// +/// This struct is created by the [`par_chunk_by`] method on `&[T]`. +/// +/// [`par_chunk_by`]: trait.ParallelSlice.html#method.par_chunk_by +pub struct ChunkBy<'data, T, P> { + pred: P, + slice: &'data [T], +} + +impl<'data, T, P: Clone> Clone for ChunkBy<'data, T, P> { + fn clone(&self) -> Self { + ChunkBy { + pred: self.pred.clone(), + slice: self.slice, + } + } +} + +impl<'data, T: fmt::Debug, P> fmt::Debug for ChunkBy<'data, T, P> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ChunkBy") + .field("slice", &self.slice) + .finish() + } +} + +impl<'data, T, P> ChunkBy<'data, T, P> { + pub(super) fn new(slice: &'data [T], pred: P) -> Self { + Self { pred, slice } + } +} + +impl<'data, T, P> ParallelIterator for ChunkBy<'data, T, P> +where + T: Sync, + P: Fn(&T, &T) -> bool + Send + Sync, +{ + type Item = &'data [T]; + + fn drive_unindexed(self, consumer: C) -> C::Result + where + C: UnindexedConsumer, + { + bridge_unindexed( + ChunkByProducer { + tail: self.slice.len(), + slice: self.slice, + pred: &self.pred, + marker: PhantomData, + }, + consumer, + ) + } +} + +/// Parallel iterator over slice in (non-overlapping) mutable chunks +/// separated by a predicate. +/// +/// This struct is created by the [`par_chunk_by_mut`] method on `&mut [T]`. +/// +/// [`par_chunk_by_mut`]: trait.ParallelSliceMut.html#method.par_chunk_by_mut +pub struct ChunkByMut<'data, T, P> { + pred: P, + slice: &'data mut [T], +} + +impl<'data, T: fmt::Debug, P> fmt::Debug for ChunkByMut<'data, T, P> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ChunkByMut") + .field("slice", &self.slice) + .finish() + } +} + +impl<'data, T, P> ChunkByMut<'data, T, P> { + pub(super) fn new(slice: &'data mut [T], pred: P) -> Self { + Self { pred, slice } + } +} + +impl<'data, T, P> ParallelIterator for ChunkByMut<'data, T, P> +where + T: Send, + P: Fn(&T, &T) -> bool + Send + Sync, +{ + type Item = &'data mut [T]; + + fn drive_unindexed(self, consumer: C) -> C::Result + where + C: UnindexedConsumer, + { + bridge_unindexed( + ChunkByProducer { + tail: self.slice.len(), + slice: self.slice, + pred: &self.pred, + marker: PhantomData, + }, + consumer, + ) + } +} diff --git a/src/slice/mod.rs b/src/slice/mod.rs index 9e8dcc906..171675e53 100644 --- a/src/slice/mod.rs +++ b/src/slice/mod.rs @@ -5,6 +5,7 @@ //! //! [std::slice]: https://doc.rust-lang.org/stable/std/slice/ +mod chunk_by; mod chunks; mod mergesort; mod quicksort; @@ -22,6 +23,7 @@ use std::cmp::Ordering; use std::fmt::{self, Debug}; use std::mem; +pub use self::chunk_by::{ChunkBy, ChunkByMut}; pub use self::chunks::{Chunks, ChunksExact, ChunksExactMut, ChunksMut}; pub use self::rchunks::{RChunks, RChunksExact, RChunksExactMut, RChunksMut}; @@ -173,6 +175,29 @@ pub trait ParallelSlice { assert!(chunk_size != 0, "chunk_size must not be zero"); RChunksExact::new(chunk_size, self.as_parallel_slice()) } + + /// Returns a parallel iterator over the slice producing non-overlapping runs + /// of elements using the predicate to separate them. + /// + /// The predicate is called on two elements following themselves, + /// it means the predicate is called on `slice[0]` and `slice[1]` + /// then on `slice[1]` and `slice[2]` and so on. + /// + /// # Examples + /// + /// ``` + /// use rayon::prelude::*; + /// let chunks: Vec<_> = [1, 2, 2, 3, 3, 3].par_chunk_by(|&x, &y| x == y).collect(); + /// assert_eq!(chunks[0], &[1]); + /// assert_eq!(chunks[1], &[2, 2]); + /// assert_eq!(chunks[2], &[3, 3, 3]); + /// ``` + fn par_chunk_by(&self, pred: F) -> ChunkBy<'_, T, F> + where + F: Fn(&T, &T) -> bool + Send + Sync, + { + ChunkBy::new(self.as_parallel_slice(), pred) + } } impl ParallelSlice for [T] { @@ -704,6 +729,30 @@ pub trait ParallelSliceMut { { par_quicksort(self.as_parallel_slice_mut(), |a, b| f(a).lt(&f(b))); } + + /// Returns a parallel iterator over the slice producing non-overlapping mutable + /// runs of elements using the predicate to separate them. + /// + /// The predicate is called on two elements following themselves, + /// it means the predicate is called on `slice[0]` and `slice[1]` + /// then on `slice[1]` and `slice[2]` and so on. + /// + /// # Examples + /// + /// ``` + /// use rayon::prelude::*; + /// let mut xs = [1, 2, 2, 3, 3, 3]; + /// let chunks: Vec<_> = xs.par_chunk_by_mut(|&x, &y| x == y).collect(); + /// assert_eq!(chunks[0], &mut [1]); + /// assert_eq!(chunks[1], &mut [2, 2]); + /// assert_eq!(chunks[2], &mut [3, 3, 3]); + /// ``` + fn par_chunk_by_mut(&mut self, pred: F) -> ChunkByMut<'_, T, F> + where + F: Fn(&T, &T) -> bool + Send + Sync, + { + ChunkByMut::new(self.as_parallel_slice_mut(), pred) + } } impl ParallelSliceMut for [T] { diff --git a/src/slice/test.rs b/src/slice/test.rs index f74ca0f74..2538a86b9 100644 --- a/src/slice/test.rs +++ b/src/slice/test.rs @@ -5,6 +5,7 @@ use rand::distributions::Uniform; use rand::seq::SliceRandom; use rand::{thread_rng, Rng}; use std::cmp::Ordering::{Equal, Greater, Less}; +use std::sync::atomic::{AtomicUsize, Ordering::Relaxed}; macro_rules! sort { ($f:ident, $name:ident) => { @@ -168,3 +169,48 @@ fn test_par_rchunks_exact_mut_remainder() { assert_eq!(c.take_remainder(), &[]); assert_eq!(c.len(), 2); } + +#[test] +fn slice_chunk_by() { + let v: Vec<_> = (0..1000).collect(); + assert_eq!(v[..0].par_chunk_by(|_, _| todo!()).count(), 0); + assert_eq!(v[..1].par_chunk_by(|_, _| todo!()).count(), 1); + assert_eq!(v[..2].par_chunk_by(|_, _| true).count(), 1); + assert_eq!(v[..2].par_chunk_by(|_, _| false).count(), 2); + + let count = AtomicUsize::new(0); + let par: Vec<_> = v + .par_chunk_by(|x, y| { + count.fetch_add(1, Relaxed); + (x % 10 < 3) == (y % 10 < 3) + }) + .collect(); + assert_eq!(count.into_inner(), v.len() - 1); + + let seq: Vec<_> = v.chunk_by(|x, y| (x % 10 < 3) == (y % 10 < 3)).collect(); + assert_eq!(par, seq); +} + +#[test] +fn slice_chunk_by_mut() { + let mut v: Vec<_> = (0..1000).collect(); + assert_eq!(v[..0].par_chunk_by_mut(|_, _| todo!()).count(), 0); + assert_eq!(v[..1].par_chunk_by_mut(|_, _| todo!()).count(), 1); + assert_eq!(v[..2].par_chunk_by_mut(|_, _| true).count(), 1); + assert_eq!(v[..2].par_chunk_by_mut(|_, _| false).count(), 2); + + let mut v2 = v.clone(); + let count = AtomicUsize::new(0); + let par: Vec<_> = v + .par_chunk_by_mut(|x, y| { + count.fetch_add(1, Relaxed); + (x % 10 < 3) == (y % 10 < 3) + }) + .collect(); + assert_eq!(count.into_inner(), v2.len() - 1); + + let seq: Vec<_> = v2 + .chunk_by_mut(|x, y| (x % 10 < 3) == (y % 10 < 3)) + .collect(); + assert_eq!(par, seq); +} diff --git a/tests/clones.rs b/tests/clones.rs index 1306147f5..9ffa1d131 100644 --- a/tests/clones.rs +++ b/tests/clones.rs @@ -109,6 +109,7 @@ fn clone_str() { fn clone_vec() { let v: Vec<_> = (0..1000).collect(); check(v.par_iter()); + check(v.par_chunk_by(i32::eq)); check(v.par_chunks(42)); check(v.par_chunks_exact(42)); check(v.par_rchunks(42)); diff --git a/tests/debug.rs b/tests/debug.rs index bf16a2fdd..97d89cd6a 100644 --- a/tests/debug.rs +++ b/tests/debug.rs @@ -121,6 +121,8 @@ fn debug_vec() { let mut v: Vec<_> = (0..10).collect(); check(v.par_iter()); check(v.par_iter_mut()); + check(v.par_chunk_by(i32::eq)); + check(v.par_chunk_by_mut(i32::eq)); check(v.par_chunks(42)); check(v.par_chunks_exact(42)); check(v.par_chunks_mut(42));