Skip to content

Commit

Permalink
Merge pull request #1120 from cuviper/split_inclusive
Browse files Browse the repository at this point in the history
Add inclusive splits on strings and slices
  • Loading branch information
cuviper authored Jan 27, 2024
2 parents ba8e1a1 + 06b546c commit 2734f99
Show file tree
Hide file tree
Showing 7 changed files with 307 additions and 35 deletions.
39 changes: 39 additions & 0 deletions src/iter/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -983,6 +983,25 @@ fn check_slice_split() {
assert_eq!(v, &[&slice[..1], &slice[..0], &slice[3..]]);
}

#[test]
fn check_slice_split_inclusive() {
let v: Vec<_> = (0..1000).collect();
for m in 1..100 {
let a: Vec<_> = v.split_inclusive(|x| x % m == 0).collect();
let b: Vec<_> = v.par_split_inclusive(|x| x % m == 0).collect();
assert_eq!(a, b);
}

// same as std::slice::split_inclusive() examples
let slice = [10, 40, 33, 20];
let v: Vec<_> = slice.par_split_inclusive(|num| num % 3 == 0).collect();
assert_eq!(v, &[&slice[..3], &slice[3..]]);

let slice = [3, 10, 40, 33];
let v: Vec<_> = slice.par_split_inclusive(|num| num % 3 == 0).collect();
assert_eq!(v, &[&slice[..1], &slice[1..]]);
}

#[test]
fn check_slice_split_mut() {
let mut v1: Vec<_> = (0..1000).collect();
Expand All @@ -1001,6 +1020,26 @@ fn check_slice_split_mut() {
assert_eq!(v, [1, 40, 30, 1, 60, 1]);
}

#[test]
fn check_slice_split_inclusive_mut() {
let mut v1: Vec<_> = (0..1000).collect();
let mut v2 = v1.clone();
for m in 1..100 {
let a: Vec<_> = v1.split_inclusive_mut(|x| x % m == 0).collect();
let b: Vec<_> = v2.par_split_inclusive_mut(|x| x % m == 0).collect();
assert_eq!(a, b);
}

// same as std::slice::split_inclusive_mut() example
let mut v = [10, 40, 30, 20, 60, 50];
v.par_split_inclusive_mut(|num| num % 3 == 0)
.for_each(|group| {
let terminator_idx = group.len() - 1;
group[terminator_idx] = 1;
});
assert_eq!(v, [10, 40, 1, 20, 1, 1]);
}

#[test]
fn check_chunks() {
let a: Vec<i32> = vec![1, 5, 10, 4, 100, 3, 1000, 2, 10000, 1];
Expand Down
176 changes: 156 additions & 20 deletions src/slice/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,11 @@ pub trait ParallelSlice<T: Sync> {
///
/// ```
/// use rayon::prelude::*;
/// let smallest = [1, 2, 3, 0, 2, 4, 8, 0, 3, 6, 9]
/// let products: Vec<_> = [1, 2, 3, 0, 2, 4, 8, 0, 3, 6, 9]
/// .par_split(|i| *i == 0)
/// .map(|numbers| numbers.iter().min().unwrap())
/// .min();
/// assert_eq!(Some(&1), smallest);
/// .map(|numbers| numbers.iter().product::<i32>())
/// .collect();
/// assert_eq!(products, [6, 64, 162]);
/// ```
fn par_split<P>(&self, separator: P) -> Split<'_, T, P>
where
Expand All @@ -54,6 +54,29 @@ pub trait ParallelSlice<T: Sync> {
}
}

/// Returns a parallel iterator over subslices separated by elements that
/// match the separator, including the matched part as a terminator.
///
/// # Examples
///
/// ```
/// use rayon::prelude::*;
/// let lengths: Vec<_> = [1, 2, 3, 0, 2, 4, 8, 0, 3, 6, 9]
/// .par_split_inclusive(|i| *i == 0)
/// .map(|numbers| numbers.len())
/// .collect();
/// assert_eq!(lengths, [4, 4, 3]);
/// ```
fn par_split_inclusive<P>(&self, separator: P) -> SplitInclusive<'_, T, P>
where
P: Fn(&T) -> bool + Sync + Send,
{
SplitInclusive {
slice: self.as_parallel_slice(),
separator,
}
}

/// Returns a parallel iterator over all contiguous windows of length
/// `window_size`. The windows overlap.
///
Expand Down Expand Up @@ -187,6 +210,28 @@ pub trait ParallelSliceMut<T: Send> {
}
}

/// Returns a parallel iterator over mutable subslices separated by elements
/// that match the separator, including the matched part as a terminator.
///
/// # Examples
///
/// ```
/// use rayon::prelude::*;
/// let mut array = [1, 2, 3, 0, 2, 4, 8, 0, 3, 6, 9];
/// array.par_split_inclusive_mut(|i| *i == 0)
/// .for_each(|slice| slice.reverse());
/// assert_eq!(array, [0, 3, 2, 1, 0, 8, 4, 2, 9, 6, 3]);
/// ```
fn par_split_inclusive_mut<P>(&mut self, separator: P) -> SplitInclusiveMut<'_, T, P>
where
P: Fn(&T) -> bool + Sync + Send,
{
SplitInclusiveMut {
slice: self.as_parallel_slice_mut(),
separator,
}
}

/// Returns a parallel iterator over at most `chunk_size` elements of
/// `self` at a time. The chunks are mutable and do not overlap.
///
Expand Down Expand Up @@ -932,6 +977,46 @@ where
}
}

