Skip to content

Faster fmt::Display of 128-bit integers, without unsafe pointer #136594

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
286 changes: 132 additions & 154 deletions library/core/src/fmt/num.rs
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ macro_rules! impl_Display {
// Format per two digits from the lookup table.
if remain > 9 {
// SAFETY: All of the decimals fit in buf due to MAX_DEC_N
// and the while condition ensures at least 2 more decimals.
// and the if condition ensures at least 2 more decimals.
unsafe { core::hint::assert_unchecked(offset >= 2) }
// SAFETY: The offset counts down from its initial buf.len()
// without underflow due to the previous precondition.
Expand Down Expand Up @@ -555,93 +555,6 @@ mod imp {
}
impl_Exp!(i128, u128 as u128 via to_u128 named exp_u128);

/// Helper function for writing a u64 into `buf` going from last to first, with `curr`.
fn parse_u64_into<const N: usize>(mut n: u64, buf: &mut [MaybeUninit<u8>; N], curr: &mut usize) {
let buf_ptr = MaybeUninit::slice_as_mut_ptr(buf);
let lut_ptr = DEC_DIGITS_LUT.as_ptr();
assert!(*curr > 19);

// SAFETY:
// Writes at most 19 characters into the buffer. Guaranteed that any ptr into LUT is at most
// 198, so will never OOB. There is a check above that there are at least 19 characters
// remaining.
unsafe {
if n >= 1e16 as u64 {
let to_parse = n % 1e16 as u64;
n /= 1e16 as u64;

// Some of these are nops but it looks more elegant this way.
let d1 = ((to_parse / 1e14 as u64) % 100) << 1;
let d2 = ((to_parse / 1e12 as u64) % 100) << 1;
let d3 = ((to_parse / 1e10 as u64) % 100) << 1;
let d4 = ((to_parse / 1e8 as u64) % 100) << 1;
let d5 = ((to_parse / 1e6 as u64) % 100) << 1;
let d6 = ((to_parse / 1e4 as u64) % 100) << 1;
let d7 = ((to_parse / 1e2 as u64) % 100) << 1;
let d8 = ((to_parse / 1e0 as u64) % 100) << 1;

*curr -= 16;

ptr::copy_nonoverlapping(lut_ptr.add(d1 as usize), buf_ptr.add(*curr + 0), 2);
ptr::copy_nonoverlapping(lut_ptr.add(d2 as usize), buf_ptr.add(*curr + 2), 2);
ptr::copy_nonoverlapping(lut_ptr.add(d3 as usize), buf_ptr.add(*curr + 4), 2);
ptr::copy_nonoverlapping(lut_ptr.add(d4 as usize), buf_ptr.add(*curr + 6), 2);
ptr::copy_nonoverlapping(lut_ptr.add(d5 as usize), buf_ptr.add(*curr + 8), 2);
ptr::copy_nonoverlapping(lut_ptr.add(d6 as usize), buf_ptr.add(*curr + 10), 2);
ptr::copy_nonoverlapping(lut_ptr.add(d7 as usize), buf_ptr.add(*curr + 12), 2);
ptr::copy_nonoverlapping(lut_ptr.add(d8 as usize), buf_ptr.add(*curr + 14), 2);
}
if n >= 1e8 as u64 {
let to_parse = n % 1e8 as u64;
n /= 1e8 as u64;

// Some of these are nops but it looks more elegant this way.
let d1 = ((to_parse / 1e6 as u64) % 100) << 1;
let d2 = ((to_parse / 1e4 as u64) % 100) << 1;
let d3 = ((to_parse / 1e2 as u64) % 100) << 1;
let d4 = ((to_parse / 1e0 as u64) % 100) << 1;
*curr -= 8;

ptr::copy_nonoverlapping(lut_ptr.add(d1 as usize), buf_ptr.add(*curr + 0), 2);
ptr::copy_nonoverlapping(lut_ptr.add(d2 as usize), buf_ptr.add(*curr + 2), 2);
ptr::copy_nonoverlapping(lut_ptr.add(d3 as usize), buf_ptr.add(*curr + 4), 2);
ptr::copy_nonoverlapping(lut_ptr.add(d4 as usize), buf_ptr.add(*curr + 6), 2);
}
// `n` < 1e8 < (1 << 32)
let mut n = n as u32;
if n >= 1e4 as u32 {
let to_parse = n % 1e4 as u32;
n /= 1e4 as u32;

let d1 = (to_parse / 100) << 1;
let d2 = (to_parse % 100) << 1;
*curr -= 4;

ptr::copy_nonoverlapping(lut_ptr.add(d1 as usize), buf_ptr.add(*curr + 0), 2);
ptr::copy_nonoverlapping(lut_ptr.add(d2 as usize), buf_ptr.add(*curr + 2), 2);
}

// `n` < 1e4 < (1 << 16)
let mut n = n as u16;
if n >= 100 {
let d1 = (n % 100) << 1;
n /= 100;
*curr -= 2;
ptr::copy_nonoverlapping(lut_ptr.add(d1 as usize), buf_ptr.add(*curr), 2);
}

// decode last 1 or 2 chars
if n < 10 {
*curr -= 1;
*buf_ptr.add(*curr) = (n as u8) + b'0';
} else {
let d1 = n << 1;
*curr -= 2;
ptr::copy_nonoverlapping(lut_ptr.add(d1 as usize), buf_ptr.add(*curr), 2);
}
}
}

#[stable(feature = "rust1", since = "1.0.0")]
impl fmt::Display for u128 {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
Expand All @@ -652,90 +565,155 @@ impl fmt::Display for u128 {
#[stable(feature = "rust1", since = "1.0.0")]
impl fmt::Display for i128 {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let is_nonnegative = *self >= 0;
let n = if is_nonnegative {
self.to_u128()
} else {
// convert the negative num to positive by summing 1 to its 2s complement
(!self.to_u128()).wrapping_add(1)
};
fmt_u128(n, is_nonnegative, f)
fmt_u128(self.unsigned_abs(), *self >= 0, f)
}
}

/// Specialized optimization for u128. Instead of taking two items at a time, it splits
/// into at most 2 u64s, and then chunks by 10e16, 10e8, 10e4, 10e2, and then 10e1.
/// It also has to handle 1 last item, as 10^40 > 2^128 > 10^39, whereas
/// 10^20 > 2^64 > 10^19.
/// Format optimized for u128. Computation of 128 bits is limited by proccessing
/// in batches of 16 decimals at a time.
fn fmt_u128(n: u128, is_nonnegative: bool, f: &mut fmt::Formatter<'_>) -> fmt::Result {
// 2^128 is about 3*10^38, so 39 gives an extra byte of space
let mut buf = [MaybeUninit::<u8>::uninit(); 39];
let mut curr = buf.len();

let (n, rem) = udiv_1e19(n);
parse_u64_into(rem, &mut buf, &mut curr);

if n != 0 {
// 0 pad up to point
let target = buf.len() - 19;
// SAFETY: Guaranteed that we wrote at most 19 bytes, and there must be space
// remaining since it has length 39
unsafe {
ptr::write_bytes(
MaybeUninit::slice_as_mut_ptr(&mut buf).add(target),
b'0',
curr - target,
);
}
curr = target;

let (n, rem) = udiv_1e19(n);
parse_u64_into(rem, &mut buf, &mut curr);
// Should this following branch be annotated with unlikely?
if n != 0 {
let target = buf.len() - 38;
// The raw `buf_ptr` pointer is only valid until `buf` is used the next time,
// buf `buf` is not used in this scope so we are good.
let buf_ptr = MaybeUninit::slice_as_mut_ptr(&mut buf);
// SAFETY: At this point we wrote at most 38 bytes, pad up to that point,
// There can only be at most 1 digit remaining.
unsafe {
ptr::write_bytes(buf_ptr.add(target), b'0', curr - target);
curr = target - 1;
*buf_ptr.add(curr) = (n as u8) + b'0';
}
// Optimize common-case zero, which would also need special treatment due to
// its "leading" zero.
if n == 0 {
return f.pad_integral(true, "", "0");
}

// U128::MAX has 39 significant-decimals.
const MAX_DEC_N: usize = 39;
// Buffer decimals with right alignment.
let mut buf = [MaybeUninit::<u8>::uninit(); MAX_DEC_N];
// Count the number of bytes in buf that are not initialized.
let mut offset = buf.len();

// Take the 16 least-significant decimals.
let (n, mod_1e16) = div_rem_1e16(n);
let mut remain = if n == 0 {
mod_1e16
} else {
// write buf[23..39]
enc_16lsd::<23>(&mut buf, mod_1e16);
offset = 23;

// Take another 16 decimals.
let (n, mod_1e16) = div_rem_1e16(n);
if n == 0 {
mod_1e16
} else {
// write buf[7..23]
enc_16lsd::<7>(&mut buf, mod_1e16);
offset = 7;

debug_assert!(n < 10);
n as u64
}
};

// Format per four digits from the lookup table.
while remain > 999 {
// SAFETY: All of the decimals fit in buf due to MAX_DEC_N
// and the while condition ensures at least 4 more decimals.
unsafe { core::hint::assert_unchecked(offset >= 4) }
// SAFETY: The offset counts down from its initial buf.len()
// without underflow due to the previous precondition.
unsafe { core::hint::assert_unchecked(offset <= buf.len()) }
offset -= 4;

// pull two pairs
let quad = remain % 1_00_00;
remain /= 1_00_00;
let pair1 = (quad / 100) as usize;
let pair2 = (quad % 100) as usize;
buf[offset + 0].write(DEC_DIGITS_LUT[pair1 * 2 + 0]);
buf[offset + 1].write(DEC_DIGITS_LUT[pair1 * 2 + 1]);
buf[offset + 2].write(DEC_DIGITS_LUT[pair2 * 2 + 0]);
buf[offset + 3].write(DEC_DIGITS_LUT[pair2 * 2 + 1]);
}

// Format per two digits from the lookup table.
if remain > 9 {
// SAFETY: All of the decimals fit in buf due to MAX_DEC_N
// and the if condition ensures at least 2 more decimals.
unsafe { core::hint::assert_unchecked(offset >= 2) }
// SAFETY: The offset counts down from its initial buf.len()
// without underflow due to the previous precondition.
unsafe { core::hint::assert_unchecked(offset <= buf.len()) }
offset -= 2;

let pair = (remain % 100) as usize;
remain /= 100;
buf[offset + 0].write(DEC_DIGITS_LUT[pair * 2 + 0]);
buf[offset + 1].write(DEC_DIGITS_LUT[pair * 2 + 1]);
}

// SAFETY: `curr` > 0 (since we made `buf` large enough), and all the chars are valid
// UTF-8 since `DEC_DIGITS_LUT` is
let buf_slice = unsafe {
// Format the last remaining digit, if any.
if remain != 0 {
// SAFETY: All of the decimals fit in buf due to MAX_DEC_N
// and the if condition ensures (at least) 1 more decimals.
unsafe { core::hint::assert_unchecked(offset >= 1) }
// SAFETY: The offset counts down from its initial buf.len()
// without underflow due to the previous precondition.
unsafe { core::hint::assert_unchecked(offset <= buf.len()) }
offset -= 1;

// Either the compiler sees that remain < 10, or it prevents
// a boundary check up next.
let last = (remain & 15) as usize;
buf[offset].write(DEC_DIGITS_LUT[last * 2 + 1]);
// not used: remain = 0;
}

// SAFETY: All buf content since offset is set.
let written = unsafe { buf.get_unchecked(offset..) };
// SAFETY: Writes use ASCII from the lookup table exclusively.
let as_str = unsafe {
str::from_utf8_unchecked(slice::from_raw_parts(
MaybeUninit::slice_as_mut_ptr(&mut buf).add(curr),
buf.len() - curr,
MaybeUninit::slice_as_ptr(written),
written.len(),
))
};
f.pad_integral(is_nonnegative, "", buf_slice)
f.pad_integral(is_nonnegative, "", as_str)
}

/// Encodes the 16 least significant decimals of n into buf.
fn enc_16lsd<const OFFSET: usize>(buf: &mut [MaybeUninit<u8>; 39], n: u64) {
// Consume the least-significant decimals from a working copy.
let mut remain = n;

// Format per four digits from the lookup table.
for quad_index in (0..4).rev() {
// pull two pairs
let quad = remain % 1_00_00;
remain /= 1_00_00;
let pair1 = (quad / 100) as usize;
let pair2 = (quad % 100) as usize;
buf[quad_index * 4 + OFFSET + 0].write(DEC_DIGITS_LUT[pair1 * 2 + 0]);
buf[quad_index * 4 + OFFSET + 1].write(DEC_DIGITS_LUT[pair1 * 2 + 1]);
buf[quad_index * 4 + OFFSET + 2].write(DEC_DIGITS_LUT[pair2 * 2 + 0]);
buf[quad_index * 4 + OFFSET + 3].write(DEC_DIGITS_LUT[pair2 * 2 + 1]);
}
}

/// Partition of `n` into n > 1e19 and rem <= 1e19
/// Euclidean division plus remainder with constant 1E16 basically consumes 16
/// decimals from n.
///
/// Integer division algorithm is based on the following paper:
/// The integer division algorithm is based on the following paper:
///
/// T. Granlund and P. Montgomery, “Division by Invariant Integers Using Multiplication”
/// in Proc. of the SIGPLAN94 Conference on Programming Language Design and
/// Implementation, 1994, pp. 61–72
///
fn udiv_1e19(n: u128) -> (u128, u64) {
const DIV: u64 = 1e19 as u64;
const FACTOR: u128 = 156927543384667019095894735580191660403;
#[inline]
fn div_rem_1e16(n: u128) -> (u128, u64) {
const D: u128 = 1_0000_0000_0000_0000;
if n < D {
return (0, n as u64);
}

let quot = if n < 1 << 83 {
((n >> 19) as u64 / (DIV >> 19)) as u128
} else {
n.widening_mul(FACTOR).1 >> 62
};
// These constant values are computed with the CHOOSE_MULTIPLIER procedure.
const M_HIGH: u128 = 76624777043294442917917351357515459181;
const SH_POST: u8 = 51;

let rem = (n - quot * DIV as u128) as u64;
(quot, rem)
let quot = n.widening_mul(M_HIGH).1 >> SH_POST;
let rem = n - quot * D;
(quot, rem as u64)
}
Loading
Loading