Skip to content

Commit

Permalink
Stop manually SIMDing in swap_nonoverlapping
Browse files Browse the repository at this point in the history
Like I previously did for `reverse`, this leaves it to LLVM to pick how to vectorize it, since it can know better the chunk size to use, compared to the "32 bytes always" approach we currently have.

It does still need logic to type-erase where appropriate, though, as while LLVM is now smart enough to vectorize over slices of things like `[u8; 4]`, it fails to do so over slices of `[u8; 3]`.

As a bonus, this also means one no longer gets the spurious `memcpy`(s?) at the end up swapping a slice of `__m256`s: <https://rust.godbolt.org/z/joofr4v8Y>
  • Loading branch information
scottmcm committed Feb 21, 2022
1 parent 73a7423 commit 8ca47d7
Show file tree
Hide file tree
Showing 6 changed files with 263 additions and 97 deletions.
43 changes: 39 additions & 4 deletions library/core/benches/slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,15 @@ fn binary_search_l3_worst_case(b: &mut Bencher) {
binary_search_worst_case(b, Cache::L3);
}

#[derive(Clone)]
struct Rgb(u8, u8, u8);

impl Rgb {
fn gen(i: usize) -> Self {
Rgb(i as u8, (i as u8).wrapping_add(7), (i as u8).wrapping_add(42))
}
}