/// Parallel iterator over slices separated by a predicate,
/// including the matched part as a terminator.
pub struct SplitInclusive<'data, T, P> {
slice: &'data [T],
separator: P,
}

impl<'data, T, P: Clone> Clone for SplitInclusive<'data, T, P> {
fn clone(&self) -> Self {
SplitInclusive {
separator: self.separator.clone(),
..*self
}
}
}

impl<'data, T: Debug, P> Debug for SplitInclusive<'data, T, P> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("SplitInclusive")
.field("slice", &self.slice)
.finish()
}
}

impl<'data, T, P> ParallelIterator for SplitInclusive<'data, T, P>
where
P: Fn(&T) -> bool + Sync + Send,
T: Sync,
{
type Item = &'data [T];

fn drive_unindexed<C>(self, consumer: C) -> C::Result
where
C: UnindexedConsumer<Self::Item>,
{
let producer = SplitInclusiveProducer::new_incl(self.slice, &self.separator);
bridge_unindexed(producer, consumer)
}
}

/// Implement support for `SplitProducer`.
impl<'data, T, P> Fissile<P> for &'data [T]
where
Expand All @@ -953,21 +1038,31 @@ where
self[..end].iter().rposition(separator)
}

fn split_once(self, index: usize) -> (Self, Self) {
let (left, right) = self.split_at(index);
(left, &right[1..]) // skip the separator
fn split_once<const INCL: bool>(self, index: usize) -> (Self, Self) {
if INCL {
// include the separator in the left side
self.split_at(index + 1)
} else {
let (left, right) = self.split_at(index);
(left, &right[1..]) // skip the separator
}
}

fn fold_splits<F>(self, separator: &P, folder: F, skip_last: bool) -> F
fn fold_splits<F, const INCL: bool>(self, separator: &P, folder: F, skip_last: bool) -> F
where
F: Folder<Self>,
Self: Send,
{
let mut split = self.split(separator);
if skip_last {
split.next_back();
if INCL {
debug_assert!(!skip_last);
folder.consume_iter(self.split_inclusive(separator))
} else {
let mut split = self.split(separator);
if skip_last {
split.next_back();
}
folder.consume_iter(split)
}
folder.consume_iter(split)
}
}

Expand Down Expand Up @@ -1001,6 +1096,37 @@ where
}
}

/// Parallel iterator over mutable slices separated by a predicate,
/// including the matched part as a terminator.
pub struct SplitInclusiveMut<'data, T, P> {
slice: &'data mut [T],
separator: P,
}

impl<'data, T: Debug, P> Debug for SplitInclusiveMut<'data, T, P> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("SplitInclusiveMut")
.field("slice", &self.slice)
.finish()
}
}

impl<'data, T, P> ParallelIterator for SplitInclusiveMut<'data, T, P>
where
P: Fn(&T) -> bool + Sync + Send,
T: Send,
{
type Item = &'data mut [T];

fn drive_unindexed<C>(self, consumer: C) -> C::Result
where
C: UnindexedConsumer<Self::Item>,
{
let producer = SplitInclusiveProducer::new_incl(self.slice, &self.separator);
bridge_unindexed(producer, consumer)
}
}

/// Implement support for `SplitProducer`.
impl<'data, T, P> Fissile<P> for &'data mut [T]
where
Expand All @@ -1022,20 +1148,30 @@ where
self[..end].iter().rposition(separator)
}

