Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support ranges in <[T]>::get_many_mut() #133136

Merged
merged 1 commit into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 180 additions & 18 deletions library/core/src/slice/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ use crate::cmp::Ordering::{self, Equal, Greater, Less};
use crate::intrinsics::{exact_div, select_unpredictable, unchecked_sub};
use crate::mem::{self, SizedTypeProperties};
use crate::num::NonZero;
use crate::ops::{Bound, OneSidedRange, Range, RangeBounds};
use crate::ops::{Bound, OneSidedRange, Range, RangeBounds, RangeInclusive};
use crate::simd::{self, Simd};
use crate::ub_checks::assert_unsafe_precondition;
use crate::{fmt, hint, ptr, slice};
use crate::{fmt, hint, ptr, range, slice};

#[unstable(
feature = "slice_internals",
Expand Down Expand Up @@ -4467,6 +4467,12 @@ impl<T> [T] {

/// Returns mutable references to many indices at once, without doing any checks.
///
/// An index can be either a `usize`, a [`Range`] or a [`RangeInclusive`]. Note
/// that this method takes an array, so all indices must be of the same type.
/// If passed an array of `usize`s this method gives back an array of mutable references
/// to single elements, while if passed an array of ranges it gives back an array of
/// mutable references to slices.
///
/// For a safe alternative see [`get_many_mut`].
///
/// # Safety
Expand All @@ -4487,39 +4493,68 @@ impl<T> [T] {
/// *b *= 100;
/// }
/// assert_eq!(x, &[10, 2, 400]);
///
/// unsafe {
/// let [a, b] = x.get_many_unchecked_mut([0..1, 1..3]);
/// a[0] = 8;
/// b[0] = 88;
/// b[1] = 888;
/// }
/// assert_eq!(x, &[8, 88, 888]);
///
/// unsafe {
/// let [a, b] = x.get_many_unchecked_mut([1..=2, 0..=0]);
/// a[0] = 11;
/// a[1] = 111;
/// b[0] = 1;
/// }
/// assert_eq!(x, &[1, 11, 111]);
/// ```
///
/// [`get_many_mut`]: slice::get_many_mut
/// [undefined behavior]: https://doc.rust-lang.org/reference/behavior-considered-undefined.html
#[unstable(feature = "get_many_mut", issue = "104642")]
#[inline]
pub unsafe fn get_many_unchecked_mut<const N: usize>(
pub unsafe fn get_many_unchecked_mut<I, const N: usize>(
&mut self,
indices: [usize; N],
) -> [&mut T; N] {
indices: [I; N],
) -> [&mut I::Output; N]
where
I: GetManyMutIndex + SliceIndex<Self>,
{
// NB: This implementation is written as it is because any variation of
// `indices.map(|i| self.get_unchecked_mut(i))` would make miri unhappy,
// or generate worse code otherwise. This is also why we need to go
// through a raw pointer here.
let slice: *mut [T] = self;
let mut arr: mem::MaybeUninit<[&mut T; N]> = mem::MaybeUninit::uninit();
let mut arr: mem::MaybeUninit<[&mut I::Output; N]> = mem::MaybeUninit::uninit();
let arr_ptr = arr.as_mut_ptr();

// SAFETY: We expect `indices` to contain disjunct values that are
// in bounds of `self`.
unsafe {
for i in 0..N {
let idx = *indices.get_unchecked(i);
*(*arr_ptr).get_unchecked_mut(i) = &mut *slice.get_unchecked_mut(idx);
let idx = indices.get_unchecked(i).clone();
arr_ptr.cast::<&mut I::Output>().add(i).write(&mut *slice.get_unchecked_mut(idx));
}
arr.assume_init()
}
}

/// Returns mutable references to many indices at once.
///
/// Returns an error if any index is out-of-bounds, or if the same index was
/// passed more than once.
/// An index can be either a `usize`, a [`Range`] or a [`RangeInclusive`]. Note
/// that this method takes an array, so all indices must be of the same type.
/// If passed an array of `usize`s this method gives back an array of mutable references
/// to single elements, while if passed an array of ranges it gives back an array of
/// mutable references to slices.
///
/// Returns an error if any index is out-of-bounds, or if there are overlapping indices.
/// An empty range is not considered to overlap if it is located at the beginning or at
/// the end of another range, but is considered to overlap if it is located in the middle.
///
/// This method does a O(n^2) check to check that there are no overlapping indices, so be careful
/// when passing many indices.
///
/// # Examples
///
Expand All @@ -4532,13 +4567,30 @@ impl<T> [T] {
/// *b = 612;
/// }
/// assert_eq!(v, &[413, 2, 612]);
///
/// if let Ok([a, b]) = v.get_many_mut([0..1, 1..3]) {
/// a[0] = 8;
/// b[0] = 88;
/// b[1] = 888;
/// }
/// assert_eq!(v, &[8, 88, 888]);
///
/// if let Ok([a, b]) = v.get_many_mut([1..=2, 0..=0]) {
/// a[0] = 11;
/// a[1] = 111;
/// b[0] = 1;
/// }
/// assert_eq!(v, &[1, 11, 111]);
/// ```
#[unstable(feature = "get_many_mut", issue = "104642")]
#[inline]
pub fn get_many_mut<const N: usize>(
pub fn get_many_mut<I, const N: usize>(
&mut self,
indices: [usize; N],
) -> Result<[&mut T; N], GetManyMutError<N>> {
indices: [I; N],
) -> Result<[&mut I::Output; N], GetManyMutError<N>>
where
I: GetManyMutIndex + SliceIndex<Self>,
{
if !get_many_check_valid(&indices, self.len()) {
return Err(GetManyMutError { _private: () });
}
Expand Down Expand Up @@ -4883,14 +4935,15 @@ impl<T, const N: usize> SlicePattern for [T; N] {
///
/// This will do `binomial(N + 1, 2) = N * (N + 1) / 2 = 0, 1, 3, 6, 10, ..`
/// comparison operations.
fn get_many_check_valid<const N: usize>(indices: &[usize; N], len: usize) -> bool {
#[inline]
fn get_many_check_valid<I: GetManyMutIndex, const N: usize>(indices: &[I; N], len: usize) -> bool {
// NB: The optimizer should inline the loops into a sequence
// of instructions without additional branching.
let mut valid = true;
for (i, &idx) in indices.iter().enumerate() {
valid &= idx < len;
for &idx2 in &indices[..i] {
valid &= idx != idx2;
for (i, idx) in indices.iter().enumerate() {
valid &= idx.is_in_bounds(len);
for idx2 in &indices[..i] {
valid &= !idx.is_overlapping(idx2);
}
}
valid
Expand All @@ -4914,6 +4967,7 @@ fn get_many_check_valid<const N: usize>(indices: &[usize; N], len: usize) -> boo
#[unstable(feature = "get_many_mut", issue = "104642")]
// NB: The N here is there to be forward-compatible with adding more details
// to the error type at a later point
#[derive(Clone, PartialEq, Eq)]
pub struct GetManyMutError<const N: usize> {
_private: (),
}
Expand All @@ -4931,3 +4985,111 @@ impl<const N: usize> fmt::Display for GetManyMutError<N> {
fmt::Display::fmt("an index is out of bounds or appeared multiple times in the array", f)
}
}

mod private_get_many_mut_index {
use super::{Range, RangeInclusive, range};

#[unstable(feature = "get_many_mut_helpers", issue = "none")]
pub trait Sealed {}

#[unstable(feature = "get_many_mut_helpers", issue = "none")]
impl Sealed for usize {}
#[unstable(feature = "get_many_mut_helpers", issue = "none")]
impl Sealed for Range<usize> {}
#[unstable(feature = "get_many_mut_helpers", issue = "none")]
impl Sealed for RangeInclusive<usize> {}
#[unstable(feature = "get_many_mut_helpers", issue = "none")]
impl Sealed for range::Range<usize> {}
#[unstable(feature = "get_many_mut_helpers", issue = "none")]
impl Sealed for range::RangeInclusive<usize> {}
}

/// A helper trait for `<[T]>::get_many_mut()`.
///
/// # Safety
///
/// If `is_in_bounds()` returns `true` and `is_overlapping()` returns `false`,
/// it must be safe to index the slice with the indices.
#[unstable(feature = "get_many_mut_helpers", issue = "none")]
pub unsafe trait GetManyMutIndex: Clone + private_get_many_mut_index::Sealed {
/// Returns `true` if `self` is in bounds for `len` slice elements.
#[unstable(feature = "get_many_mut_helpers", issue = "none")]
fn is_in_bounds(&self, len: usize) -> bool;

/// Returns `true` if `self` overlaps with `other`.
///
/// Note that we don't consider zero-length ranges to overlap at the beginning or the end,
/// but do consider them to overlap in the middle.
#[unstable(feature = "get_many_mut_helpers", issue = "none")]
fn is_overlapping(&self, other: &Self) -> bool;
}

#[unstable(feature = "get_many_mut_helpers", issue = "none")]
// SAFETY: We implement `is_in_bounds()` and `is_overlapping()` correctly.
unsafe impl GetManyMutIndex for usize {
#[inline]
fn is_in_bounds(&self, len: usize) -> bool {
*self < len
}

#[inline]
fn is_overlapping(&self, other: &Self) -> bool {
*self == *other
}
}

#[unstable(feature = "get_many_mut_helpers", issue = "none")]
// SAFETY: We implement `is_in_bounds()` and `is_overlapping()` correctly.
unsafe impl GetManyMutIndex for Range<usize> {
#[inline]
fn is_in_bounds(&self, len: usize) -> bool {
(self.start <= self.end) & (self.end <= len)
}

#[inline]
fn is_overlapping(&self, other: &Self) -> bool {
(self.start < other.end) & (other.start < self.end)
}
}

#[unstable(feature = "get_many_mut_helpers", issue = "none")]
// SAFETY: We implement `is_in_bounds()` and `is_overlapping()` correctly.
unsafe impl GetManyMutIndex for RangeInclusive<usize> {
#[inline]
fn is_in_bounds(&self, len: usize) -> bool {
(self.start <= self.end) & (self.end < len)
}

#[inline]
fn is_overlapping(&self, other: &Self) -> bool {
(self.start <= other.end) & (other.start <= self.end)
}
}

#[unstable(feature = "get_many_mut_helpers", issue = "none")]
// SAFETY: We implement `is_in_bounds()` and `is_overlapping()` correctly.
unsafe impl GetManyMutIndex for range::Range<usize> {
#[inline]
fn is_in_bounds(&self, len: usize) -> bool {
Range::from(*self).is_in_bounds(len)
}

#[inline]
fn is_overlapping(&self, other: &Self) -> bool {
Range::from(*self).is_overlapping(&Range::from(*other))
}
}

#[unstable(feature = "get_many_mut_helpers", issue = "none")]
// SAFETY: We implement `is_in_bounds()` and `is_overlapping()` correctly.
unsafe impl GetManyMutIndex for range::RangeInclusive<usize> {
#[inline]
fn is_in_bounds(&self, len: usize) -> bool {
RangeInclusive::from(*self).is_in_bounds(len)
}

#[inline]
fn is_overlapping(&self, other: &Self) -> bool {
RangeInclusive::from(*self).is_overlapping(&RangeInclusive::from(*other))
}
}
70 changes: 69 additions & 1 deletion library/core/tests/slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use core::cell::Cell;
use core::cmp::Ordering;
use core::mem::MaybeUninit;
use core::num::NonZero;
use core::ops::{Range, RangeInclusive};
use core::slice;

#[test]
Expand Down Expand Up @@ -2553,6 +2554,14 @@ fn test_get_many_mut_normal_2() {
*a += 10;
*b += 100;
assert_eq!(v, vec![101, 2, 3, 14, 5]);

let [a, b] = v.get_many_mut([0..=1, 2..=2]).unwrap();
assert_eq!(a, &mut [101, 2][..]);
assert_eq!(b, &mut [3][..]);
a[0] += 10;
a[1] += 20;
b[0] += 100;
assert_eq!(v, vec![111, 22, 103, 14, 5]);
}

#[test]
Expand All @@ -2563,12 +2572,23 @@ fn test_get_many_mut_normal_3() {
*b += 100;
*c += 1000;
assert_eq!(v, vec![11, 2, 1003, 4, 105]);

let [a, b, c] = v.get_many_mut([0..1, 4..5, 1..4]).unwrap();
assert_eq!(a, &mut [11][..]);
assert_eq!(b, &mut [105][..]);
assert_eq!(c, &mut [2, 1003, 4][..]);
a[0] += 10;
b[0] += 100;
c[0] += 1000;
assert_eq!(v, vec![21, 1002, 1003, 4, 205]);
}

#[test]
fn test_get_many_mut_empty() {
let mut v = vec![1, 2, 3, 4, 5];
let [] = v.get_many_mut([]).unwrap();
let [] = v.get_many_mut::<usize, 0>([]).unwrap();
let [] = v.get_many_mut::<RangeInclusive<usize>, 0>([]).unwrap();
let [] = v.get_many_mut::<Range<usize>, 0>([]).unwrap();
assert_eq!(v, vec![1, 2, 3, 4, 5]);
}

Expand Down Expand Up @@ -2606,6 +2626,54 @@ fn test_get_many_mut_duplicate() {
assert!(v.get_many_mut([1, 3, 3, 4]).is_err());
}

#[test]
fn test_get_many_mut_range_oob() {
let mut v = vec![1, 2, 3, 4, 5];
assert!(v.get_many_mut([0..6]).is_err());
assert!(v.get_many_mut([5..6]).is_err());
assert!(v.get_many_mut([6..6]).is_err());
assert!(v.get_many_mut([0..=5]).is_err());
assert!(v.get_many_mut([0..=6]).is_err());
assert!(v.get_many_mut([5..=5]).is_err());
}

#[test]
fn test_get_many_mut_range_overlapping() {
let mut v = vec![1, 2, 3, 4, 5];
assert!(v.get_many_mut([0..1, 0..2]).is_err());
assert!(v.get_many_mut([0..1, 1..2, 0..1]).is_err());
assert!(v.get_many_mut([0..3, 1..1]).is_err());
assert!(v.get_many_mut([0..3, 1..2]).is_err());
assert!(v.get_many_mut([0..=0, 2..=2, 0..=1]).is_err());
assert!(v.get_many_mut([0..=4, 0..=0]).is_err());
assert!(v.get_many_mut([4..=4, 0..=0, 3..=4]).is_err());
}

#[test]
fn test_get_many_mut_range_empty_at_edge() {
let mut v = vec![1, 2, 3, 4, 5];
assert_eq!(
v.get_many_mut([0..0, 0..5, 5..5]),
Ok([&mut [][..], &mut [1, 2, 3, 4, 5], &mut []]),
);
assert_eq!(
v.get_many_mut([0..0, 0..1, 1..1, 1..2, 2..2, 2..3, 3..3, 3..4, 4..4, 4..5, 5..5]),
Ok([
&mut [][..],
&mut [1],
&mut [],
&mut [2],
&mut [],
&mut [3],
&mut [],
&mut [4],
&mut [],
&mut [5],
&mut [],
]),
);
}

#[test]
fn test_slice_from_raw_parts_in_const() {
static FANCY: i32 = 4;
Expand Down
Loading