Skip to content

Commit

Permalink
Redo the swap code for better tail & padding handling
Browse files Browse the repository at this point in the history
  • Loading branch information
scottmcm committed Dec 31, 2024
1 parent 4e5fec2 commit 932f981
Show file tree
Hide file tree
Showing 12 changed files with 443 additions and 135 deletions.
17 changes: 17 additions & 0 deletions compiler/rustc_codegen_llvm/src/intrinsic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,23 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> {
}
}

sym::untyped_swap_nonoverlapping => {
// The fallback impl uses memcpy, which leaves around allocas
// that don't optimize out for certain widths, so force it to
// use SSA registers instead.

let chunk_ty = fn_args.type_at(0);
let layout = self.layout_of(chunk_ty).layout;
let integer_ty = self.type_ix(layout.size().bits());
let a = args[0].immediate();
let b = args[1].immediate();
let a_val = self.load(integer_ty, a, layout.align().abi);
let b_val = self.load(integer_ty, b, layout.align().abi);
self.store(b_val, a, layout.align().abi);
self.store(a_val, b, layout.align().abi);
return Ok(());
}

sym::compare_bytes => {
// Here we assume that the `memcmp` provided by the target is a NOP for size 0.
let cmp = self.call_intrinsic("memcmp", &[
Expand Down
6 changes: 6 additions & 0 deletions compiler/rustc_hir_analysis/src/check/intrinsic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,12 @@ pub fn check_intrinsic_type(
sym::typed_swap_nonoverlapping => {
(1, 0, vec![Ty::new_mut_ptr(tcx, param(0)); 2], tcx.types.unit)
}
sym::untyped_swap_nonoverlapping => (
1,
0,
vec![Ty::new_mut_ptr(tcx, Ty::new_maybe_uninit(tcx, param(0))); 2],
tcx.types.unit,
),

sym::discriminant_value => {
let assoc_items = tcx.associated_item_def_ids(
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_span/src/symbol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2142,6 +2142,7 @@ symbols! {
unstable location; did you mean to load this crate \
from crates.io via `Cargo.toml` instead?",
untagged_unions,
untyped_swap_nonoverlapping,
unused_imports,
unwind,
unwind_attributes,
Expand Down
34 changes: 32 additions & 2 deletions library/core/src/intrinsics/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@

use crate::marker::{DiscriminantKind, Tuple};
use crate::mem::SizedTypeProperties;
use crate::{ptr, ub_checks};
use crate::{mem, ptr, ub_checks};

pub mod fallback;
pub mod mir;
Expand Down Expand Up @@ -4003,7 +4003,37 @@ pub use typed_swap as typed_swap_nonoverlapping;
pub const unsafe fn typed_swap_nonoverlapping<T>(x: *mut T, y: *mut T) {
// SAFETY: The caller provided single non-overlapping items behind
// pointers, so swapping them with `count: 1` is fine.
unsafe { ptr::swap_nonoverlapping(x, y, 1) };
unsafe { crate::swapping::swap_nonoverlapping(x, y, 1) };
}

/// Swaps the `N` untyped & non-overlapping bytes behind the two pointers.
///
/// Split out from `typed_swap` for the internal swaps in `swap_nonoverlapping`
/// which would otherwise cause cycles between the fallback implementations on
/// backends where neither is overridden.
///
/// # Safety
///
/// `x` and `y` are readable and writable as `MaybeUninit<C>` and non-overlapping.
#[inline]
#[rustc_nounwind]
#[cfg_attr(not(bootstrap), rustc_intrinsic)]
#[miri::intrinsic_fallback_is_spec]
#[rustc_const_stable_indirect]
pub const unsafe fn untyped_swap_nonoverlapping<C>(
x: *mut mem::MaybeUninit<C>,
y: *mut mem::MaybeUninit<C>,
) {
// This intentionally uses untyped memory copies, not reads/writes,
// to avoid any risk of losing padding in things like (u16, u8).
let mut temp = mem::MaybeUninit::<C>::uninit();
// SAFETY: Caller promised that x and y are non-overlapping & read/writeable,
// and our fresh local is always disjoint from anything otherwise readable.
unsafe {
(&raw mut temp).copy_from_nonoverlapping(x, 1);
x.copy_from_nonoverlapping(y, 1);
y.copy_from_nonoverlapping(&raw const temp, 1);
}
}

/// Returns whether we should perform some UB-checking at runtime. This eventually evaluates to
Expand Down
1 change: 1 addition & 0 deletions library/core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,7 @@ pub mod alloc;
// note: does not need to be public
mod bool;
mod escape;
pub(crate) mod swapping;
mod tuple;
mod unit;

Expand Down
81 changes: 2 additions & 79 deletions library/core/src/ptr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,6 @@
#![allow(clippy::not_unsafe_ptr_arg_deref)]

use crate::cmp::Ordering;
use crate::intrinsics::const_eval_select;
use crate::marker::FnPtr;
use crate::mem::{self, MaybeUninit, SizedTypeProperties};
use crate::{fmt, hash, intrinsics, ub_checks};
Expand Down Expand Up @@ -1092,84 +1091,8 @@ pub const unsafe fn swap_nonoverlapping<T>(x: *mut T, y: *mut T, count: usize) {
}
);

const_eval_select!(
@capture[T] { x: *mut T, y: *mut T, count: usize }:
if const {
// At compile-time we want to always copy this in chunks of `T`, to ensure that if there
// are pointers inside `T` we will copy them in one go rather than trying to copy a part
// of a pointer (which would not work).
// SAFETY: Same preconditions as this function
unsafe { swap_nonoverlapping_simple_untyped(x, y, count) }
} else {
macro_rules! attempt_swap_as_chunks {
($ChunkTy:ty) => {
if mem::align_of::<T>() >= mem::align_of::<$ChunkTy>()
&& mem::size_of::<T>() % mem::size_of::<$ChunkTy>() == 0
{
let x: *mut $ChunkTy = x.cast();
let y: *mut $ChunkTy = y.cast();
let count = count * (mem::size_of::<T>() / mem::size_of::<$ChunkTy>());
// SAFETY: these are the same bytes that the caller promised were
// ok, just typed as `MaybeUninit<ChunkTy>`s instead of as `T`s.
// The `if` condition above ensures that we're not violating
// alignment requirements, and that the division is exact so
// that we don't lose any bytes off the end.
return unsafe { swap_nonoverlapping_simple_untyped(x, y, count) };
}
};
}

// Split up the slice into small power-of-two-sized chunks that LLVM is able
// to vectorize (unless it's a special type with more-than-pointer alignment,
// because we don't want to pessimize things like slices of SIMD vectors.)
if mem::align_of::<T>() <= mem::size_of::<usize>()
&& (!mem::size_of::<T>().is_power_of_two()
|| mem::size_of::<T>() > mem::size_of::<usize>() * 2)
{
attempt_swap_as_chunks!(usize);
attempt_swap_as_chunks!(u8);
}

// SAFETY: Same preconditions as this function
unsafe { swap_nonoverlapping_simple_untyped(x, y, count) }
}
)
}

/// Same behavior and safety conditions as [`swap_nonoverlapping`]
///
/// LLVM can vectorize this (at least it can for the power-of-two-sized types
/// `swap_nonoverlapping` tries to use) so no need to manually SIMD it.
#[inline]
const unsafe fn swap_nonoverlapping_simple_untyped<T>(x: *mut T, y: *mut T, count: usize) {
let x = x.cast::<MaybeUninit<T>>();
let y = y.cast::<MaybeUninit<T>>();
let mut i = 0;
while i < count {
// SAFETY: By precondition, `i` is in-bounds because it's below `n`
let x = unsafe { x.add(i) };
// SAFETY: By precondition, `i` is in-bounds because it's below `n`
// and it's distinct from `x` since the ranges are non-overlapping
let y = unsafe { y.add(i) };

// If we end up here, it's because we're using a simple type -- like
// a small power-of-two-sized thing -- or a special type with particularly
// large alignment, particularly SIMD types.
// Thus, we're fine just reading-and-writing it, as either it's small
// and that works well anyway or it's special and the type's author
// presumably wanted things to be done in the larger chunk.

// SAFETY: we're only ever given pointers that are valid to read/write,
// including being aligned, and nothing here panics so it's drop-safe.
unsafe {
let a: MaybeUninit<T> = read(x);
let b: MaybeUninit<T> = read(y);
write(x, b);
write(y, a);
}

i += 1;
}
// SAFETY: Same preconditions as this function
unsafe { crate::swapping::swap_nonoverlapping(x, y, count) }
}

/// Moves `src` into the pointed `dst`, returning the previous `dst` value.
Expand Down
182 changes: 182 additions & 0 deletions library/core/src/swapping.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
use crate::{hint, intrinsics, mem, ptr};

//#[rustc_const_stable_indirect]
//#[rustc_allow_const_fn_unstable(const_eval_select)]
#[rustc_const_unstable(feature = "const_swap_nonoverlapping", issue = "133668")]
#[inline]
pub(crate) const unsafe fn swap_nonoverlapping<T>(x: *mut T, y: *mut T, count: usize) {
intrinsics::const_eval_select!(
@capture[T] { x: *mut T, y: *mut T, count: usize }:
if const {
// At compile-time we want to always copy this in chunks of `T`, to ensure that if there
// are pointers inside `T` we will copy them in one go rather than trying to copy a part
// of a pointer (which would not work).
// SAFETY: Same preconditions as this function
unsafe { swap_nonoverlapping_const(x, y, count) }
} else {
// At runtime we want to make sure not to swap byte-for-byte for types like [u8; 15],
// and swapping as `MaybeUninit<T>` doesn't actually work as untyped for things like
// T = (u16, u8), so we type-erase to raw bytes and swap that way.
// SAFETY: Same preconditions as this function
unsafe { swap_nonoverlapping_runtime(x, y, count) }
}
)
}

/// Same behavior and safety conditions as [`swap_nonoverlapping`]
#[rustc_const_stable_indirect]
#[inline]
const unsafe fn swap_nonoverlapping_const<T>(x: *mut T, y: *mut T, count: usize) {
let x = x.cast::<mem::MaybeUninit<T>>();
let y = y.cast::<mem::MaybeUninit<T>>();
let mut i = 0;
while i < count {
// SAFETY: By precondition, `i` is in-bounds because it's below `n`
// and because the two input ranges are non-overlapping and read/writeable,
// these individual items inside them are too.
unsafe {
intrinsics::untyped_swap_nonoverlapping::<T>(x.add(i), y.add(i));
}

i += 1;
}
}

// Scale the monomorphizations with the size of the machine, roughly.
const MAX_ALIGN: usize = align_of::<usize>().pow(2);

/// Same behavior and safety conditions as [`swap_nonoverlapping`]
#[inline]
unsafe fn swap_nonoverlapping_runtime<T>(x: *mut T, y: *mut T, count: usize) {
let bytes = {
let slice = ptr::slice_from_raw_parts(x, count);
// SAFETY: Because they both exist in memory and don't overlap, they
// must be legal slice sizes (below `isize::MAX` bytes).
unsafe { mem::size_of_val_raw(slice) }
};

// Generating *untyped* loops for every type is silly, so we polymorphize away
// the actual type, but we want to take advantage of alignment if possible,
// so monomorphize for a restricted set of possible alignments.
macro_rules! delegate_by_alignment {
($($p:pat => $align:expr,)+) => {{
#![allow(unreachable_patterns)]
match const { align_of::<T>() } {
$(
$p => {
swap_nonoverlapping_bytes::<$align>(x.cast(), y.cast(), bytes);
}
)+
}
}};
}

// SAFETY:
unsafe {
delegate_by_alignment! {
MAX_ALIGN.. => MAX_ALIGN,
64.. => 64,
32.. => 32,
16.. => 16,
8.. => 8,
4.. => 4,
2.. => 2,
_ => 1,
}
}
}

/// # Safety:
/// - `x` and `y` must be aligned to `ALIGN`
/// - `bytes` must be a multiple of `ALIGN`
/// - They must be readable, writable, and non-overlapping for `bytes` bytes
#[inline]
unsafe fn swap_nonoverlapping_bytes<const ALIGN: usize>(
x: *mut mem::MaybeUninit<u8>,
y: *mut mem::MaybeUninit<u8>,
bytes: usize,
) {
// SAFETY: Two legal non-overlapping regions can't be bigger than this.
// (And they couldn't have made allocations any bigger either anyway.)
// FIXME: Would be nice to have a type for this instead of the assume.
unsafe { hint::assert_unchecked(bytes < isize::MAX as usize) };

let mut i = 0;
macro_rules! swap_next_n {
($n:expr) => {{
let x: *mut mem::MaybeUninit<[u8; $n]> = x.add(i).cast();
let y: *mut mem::MaybeUninit<[u8; $n]> = y.add(i).cast();
swap_nonoverlapping_aligned_chunk::<ALIGN, [u8; $n]>(
x.as_mut_unchecked(),
y.as_mut_unchecked(),
);
i += $n;
}};
}

while bytes - i >= MAX_ALIGN {
const { assert!(MAX_ALIGN >= ALIGN) };
// SAFETY: the const-assert above confirms we're only ever called with
// an alignment equal to or smaller than max align, so this is necessarily
// aligned, and the while loop ensures there's enough read/write memory.
unsafe {
swap_next_n!(MAX_ALIGN);
}
}

macro_rules! handle_tail {
($($n:literal)+) => {$(
if const { $n % ALIGN == 0 } {
// Checking this way simplifies the block end to just add+test,
// rather than needing extra math before the check.
if (bytes & $n) != 0 {
// SAFETY: The above swaps were bigger, so could not have
// impacted the `$n`-relevant bit, so checking `bytes & $n`
// was equivalent to `bytes - i >= $n`, and thus we have
// enough space left to swap another `$n` bytes.
unsafe {
swap_next_n!($n);
}
}
}
)+};
}
const { assert!(MAX_ALIGN <= 64) };
handle_tail!(32 16 8 4 2 1);

debug_assert_eq!(i, bytes);
}

/// Swaps the `C` behind `x` and `y` as untyped memory
///
/// # Safety
///
/// Both `x` and `y` must be aligned to `ALIGN`, in addition to their normal alignment.
/// They must be readable and writeable for `sizeof(C)` bytes, as usual for `&mut`s.
///
/// (The actual instantiations are usually `C = [u8; _]`, so we get the alignment
/// information from the loads by `assume`ing the passed-in alignment.)
// Don't let MIR inline this, because we really want it to keep its noalias metadata
#[rustc_no_mir_inline]
#[inline]
unsafe fn swap_nonoverlapping_aligned_chunk<const ALIGN: usize, C>(
x: &mut mem::MaybeUninit<C>,
y: &mut mem::MaybeUninit<C>,
) {
assert!(size_of::<C>() % ALIGN == 0);

let x = ptr::from_mut(x);
let y = ptr::from_mut(y);

// SAFETY: One of our preconditions.
unsafe {
hint::assert_unchecked(x.is_aligned_to(ALIGN));
hint::assert_unchecked(y.is_aligned_to(ALIGN));
}

// SAFETY: The memory is readable and writable because these were passed to
// us as mutable references, and the untyped swap doesn't need validity.
unsafe {
intrinsics::untyped_swap_nonoverlapping::<C>(x, y);
}
}
Loading

0 comments on commit 932f981

Please sign in to comment.