Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix mod_inv termination for the last iteration #103378

Merged
merged 1 commit into from
Nov 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 28 additions & 26 deletions library/core/src/ptr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1571,8 +1571,8 @@ pub(crate) unsafe fn align_offset<T: Sized>(p: *const T, a: usize) -> usize {
// FIXME(#75598): Direct use of these intrinsics improves codegen significantly at opt-level <=
// 1, where the method versions of these operations are not inlined.
use intrinsics::{
cttz_nonzero, exact_div, unchecked_rem, unchecked_shl, unchecked_shr, unchecked_sub,
wrapping_add, wrapping_mul, wrapping_sub,
cttz_nonzero, exact_div, mul_with_overflow, unchecked_rem, unchecked_shl, unchecked_shr,
unchecked_sub, wrapping_add, wrapping_mul, wrapping_sub,
};

/// Calculate multiplicative modular inverse of `x` modulo `m`.
Expand All @@ -1592,36 +1592,38 @@ pub(crate) unsafe fn align_offset<T: Sized>(p: *const T, a: usize) -> usize {
const INV_TABLE_MOD_16: [u8; 8] = [1, 11, 13, 7, 9, 3, 5, 15];
/// Modulo for which the `INV_TABLE_MOD_16` is intended.
const INV_TABLE_MOD: usize = 16;
/// INV_TABLE_MOD²
const INV_TABLE_MOD_SQUARED: usize = INV_TABLE_MOD * INV_TABLE_MOD;

let table_inverse = INV_TABLE_MOD_16[(x & (INV_TABLE_MOD - 1)) >> 1] as usize;
// SAFETY: `m` is required to be a power-of-two, hence non-zero.
let m_minus_one = unsafe { unchecked_sub(m, 1) };
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sadly unchecked is useless here today -- LLVM turns sub nuw %m, 1 into add %m, -1 during normalization :(

(Doesn't need to change here, though. I'm just sad about llvm/llvm-project#53377.)

if m <= INV_TABLE_MOD {
table_inverse & m_minus_one
} else {
// We iterate "up" using the following formula:
//
// $$ xy ≡ 1 (mod 2ⁿ) → xy (2 - xy) ≡ 1 (mod 2²ⁿ) $$
let mut inverse = INV_TABLE_MOD_16[(x & (INV_TABLE_MOD - 1)) >> 1] as usize;
let mut mod_gate = INV_TABLE_MOD;
// We iterate "up" using the following formula:
//
// $$ xy ≡ 1 (mod 2ⁿ) → xy (2 - xy) ≡ 1 (mod 2²ⁿ) $$
//
// This application needs to be applied at least until `2²ⁿ ≥ m`, at which point we can
// finally reduce the computation to our desired `m` by taking `inverse mod m`.
//
// This computation is `O(log log m)`, which is to say, that on 64-bit machines this loop
// will always finish in at most 4 iterations.
loop {
// y = y * (2 - xy) mod n
//
// until 2²ⁿ ≥ m. Then we can reduce to our desired `m` by taking the result `mod m`.
let mut inverse = table_inverse;
let mut going_mod = INV_TABLE_MOD_SQUARED;
loop {
// y = y * (2 - xy) mod n
//
// Note, that we use wrapping operations here intentionally – the original formula
// uses e.g., subtraction `mod n`. It is entirely fine to do them `mod
// usize::MAX` instead, because we take the result `mod n` at the end
// anyway.
inverse = wrapping_mul(inverse, wrapping_sub(2usize, wrapping_mul(x, inverse)));
if going_mod >= m {
return inverse & m_minus_one;
}
going_mod = wrapping_mul(going_mod, going_mod);
// Note, that we use wrapping operations here intentionally – the original formula
// uses e.g., subtraction `mod n`. It is entirely fine to do them `mod
// usize::MAX` instead, because we take the result `mod n` at the end
// anyway.
if mod_gate >= m {
break;
}
inverse = wrapping_mul(inverse, wrapping_sub(2usize, wrapping_mul(x, inverse)));
let (new_gate, overflow) = mul_with_overflow(mod_gate, mod_gate);
if overflow {
break;
}
mod_gate = new_gate;
}
inverse & m_minus_one
}

let addr = p.addr();
Expand Down
12 changes: 12 additions & 0 deletions library/core/tests/ptr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,18 @@ fn align_offset_various_strides() {
assert!(!x);
}

#[test]
fn align_offset_issue_103361() {
#[cfg(target_pointer_width = "64")]
const SIZE: usize = 1 << 47;
#[cfg(target_pointer_width = "32")]
const SIZE: usize = 1 << 30;
#[cfg(target_pointer_width = "16")]
const SIZE: usize = 1 << 13;
struct HugeSize([u8; SIZE - 1]);
let _ = (SIZE as *const HugeSize).align_offset(SIZE);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We usually prefer the strict provenance APIs in libcore -- #104632

Note sure if the lint against int2ptr casts ever got implemented? If yes we should probably enable it here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I basically just copy-pasted over the reproducer from the issue…

}

#[test]
fn offset_from() {
let mut a = [0; 5];
Expand Down