Skip to content

Optimize core::ptr::align_offset #68616

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 56 additions & 36 deletions src/libcore/ptr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1043,50 +1043,59 @@ pub unsafe fn write_volatile<T>(dst: *mut T, src: T) {
/// Any questions go to @nagisa.
#[lang = "align_offset"]
pub(crate) unsafe fn align_offset<T: Sized>(p: *const T, a: usize) -> usize {
/// Calculate multiplicative modular inverse of `x` modulo `m`.
/// Calculate multiplicative modular inverse of `x` modulo `m`, where
/// `m = 2^mpow` and `mask = m - 1`.
///
/// This implementation is tailored for align_offset and has following preconditions:
///
/// * `m` is a power-of-two;
/// * `x < m`; (if `x ≥ m`, pass in `x % m` instead)
/// * The requested modulo `m` is a power-of-two, so `mpow` can be an argument;
/// * `x < m`; (if `x >= m`, pass in `x % m` instead)
///
/// It also sometimes leaves reducing the result modulu `m` to the caller, so the result may be
/// larger than `m`.
///
/// Implementation of this function shall not panic. Ever.
#[inline]
fn mod_inv(x: usize, m: usize) -> usize {
/// Multiplicative modular inverse table modulo 2 = 16.
fn mod_pow_2_inv(x: usize, mpow: usize, mask: usize) -> usize {
/// Multiplicative modular inverse table modulo 2^4 = 16.
///
/// Note, that this table does not contain values where inverse does not exist (i.e., for
/// `0⁻¹ mod 16`, `2⁻¹ mod 16`, etc.)
/// `0^-1 mod 16`, `2^-1 mod 16`, etc.)
const INV_TABLE_MOD_16: [u8; 8] = [1, 11, 13, 7, 9, 3, 5, 15];
/// Modulo for which the `INV_TABLE_MOD_16` is intended.
const INV_TABLE_MOD: usize = 16;
/// INV_TABLE_MOD²
const INV_TABLE_MOD_SQUARED: usize = INV_TABLE_MOD * INV_TABLE_MOD;
/// `t` such that `2^t` is the modulu for which the `INV_TABLE_MOD_16` is intended.
const INV_TABLE_MOD_POW: usize = 4;
const INV_TABLE_MOD_POW_TIMES_2: usize = INV_TABLE_MOD_POW << 1;
const INV_TABLE_MOD: usize = 1 << INV_TABLE_MOD_POW;

let table_inverse = INV_TABLE_MOD_16[(x & (INV_TABLE_MOD - 1)) >> 1] as usize;
if m <= INV_TABLE_MOD {
table_inverse & (m - 1)

if mpow <= INV_TABLE_MOD_POW {
// This is explicitly left here, as benchmarking shows this improves performance.
table_inverse & mask
} else {
// We iterate "up" using the following formula:
//
// $$ xy 1 (mod 2ⁿ) → xy (2 - xy) 1 (mod 2²ⁿ) $$
// ` xy = 1 (mod 2^n) -> xy (2 - xy) = 1 (mod 2^(2n)) `
//
// until 2²ⁿ ≥ m. Then we can reduce to our desired `m` by taking the result `mod m`.
// until 2^2n ≥ m. Then we can reduce to our desired `m` by taking the result `mod m`.
//
// Running `k` iterations starting with a solution valid mod `2^t` will get us a
// solution valid mod `2^((2^k) * t)`, so we need to calculate for which `k`,
// `2^k * t > log2(m)`.
let mut inverse = table_inverse;
let mut going_mod = INV_TABLE_MOD_SQUARED;
let mut going_modpow = INV_TABLE_MOD_POW_TIMES_2;
loop {
// y = y * (2 - xy) mod n
// y = y * (2 - xy)
//
// Note, that we use wrapping operations here intentionally the original formula
// Note, that we use wrapping operations here intentionally - the original formula
// uses e.g., subtraction `mod n`. It is entirely fine to do them `mod
// usize::max_value()` instead, because we take the result `mod n` at the end
// anyway.
inverse = inverse.wrapping_mul(2usize.wrapping_sub(x.wrapping_mul(inverse)))
& (going_mod - 1);
if going_mod > m {
return inverse & (m - 1);
inverse = inverse.wrapping_mul(2usize.wrapping_sub(x.wrapping_mul(inverse)));
if going_modpow >= mpow {
return inverse;
}
going_mod = going_mod.wrapping_mul(going_mod);
going_modpow <<= 1;
}
}
}
Expand All @@ -1112,29 +1121,40 @@ pub(crate) unsafe fn align_offset<T: Sized>(p: *const T, a: usize) -> usize {

let smoda = stride & a_minus_one;
// a is power-of-two so cannot be 0. stride = 0 is handled above.
let gcdpow = intrinsics::cttz_nonzero(stride).min(intrinsics::cttz_nonzero(a));
let apow = intrinsics::cttz_nonzero(a);
let gcdpow = intrinsics::cttz_nonzero(stride).min(apow);
let gcd = 1usize << gcdpow;

if p as usize & (gcd - 1) == 0 {
if p as usize & (gcd.wrapping_sub(1)) == 0 {
// This branch solves for the following linear congruence equation:
//
// $$ p + so 0 mod a $$
// ` p + so = 0 mod a `
//
// $p$ here is the pointer value, $s$ – stride of `T`, $o$ offset in `T`s, and $a$ – the
// `p` here is the pointer value, `s` - stride of `T`, `o` offset in `T`s, and `a` - the
// requested alignment.
//
// g = gcd(a, s)
// o = (a - (p mod a))/g * ((s/g)⁻¹ mod a)
// With `g = gcd(a, s)`, and the above asserting that `p` is also divisible by `g`, we can
// denote `a' = a/g`, `s' = s/g`, `p' = p/g`, then this becomes equivalent to:
//
// ` p' + s'o = 0 mod a' `
// ` o = (a' - (p' mod a')) * (s'^-1 mod a') `
//
// The first term is “the relative alignment of p to a”, the second term is “how does
// incrementing p by s bytes change the relative alignment of p”. Division by `g` is
// necessary to make this equation well formed if $a$ and $s$ are not co-prime.
// The first term is "the relative alignment of `p` to `a`" (divided by the `g`), the second
// term is "how does incrementing `p` by `s` bytes change the relative alignment of `p`" (again
// divided by `g`).
// Division by `g` is necessary to make the inverse well formed if `a` and `s` are not
// co-prime.
//
// Furthermore, the result produced by this solution is not “minimal”, so it is necessary
// to take the result $o mod lcm(s, a)$. We can replace $lcm(s, a)$ with just a $a / g$.
let j = a.wrapping_sub(pmoda) >> gcdpow;
let k = smoda >> gcdpow;
return intrinsics::unchecked_rem(j.wrapping_mul(mod_inv(k, a)), a >> gcdpow);
// Furthermore, the result produced by this solution is not "minimal", so it is necessary
// to take the result `o mod lcm(s, a)`. We can replace `lcm(s, a)` with just a `a'`.
let a2 = a >> gcdpow;
let a2minus1 = a2.wrapping_sub(1);
let s2 = smoda >> gcdpow;
let minusp2 = a2.wrapping_sub(pmoda >> gcdpow);
// mod_pow_2_inv returns a result which may be out of `a'`-s range, but it's fine to
// multiply modulu usize::max_value() here, and then take modulu `a'` afterwards.
return (minusp2.wrapping_mul(mod_pow_2_inv(s2, apow.wrapping_sub(gcdpow), a2minus1)))
& a2minus1;
}

// Cannot be aligned at all.
Expand Down