macro_rules! rotate {
($fn:ident, $n:expr, $mapper:expr) => {
#[bench]
Expand All @@ -104,17 +113,43 @@ macro_rules! rotate {
};
}

#[derive(Clone)]
struct Rgb(u8, u8, u8);

rotate!(rotate_u8, 32, |i| i as u8);
rotate!(rotate_rgb, 32, |i| Rgb(i as u8, (i as u8).wrapping_add(7), (i as u8).wrapping_add(42)));
rotate!(rotate_rgb, 32, Rgb::gen);
rotate!(rotate_usize, 32, |i| i);
rotate!(rotate_16_usize_4, 16, |i| [i; 4]);
rotate!(rotate_16_usize_5, 16, |i| [i; 5]);
rotate!(rotate_64_usize_4, 64, |i| [i; 4]);
rotate!(rotate_64_usize_5, 64, |i| [i; 5]);

macro_rules! swap_with_slice {
($fn:ident, $n:expr, $mapper:expr) => {
#[bench]
fn $fn(b: &mut Bencher) {
let mut x = (0usize..$n).map(&$mapper).collect::<Vec<_>>();
let mut y = ($n..($n * 2)).map(&$mapper).collect::<Vec<_>>();
let mut skip = 0;
b.iter(|| {
for _ in 0..32 {
x[skip..].swap_with_slice(&mut y[..($n - skip)]);
skip = black_box(skip + 1) % 8;
}
black_box((x[$n / 3].clone(), y[$n * 2 / 3].clone()))
})
}
};
}

swap_with_slice!(swap_with_slice_u8_30, 30, |i| i as u8);
swap_with_slice!(swap_with_slice_u8_3000, 3000, |i| i as u8);
swap_with_slice!(swap_with_slice_rgb_30, 30, Rgb::gen);
swap_with_slice!(swap_with_slice_rgb_3000, 3000, Rgb::gen);
swap_with_slice!(swap_with_slice_usize_30, 30, |i| i);
swap_with_slice!(swap_with_slice_usize_3000, 3000, |i| i);
swap_with_slice!(swap_with_slice_4x_usize_30, 30, |i| [i; 4]);
swap_with_slice!(swap_with_slice_4x_usize_3000, 3000, |i| [i; 4]);
swap_with_slice!(swap_with_slice_5x_usize_30, 30, |i| [i; 5]);
swap_with_slice!(swap_with_slice_5x_usize_3000, 3000, |i| [i; 5]);

#[bench]
fn fill_byte_sized(b: &mut Bencher) {
#[derive(Copy, Clone)]
Expand Down
45 changes: 42 additions & 3 deletions library/core/src/mem/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -700,10 +700,49 @@ pub unsafe fn uninitialized<T>() -> T {
#[stable(feature = "rust1", since = "1.0.0")]
#[rustc_const_unstable(feature = "const_swap", issue = "83163")]
pub const fn swap<T>(x: &mut T, y: &mut T) {
// SAFETY: the raw pointers have been created from safe mutable references satisfying all the
// constraints on `ptr::swap_nonoverlapping_one`
// NOTE(eddyb) SPIR-V's Logical addressing model doesn't allow for arbitrary
// reinterpretation of values as (chunkable) byte arrays, and the loop in the
// block optimization in `swap_slice` is hard to rewrite back
// into the (unoptimized) direct swapping implementation, so we disable it.
// FIXME(eddyb) the block optimization also prevents MIR optimizations from
// understanding `mem::replace`, `Option::take`, etc. - a better overall
// solution might be to make `ptr::swap_nonoverlapping` into an intrinsic, which
// a backend can choose to implement using the block optimization, or not.
#[cfg(not(target_arch = "spirv"))]
{
// For types that are larger multiples of their alignment, the simple way
// tends to copy the whole thing to stack rather than doing it one part
// at a time, so instead treat them as one-element slices and piggy-back
// the slice optimizations that will split up the swaps.
if size_of::<T>() / align_of::<T>() > 4 {
// SAFETY: exclusive references always point to one non-overlapping
// element and are non-null and properly aligned.
return unsafe { ptr::swap_nonoverlapping(x, y, 1) };
}
}

// If a scalar consists of just a small number of alignment units, let
// the codegen just swap those pieces directly, as it's likely just a
// few instructions and anything else is probably overcomplicated.
//
// Most importantly, this covers primitives and simd types that tend to
// have size=align where doing anything else can be a pessimization.
// (This will also be used for ZSTs, though any solution works for them.)
swap_simple(x, y);
}

/// Same as [`swap`] semantically, but always uses the simple implementation.
///
/// Used elsewhere in `mem` and `ptr` at the bottom layer of calls.
#[rustc_const_unstable(feature = "const_swap", issue = "83163")]
#[inline]
pub(crate) const fn swap_simple<T>(x: &mut T, y: &mut T) {
// SAFETY: exclusive references are always valid to read/write,
// are non-overlapping, and nothing here panics so it's drop-safe.
unsafe {
ptr::swap_nonoverlapping_one(x, y);
let z = ptr::read(x);
ptr::copy_nonoverlapping(y, x, 1);
ptr::write(y, z);
}
}

Expand Down
132 changes: 42 additions & 90 deletions library/core/src/ptr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -419,106 +419,58 @@ pub const unsafe fn swap<T>(x: *mut T, y: *mut T) {
#[stable(feature = "swap_nonoverlapping", since = "1.27.0")]
#[rustc_const_unstable(feature = "const_swap", issue = "83163")]
pub const unsafe fn swap_nonoverlapping<T>(x: *mut T, y: *mut T, count: usize) {
let x = x as *mut u8;
let y = y as *mut u8;
let len = mem::size_of::<T>() * count;
// SAFETY: the caller must guarantee that `x` and `y` are
// valid for writes and properly aligned.
unsafe { swap_nonoverlapping_bytes(x, y, len) }
}
macro_rules! attempt_swap_as_chunks {
($ChunkTy:ty) => {
if mem::align_of::<T>() >= mem::align_of::<$ChunkTy>()
&& mem::size_of::<T>() % mem::size_of::<$ChunkTy>() == 0
{
let x: *mut MaybeUninit<$ChunkTy> = x.cast();
let y: *mut MaybeUninit<$ChunkTy> = y.cast();
let count = count * (mem::size_of::<T>() / mem::size_of::<$ChunkTy>());
// SAFETY: these are the same bytes that the caller promised were
// ok, just typed as `MaybeUninit<ChunkTy>`s instead of as `T`s.
// The `if` condition above ensures that we're not violating
// alignment requirements, and that the division is exact so
// that we don't lose any bytes off the end.
return unsafe { swap_nonoverlapping_simple(x, y, count) };
}
};
}

#[inline]
#[rustc_const_unstable(feature = "const_swap", issue = "83163")]
pub(crate) const unsafe fn swap_nonoverlapping_one<T>(x: *mut T, y: *mut T) {
// NOTE(eddyb) SPIR-V's Logical addressing model doesn't allow for arbitrary
// reinterpretation of values as (chunkable) byte arrays, and the loop in the
// block optimization in `swap_nonoverlapping_bytes` is hard to rewrite back
// into the (unoptimized) direct swapping implementation, so we disable it.
// FIXME(eddyb) the block optimization also prevents MIR optimizations from
// understanding `mem::replace`, `Option::take`, etc. - a better overall
// solution might be to make `swap_nonoverlapping` into an intrinsic, which
// a backend can choose to implement using the block optimization, or not.
#[cfg(not(target_arch = "spirv"))]
// Split up the slice into small power-of-two-sized chunks that LLVM is able
// to vectorize (unless it's a special type with more-than-pointer alignment,
// because we don't want to pessimize things like slices of SIMD vectors.)
if mem::align_of::<T>() <= mem::size_of::<usize>()
&& (!mem::size_of::<T>().is_power_of_two()
|| mem::size_of::<T>() > mem::size_of::<usize>() * 2)
{
// Only apply the block optimization in `swap_nonoverlapping_bytes` for types
// at least as large as the block size, to avoid pessimizing codegen.
if mem::size_of::<T>() >= 32 {
// SAFETY: the caller must uphold the safety contract for `swap_nonoverlapping`.
unsafe { swap_nonoverlapping(x, y, 1) };
return;
}
attempt_swap_as_chunks!(usize);
attempt_swap_as_chunks!(u8);
}

// Direct swapping, for the cases not going through the block optimization.
// SAFETY: the caller must guarantee that `x` and `y` are valid
// for writes, properly aligned, and non-overlapping.
unsafe {
let z = read(x);
copy_nonoverlapping(y, x, 1);
write(y, z);
}
// SAFETY: Same preconditions as this function
unsafe { swap_nonoverlapping_simple(x, y, count) }
}

/// Same behaviour and safety conditions as [`swap_nonoverlapping`]
///
/// LLVM can vectorize this (at least it can for the power-of-two-sized types
/// `swap_nonoverlapping` tries to use) so no need to manually SIMD it.
#[inline]
#[rustc_const_unstable(feature = "const_swap", issue = "83163")]
const unsafe fn swap_nonoverlapping_bytes(x: *mut u8, y: *mut u8, len: usize) {
// The approach here is to utilize simd to swap x & y efficiently. Testing reveals
// that swapping either 32 bytes or 64 bytes at a time is most efficient for Intel
// Haswell E processors. LLVM is more able to optimize if we give a struct a
// #[repr(simd)], even if we don't actually use this struct directly.
//
// FIXME repr(simd) broken on emscripten and redox
#[cfg_attr(not(any(target_os = "emscripten", target_os = "redox")), repr(simd))]
struct Block(u64, u64, u64, u64);
struct UnalignedBlock(u64, u64, u64, u64);

let block_size = mem::size_of::<Block>();

// Loop through x & y, copying them `Block` at a time
// The optimizer should unroll the loop fully for most types
// N.B. We can't use a for loop as the `range` impl calls `mem::swap` recursively
const unsafe fn swap_nonoverlapping_simple<T>(x: *mut T, y: *mut T, count: usize) {
let mut i = 0;
while i + block_size <= len {
// Create some uninitialized memory as scratch space
// Declaring `t` here avoids aligning the stack when this loop is unused
let mut t = mem::MaybeUninit::<Block>::uninit();
let t = t.as_mut_ptr() as *mut u8;

// SAFETY: As `i < len`, and as the caller must guarantee that `x` and `y` are valid
// for `len` bytes, `x + i` and `y + i` must be valid addresses, which fulfills the
// safety contract for `add`.
//
// Also, the caller must guarantee that `x` and `y` are valid for writes, properly aligned,
// and non-overlapping, which fulfills the safety contract for `copy_nonoverlapping`.
unsafe {
let x = x.add(i);
let y = y.add(i);
while i < count {
let x: &mut T =
// SAFETY: By precondition, `i` is in-bounds because it's below `n`
unsafe { &mut *x.add(i) };
let y: &mut T =
// SAFETY: By precondition, `i` is in-bounds because it's below `n`
// and it's distinct from `x` since the ranges are non-overlapping
unsafe { &mut *y.add(i) };
mem::swap_simple(x, y);

// Swap a block of bytes of x & y, using t as a temporary buffer
// This should be optimized into efficient SIMD operations where available
copy_nonoverlapping(x, t, block_size);
copy_nonoverlapping(y, x, block_size);
copy_nonoverlapping(t, y, block_size);
}
i += block_size;
}

if i < len {
// Swap any remaining bytes
let mut t = mem::MaybeUninit::<UnalignedBlock>::uninit();
let rem = len - i;

let t = t.as_mut_ptr() as *mut u8;

// SAFETY: see previous safety comment.
unsafe {
let x = x.add(i);
let y = y.add(i);

copy_nonoverlapping(x, t, rem);
copy_nonoverlapping(y, x, rem);
copy_nonoverlapping(t, y, rem);
}
i += 1;
}
}

Expand Down
64 changes: 64 additions & 0 deletions src/test/codegen/swap-large-types.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
// compile-flags: -O
// only-x86_64
// ignore-debug: the debug assertions get in the way

#![crate_type = "lib"]

use std::mem::swap;
use std::ptr::{read, copy_nonoverlapping, write};

type KeccakBuffer = [[u64; 5]; 5];

// A basic read+copy+write swap implementation ends up copying one of the values
// to stack for large types, which is completely unnecessary as the lack of
// overlap means we can just do whatever fits in registers at a time.

// CHECK-LABEL: @swap_basic
#[no_mangle]
pub fn swap_basic(x: &mut KeccakBuffer, y: &mut KeccakBuffer) {
// CHECK: alloca [5 x [5 x i64]]

// SAFETY: exclusive references are always valid to read/write,
// are non-overlapping, and nothing here panics so it's drop-safe.
unsafe {
let z = read(x);
copy_nonoverlapping(y, x, 1);
write(y, z);
}
}

// This test verifies that the library does something smarter, and thus
// doesn't need any scratch space on the stack.

// CHECK-LABEL: @swap_std
#[no_mangle]
pub fn swap_std(x: &mut KeccakBuffer, y: &mut KeccakBuffer) {
// CHECK-NOT: alloca
// CHECK: load <{{[0-9]+}} x i64>
// CHECK: store <{{[0-9]+}} x i64>
swap(x, y)
}

// CHECK-LABEL: @swap_slice
#[no_mangle]
pub fn swap_slice(x: &mut [KeccakBuffer], y: &mut [KeccakBuffer]) {
// CHECK-NOT: alloca
// CHECK: load <{{[0-9]+}} x i64>
// CHECK: store <{{[0-9]+}} x i64>
if x.len() == y.len() {
x.swap_with_slice(y);
}
}

type OneKilobyteBuffer = [u8; 1024];

// CHECK-LABEL: @swap_1kb_slices
#[no_mangle]
pub fn swap_1kb_slices(x: &mut [OneKilobyteBuffer], y: &mut [OneKilobyteBuffer]) {
// CHECK-NOT: alloca
// CHECK: load <{{[0-9]+}} x i8>
// CHECK: store <{{[0-9]+}} x i8>
if x.len() == y.len() {
x.swap_with_slice(y);
}
}
32 changes: 32 additions & 0 deletions src/test/codegen/swap-simd-types.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// compile-flags: -O -C target-feature=+avx
// only-x86_64
// ignore-debug: the debug assertions get in the way

#![crate_type = "lib"]

use std::mem::swap;

// SIMD types are highly-aligned already, so make sure the swap code leaves their
// types alone and doesn't pessimize them (such as by swapping them as `usize`s).
extern crate core;
use core::arch::x86_64::__m256;

// CHECK-LABEL: @swap_single_m256
#[no_mangle]
pub fn swap_single_m256(x: &mut __m256, y: &mut __m256) {
// CHECK-NOT: alloca
// CHECK: load <8 x float>{{.+}}align 32
// CHECK: store <8 x float>{{.+}}align 32
swap(x, y)
}

// CHECK-LABEL: @swap_m256_slice
#[no_mangle]
pub fn swap_m256_slice(x: &mut [__m256], y: &mut [__m256]) {
// CHECK-NOT: alloca
// CHECK: load <8 x float>{{.+}}align 32
// CHECK: store <8 x float>{{.+}}align 32
if x.len() == y.len() {
x.swap_with_slice(y);
}
}
Loading

0 comments on commit 8ca47d7

Please sign in to comment.