diff --git a/src/linear_group/linear_group_by.rs b/src/linear_group/linear_group_by.rs index 5ebff64..4194320 100644 --- a/src/linear_group/linear_group_by.rs +++ b/src/linear_group/linear_group_by.rs @@ -1,5 +1,21 @@ use std::iter::FusedIterator; -use std::{mem, fmt}; +use std::{mem, fmt, slice}; + +unsafe fn split_at_unchecked(slice: &[T], mid: usize) -> (&[T], &[T]) { + (slice.get_unchecked(..mid), slice.get_unchecked(mid..)) +} + +unsafe fn split_at_mut_unchecked(slice: &mut [T], mid: usize) -> (&mut [T], &mut [T]) { + // split_at_mut_unchecked + let len = slice.len(); + let ptr = slice.as_mut_ptr(); + + // SAFETY: Caller has to check that `0 <= mid <= slice.len()`. + // + // `[ptr; mid]` and `[mid; len]` are not overlapping, so returning a mutable reference + // is fine. + (slice::from_raw_parts_mut(ptr, mid), slice::from_raw_parts_mut(ptr.add(mid), len - mid)) +} pub struct LinearGroupBy<'a, T: 'a, P> { slice: &'a [T], @@ -28,7 +44,7 @@ where while let Some([l, r]) = iter.next() { if (self.predicate)(l, r) { len += 1 } else { break } } - let (head, tail) = self.slice.split_at(len); + let (head, tail) = unsafe { split_at_unchecked(self.slice, len) }; self.slice = tail; Some(head) } @@ -59,7 +75,8 @@ where while let Some([l, r]) = iter.next_back() { if (self.predicate)(l, r) { len += 1 } else { break } } - let (head, tail) = self.slice.split_at(self.slice.len() - len); + // let (head, tail) = self.slice.split_at(self.slice.len() - len); + let (head, tail) = unsafe { split_at_unchecked(self.slice, self.slice.len() - len) }; self.slice = head; Some(tail) } @@ -102,7 +119,7 @@ where if (self.predicate)(l, r) { len += 1 } else { break } } let slice = mem::take(&mut self.slice); - let (head, tail) = slice.split_at_mut(len); + let (head, tail) = unsafe { split_at_mut_unchecked(slice, len) }; self.slice = tail; Some(head) } @@ -134,7 +151,7 @@ where if (self.predicate)(l, r) { len += 1 } else { break } } let slice = mem::take(&mut self.slice); - let (head, tail) = slice.split_at_mut(slice.len() - len); + let (head, tail) = unsafe { split_at_mut_unchecked(slice, slice.len() - len) }; self.slice = head; Some(tail) }