fn split_once(self, index: usize) -> (Self, Self) {
let (left, right) = self.split_at_mut(index);
(left, &mut right[1..]) // skip the separator
fn split_once<const INCL: bool>(self, index: usize) -> (Self, Self) {
if INCL {
// include the separator in the left side
self.split_at_mut(index + 1)
} else {
let (left, right) = self.split_at_mut(index);
(left, &mut right[1..]) // skip the separator
}
}

fn fold_splits<F>(self, separator: &P, folder: F, skip_last: bool) -> F
fn fold_splits<F, const INCL: bool>(self, separator: &P, folder: F, skip_last: bool) -> F
where
F: Folder<Self>,
Self: Send,
{
let mut split = self.split_mut(separator);
if skip_last {
split.next_back();
if INCL {
debug_assert!(!skip_last);
folder.consume_iter(self.split_inclusive_mut(separator))
} else {
let mut split = self.split_mut(separator);
if skip_last {
split.next_back();
}
folder.consume_iter(split)
}
folder.consume_iter(split)
}
}
36 changes: 28 additions & 8 deletions src/split_producer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,24 @@
use crate::iter::plumbing::{Folder, UnindexedProducer};

/// Common producer for splitting on a predicate.
pub(super) struct SplitProducer<'p, P, V> {
pub(super) struct SplitProducer<'p, P, V, const INCL: bool = false> {
data: V,
separator: &'p P,

/// Marks the endpoint beyond which we've already found no separators.
tail: usize,
}

pub(super) type SplitInclusiveProducer<'p, P, V> = SplitProducer<'p, P, V, true>;

/// Helper trait so `&str`, `&[T]`, and `&mut [T]` can share `SplitProducer`.
pub(super) trait Fissile<P>: Sized {
fn length(&self) -> usize;
fn midpoint(&self, end: usize) -> usize;
fn find(&self, separator: &P, start: usize, end: usize) -> Option<usize>;
fn rfind(&self, separator: &P, end: usize) -> Option<usize>;
fn split_once(self, index: usize) -> (Self, Self);
fn fold_splits<F>(self, separator: &P, folder: F, skip_last: bool) -> F
fn split_once<const INCL: bool>(self, index: usize) -> (Self, Self);
fn fold_splits<F, const INCL: bool>(self, separator: &P, folder: F, skip_last: bool) -> F
where
F: Folder<Self>,
Self: Send;
Expand All @@ -37,7 +39,25 @@ where
separator,
}
}
}

impl<'p, P, V> SplitInclusiveProducer<'p, P, V>
where
V: Fissile<P> + Send,
{
pub(super) fn new_incl(data: V, separator: &'p P) -> Self {
SplitProducer {
tail: data.length(),
data,
separator,
}
}
}

impl<'p, P, V, const INCL: bool> SplitProducer<'p, P, V, INCL>
where
V: Fissile<P> + Send,
{
/// Common `fold_with` implementation, integrating `SplitTerminator`'s
/// need to sometimes skip its final empty item.
pub(super) fn fold_with<F>(self, folder: F, skip_last: bool) -> F
Expand All @@ -52,12 +72,12 @@ where

if tail == data.length() {
// No tail section, so just let `fold_splits` handle it.
data.fold_splits(separator, folder, skip_last)
data.fold_splits::<F, INCL>(separator, folder, skip_last)
} else if let Some(index) = data.rfind(separator, tail) {
// We found the last separator to complete the tail, so
// end with that slice after `fold_splits` finds the rest.
let (left, right) = data.split_once(index);
let folder = left.fold_splits(separator, folder, false);
let (left, right) = data.split_once::<INCL>(index);
let folder = left.fold_splits::<F, INCL>(separator, folder, false);
if skip_last || folder.full() {
folder
} else {
Expand All @@ -74,7 +94,7 @@ where
}
}

impl<'p, P, V> UnindexedProducer for SplitProducer<'p, P, V>
impl<'p, P, V, const INCL: bool> UnindexedProducer for SplitProducer<'p, P, V, INCL>
where
V: Fissile<P> + Send,
P: Sync,
Expand All @@ -91,7 +111,7 @@ where

if let Some(index) = index {
let len = self.data.length();
let (left, right) = self.data.split_once(index);
let (left, right) = self.data.split_once::<INCL>(index);

let (left_tail, right_tail) = if index < mid {
// If we scanned backwards to find the separator, everything in
Expand Down
Loading

0 comments on commit 2734f99

Please sign in to comment.