Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve codegen for align_offset #75600

Merged
merged 2 commits into from
Aug 19, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 47 additions & 28 deletions library/core/src/ptr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1166,16 +1166,20 @@ pub unsafe fn write_volatile<T>(dst: *mut T, src: T) {
/// Any questions go to @nagisa.
#[lang = "align_offset"]
pub(crate) unsafe fn align_offset<T: Sized>(p: *const T, a: usize) -> usize {
// FIXME(#75598): Direct use of these intrinsics improves codegen significantly at opt-level <=
// 1, where the method versions of these operations are not inlined.
use intrinsics::{unchecked_shl, unchecked_shr, unchecked_sub, wrapping_mul, wrapping_sub};

/// Calculate multiplicative modular inverse of `x` modulo `m`.
///
/// This implementation is tailored for align_offset and has following preconditions:
/// This implementation is tailored for `align_offset` and has following preconditions:
///
/// * `m` is a power-of-two;
/// * `x < m`; (if `x ≥ m`, pass in `x % m` instead)
///
/// Implementation of this function shall not panic. Ever.
#[inline]
fn mod_inv(x: usize, m: usize) -> usize {
unsafe fn mod_inv(x: usize, m: usize) -> usize {
/// Multiplicative modular inverse table modulo 2⁴ = 16.
///
/// Note, that this table does not contain values where inverse does not exist (i.e., for
Expand All @@ -1187,8 +1191,10 @@ pub(crate) unsafe fn align_offset<T: Sized>(p: *const T, a: usize) -> usize {
const INV_TABLE_MOD_SQUARED: usize = INV_TABLE_MOD * INV_TABLE_MOD;

let table_inverse = INV_TABLE_MOD_16[(x & (INV_TABLE_MOD - 1)) >> 1] as usize;
// SAFETY: `m` is required to be a power-of-two, hence non-zero.
let m_minus_one = unsafe { unchecked_sub(m, 1) };
if m <= INV_TABLE_MOD {
table_inverse & (m - 1)
table_inverse & m_minus_one
} else {
// We iterate "up" using the following formula:
//
Expand All @@ -1204,49 +1210,50 @@ pub(crate) unsafe fn align_offset<T: Sized>(p: *const T, a: usize) -> usize {
// uses e.g., subtraction `mod n`. It is entirely fine to do them `mod
// usize::MAX` instead, because we take the result `mod n` at the end
// anyway.
inverse = inverse.wrapping_mul(2usize.wrapping_sub(x.wrapping_mul(inverse)));
inverse = wrapping_mul(inverse, wrapping_sub(2usize, wrapping_mul(x, inverse)));
if going_mod >= m {
return inverse & (m - 1);
return inverse & m_minus_one;
}
going_mod = going_mod.wrapping_mul(going_mod);
going_mod = wrapping_mul(going_mod, going_mod);
}
}
}

let stride = mem::size_of::<T>();
let a_minus_one = a.wrapping_sub(1);
let pmoda = p as usize & a_minus_one;
// SAFETY: `a` is a power-of-two, therefore non-zero.
let a_minus_one = unsafe { unchecked_sub(a, 1) };
if stride == 1 {
// `stride == 1` case can be computed more efficiently through `-p (mod a)`.
return wrapping_sub(0, p as usize) & a_minus_one;
}

let pmoda = p as usize & a_minus_one;
if pmoda == 0 {
// Already aligned. Yay!
return 0;
}

if stride <= 1 {
return if stride == 0 {
// If the pointer is not aligned, and the element is zero-sized, then no amount of
// elements will ever align the pointer.
!0
} else {
a.wrapping_sub(pmoda)
};
} else if stride == 0 {
// If the pointer is not aligned, and the element is zero-sized, then no amount of
// elements will ever align the pointer.
return usize::MAX;
}

let smoda = stride & a_minus_one;
// SAFETY: a is power-of-two so cannot be 0. stride = 0 is handled above.
// SAFETY: a is power-of-two hence non-zero. stride == 0 case is handled above.
let gcdpow = unsafe { intrinsics::cttz_nonzero(stride).min(intrinsics::cttz_nonzero(a)) };
let gcd = 1usize << gcdpow;
// SAFETY: gcdpow has an upper-bound that’s at most the number of bits in an usize.
let gcd = unsafe { unchecked_shl(1usize, gcdpow) };

if p as usize & (gcd.wrapping_sub(1)) == 0 {
// SAFETY: gcd is always greater or equal to 1.
if p as usize & unsafe { unchecked_sub(gcd, 1) } == 0 {
// This branch solves for the following linear congruence equation:
//
// ` p + so = 0 mod a `
//
// `p` here is the pointer value, `s` - stride of `T`, `o` offset in `T`s, and `a` - the
// requested alignment.
//
// With `g = gcd(a, s)`, and the above asserting that `p` is also divisible by `g`, we can
// denote `a' = a/g`, `s' = s/g`, `p' = p/g`, then this becomes equivalent to:
// With `g = gcd(a, s)`, and the above condition asserting that `p` is also divisible by
// `g`, we can denote `a' = a/g`, `s' = s/g`, `p' = p/g`, then this becomes equivalent to:
//
// ` p' + s'o = 0 mod a' `
// ` o = (a' - (p' mod a')) * (s'^-1 mod a') `
Expand All @@ -1259,11 +1266,23 @@ pub(crate) unsafe fn align_offset<T: Sized>(p: *const T, a: usize) -> usize {
//
// Furthermore, the result produced by this solution is not "minimal", so it is necessary
// to take the result `o mod lcm(s, a)`. We can replace `lcm(s, a)` with just a `a'`.
let a2 = a >> gcdpow;
let a2minus1 = a2.wrapping_sub(1);
let s2 = smoda >> gcdpow;
let minusp2 = a2.wrapping_sub(pmoda >> gcdpow);
return (minusp2.wrapping_mul(mod_inv(s2, a2))) & a2minus1;

// SAFETY: `gcdpow` has an upper-bound not greater than the number of trailing 0-bits in
// `a`.
let a2 = unsafe { unchecked_shr(a, gcdpow) };
// SAFETY: `a2` is non-zero. Shifting `a` by `gcdpow` cannot shift out any of the set bits
// in `a` (of which it has exactly one).
let a2minus1 = unsafe { unchecked_sub(a2, 1) };
// SAFETY: `gcdpow` has an upper-bound not greater than the number of trailing 0-bits in
// `a`.
let s2 = unsafe { unchecked_shr(smoda, gcdpow) };
// SAFETY: `gcdpow` has an upper-bound not greater than the number of trailing 0-bits in
// `a`. Furthermore, the subtraction cannot overflow, because `a2 = a >> gcdpow` will
// always be strictly greater than `(p % a) >> gcdpow`.
let minusp2 = unsafe { unchecked_sub(a2, unchecked_shr(pmoda, gcdpow)) };
// SAFETY: `a2` is a power-of-two, as proven above. `s2` is strictly less than `a2`
// because `(s % a) >> gcdpow` is strictly less than `a >> gcdpow`.
return wrapping_mul(minusp2, unsafe { mod_inv(s2, a2) }) & a2minus1;
}

// Cannot be aligned at all.
Expand Down