Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add add/sub methods that only panic with debug assertions to rustc #123175

Merged
merged 1 commit into from
Apr 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 21 additions & 17 deletions compiler/rustc_data_structures/src/sip128.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
//! This is a copy of `core::hash::sip` adapted to providing 128 bit hashes.

// This code is very hot and uses lots of arithmetic, avoid overflow checks for performance.
// See https://github.com/rust-lang/rust/pull/119440#issuecomment-1874255727
use rustc_serialize::int_overflow::{DebugStrictAdd, DebugStrictSub};
use std::hash::Hasher;
use std::mem::{self, MaybeUninit};
use std::ptr;
Expand Down Expand Up @@ -103,19 +106,19 @@ unsafe fn copy_nonoverlapping_small(src: *const u8, dst: *mut u8, count: usize)
}

let mut i = 0;
if i + 3 < count {
if i.debug_strict_add(3) < count {
ptr::copy_nonoverlapping(src.add(i), dst.add(i), 4);
i += 4;
i = i.debug_strict_add(4);
}

if i + 1 < count {
if i.debug_strict_add(1) < count {
ptr::copy_nonoverlapping(src.add(i), dst.add(i), 2);
i += 2
i = i.debug_strict_add(2)
}

if i < count {
*dst.add(i) = *src.add(i);
i += 1;
i = i.debug_strict_add(1);
}

debug_assert_eq!(i, count);
Expand Down Expand Up @@ -211,14 +214,14 @@ impl SipHasher128 {
debug_assert!(nbuf < BUFFER_SIZE);
debug_assert!(nbuf + LEN < BUFFER_WITH_SPILL_SIZE);

if nbuf + LEN < BUFFER_SIZE {
if nbuf.debug_strict_add(LEN) < BUFFER_SIZE {
unsafe {
// The memcpy call is optimized away because the size is known.
let dst = (self.buf.as_mut_ptr() as *mut u8).add(nbuf);
ptr::copy_nonoverlapping(bytes.as_ptr(), dst, LEN);
}

self.nbuf = nbuf + LEN;
self.nbuf = nbuf.debug_strict_add(LEN);

return;
}
Expand Down Expand Up @@ -265,8 +268,9 @@ impl SipHasher128 {
// This function should only be called when the write fills the buffer.
// Therefore, when LEN == 1, the new `self.nbuf` must be zero.
// LEN is statically known, so the branch is optimized away.
self.nbuf = if LEN == 1 { 0 } else { nbuf + LEN - BUFFER_SIZE };
self.processed += BUFFER_SIZE;
self.nbuf =
if LEN == 1 { 0 } else { nbuf.debug_strict_add(LEN).debug_strict_sub(BUFFER_SIZE) };
self.processed = self.processed.debug_strict_add(BUFFER_SIZE);
}
}

Expand All @@ -277,7 +281,7 @@ impl SipHasher128 {
let nbuf = self.nbuf;
debug_assert!(nbuf < BUFFER_SIZE);

if nbuf + length < BUFFER_SIZE {
if nbuf.debug_strict_add(length) < BUFFER_SIZE {
unsafe {
let dst = (self.buf.as_mut_ptr() as *mut u8).add(nbuf);

Expand All @@ -289,7 +293,7 @@ impl SipHasher128 {
}
}

self.nbuf = nbuf + length;
self.nbuf = nbuf.debug_strict_add(length);

return;
}
Expand All @@ -315,7 +319,7 @@ impl SipHasher128 {
// This function should only be called when the write fills the buffer,
// so we know that there is enough input to fill the current element.
let valid_in_elem = nbuf % ELEM_SIZE;
let needed_in_elem = ELEM_SIZE - valid_in_elem;
let needed_in_elem = ELEM_SIZE.debug_strict_sub(valid_in_elem);

let src = msg.as_ptr();
let dst = (self.buf.as_mut_ptr() as *mut u8).add(nbuf);
Expand All @@ -327,7 +331,7 @@ impl SipHasher128 {
// ELEM_SIZE` to show the compiler that this loop's upper bound is > 0.
// We know that is true, because last step ensured we have a full
// element in the buffer.
let last = nbuf / ELEM_SIZE + 1;
let last = (nbuf / ELEM_SIZE).debug_strict_add(1);

for i in 0..last {
let elem = self.buf.get_unchecked(i).assume_init().to_le();
Expand All @@ -338,7 +342,7 @@ impl SipHasher128 {

// Process the remaining element-sized chunks of input.
let mut processed = needed_in_elem;
let input_left = length - processed;
let input_left = length.debug_strict_sub(processed);
let elems_left = input_left / ELEM_SIZE;
let extra_bytes_left = input_left % ELEM_SIZE;

Expand All @@ -347,7 +351,7 @@ impl SipHasher128 {
self.state.v3 ^= elem;
Sip13Rounds::c_rounds(&mut self.state);
self.state.v0 ^= elem;
processed += ELEM_SIZE;
processed = processed.debug_strict_add(ELEM_SIZE);
}

// Copy remaining input into start of buffer.
Expand All @@ -356,7 +360,7 @@ impl SipHasher128 {
copy_nonoverlapping_small(src, dst, extra_bytes_left);

self.nbuf = extra_bytes_left;
self.processed += nbuf + processed;
self.processed = self.processed.debug_strict_add(nbuf.debug_strict_add(processed));
}
}

Expand Down Expand Up @@ -394,7 +398,7 @@ impl SipHasher128 {
};

// Finalize the hash.
let length = self.processed + self.nbuf;
let length = self.processed.debug_strict_add(self.nbuf);
let b: u64 = ((length as u64 & 0xff) << 56) | elem;

state.v3 ^= b;
Expand Down
65 changes: 65 additions & 0 deletions compiler/rustc_serialize/src/int_overflow.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
// This would belong to `rustc_data_structures`, but `rustc_serialize` needs it too.

/// Addition, but only overflow checked when `cfg(debug_assertions)` is set
/// instead of respecting `-Coverflow-checks`.
///
/// This exists for performance reasons, as we ship rustc with overflow checks.
/// While overflow checks are perf neutral in almost all of the compiler, there
/// are a few particularly hot areas where we don't want overflow checks in our
/// dist builds. Overflow is still a bug there, so we want overflow check for
/// builds with debug assertions.
///
/// That's a long way to say that this should be used in areas where overflow
/// is a bug but overflow checking is too slow.
pub trait DebugStrictAdd {
/// See [`DebugStrictAdd`].
fn debug_strict_add(self, other: Self) -> Self;
}

macro_rules! impl_debug_strict_add {
($( $ty:ty )*) => {
$(
impl DebugStrictAdd for $ty {
fn debug_strict_add(self, other: Self) -> Self {
if cfg!(debug_assertions) {
self + other
} else {
self.wrapping_add(other)
}
}
}
)*
};
}

/// See [`DebugStrictAdd`].
pub trait DebugStrictSub {
/// See [`DebugStrictAdd`].
fn debug_strict_sub(self, other: Self) -> Self;
}

macro_rules! impl_debug_strict_sub {
($( $ty:ty )*) => {
$(
impl DebugStrictSub for $ty {
fn debug_strict_sub(self, other: Self) -> Self {
if cfg!(debug_assertions) {
self - other
} else {
self.wrapping_sub(other)
}
}
}
)*
};
}

impl_debug_strict_add! {
u8 u16 u32 u64 u128 usize
i8 i16 i32 i64 i128 isize
}

impl_debug_strict_sub! {
u8 u16 u32 u64 u128 usize
i8 i16 i32 i64 i128 isize
}
14 changes: 9 additions & 5 deletions compiler/rustc_serialize/src/leb128.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
use crate::opaque::MemDecoder;
use crate::serialize::Decoder;

// This code is very hot and uses lots of arithmetic, avoid overflow checks for performance.
// See https://github.com/rust-lang/rust/pull/119440#issuecomment-1874255727
use crate::int_overflow::DebugStrictAdd;

/// Returns the length of the longest LEB128 encoding for `T`, assuming `T` is an integer type
pub const fn max_leb128_len<T>() -> usize {
// The longest LEB128 encoding for an integer uses 7 bits per byte.
Expand All @@ -24,15 +28,15 @@ macro_rules! impl_write_unsigned_leb128 {
*out.get_unchecked_mut(i) = value as u8;
}

i += 1;
i = i.debug_strict_add(1);
break;
} else {
unsafe {
*out.get_unchecked_mut(i) = ((value & 0x7f) | 0x80) as u8;
}

value >>= 7;
i += 1;
i = i.debug_strict_add(1);
}
}

Expand Down Expand Up @@ -69,7 +73,7 @@ macro_rules! impl_read_unsigned_leb128 {
} else {
result |= ((byte & 0x7F) as $int_ty) << shift;
}
shift += 7;
shift = shift.debug_strict_add(7);
}
}
};
Expand Down Expand Up @@ -101,7 +105,7 @@ macro_rules! impl_write_signed_leb128 {
*out.get_unchecked_mut(i) = byte;
}

i += 1;
i = i.debug_strict_add(1);

if !more {
break;
Expand Down Expand Up @@ -130,7 +134,7 @@ macro_rules! impl_read_signed_leb128 {
loop {
byte = decoder.read_u8();
result |= <$int_ty>::from(byte & 0x7F) << shift;
shift += 7;
shift = shift.debug_strict_add(7);

if (byte & 0x80) == 0 {
break;
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_serialize/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,6 @@ pub use self::serialize::{Decodable, Decoder, Encodable, Encoder};

mod serialize;

pub mod int_overflow;
pub mod leb128;
pub mod opaque;
10 changes: 7 additions & 3 deletions compiler/rustc_serialize/src/opaque.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ use std::ops::Range;
use std::path::Path;
use std::path::PathBuf;

// This code is very hot and uses lots of arithmetic, avoid overflow checks for performance.
// See https://github.com/rust-lang/rust/pull/119440#issuecomment-1874255727
use crate::int_overflow::DebugStrictAdd;

// -----------------------------------------------------------------------------
// Encoder
// -----------------------------------------------------------------------------
Expand Down Expand Up @@ -65,7 +69,7 @@ impl FileEncoder {
// Tracking position this way instead of having a `self.position` field
// means that we only need to update `self.buffered` on a write call,
// as opposed to updating `self.position` and `self.buffered`.
self.flushed + self.buffered
self.flushed.debug_strict_add(self.buffered)
}

#[cold]
Expand Down Expand Up @@ -119,7 +123,7 @@ impl FileEncoder {
}
if let Some(dest) = self.buffer_empty().get_mut(..buf.len()) {
dest.copy_from_slice(buf);
self.buffered += buf.len();
self.buffered = self.buffered.debug_strict_add(buf.len());
} else {
self.write_all_cold_path(buf);
}
Expand Down Expand Up @@ -158,7 +162,7 @@ impl FileEncoder {
if written > N {
Self::panic_invalid_write::<N>(written);
}
self.buffered += written;
self.buffered = self.buffered.debug_strict_add(written);
}

#[cold]
Expand Down
8 changes: 6 additions & 2 deletions compiler/rustc_span/src/span_encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ use crate::{BytePos, SpanData};

use rustc_data_structures::fx::FxIndexSet;

// This code is very hot and uses lots of arithmetic, avoid overflow checks for performance.
// See https://github.com/rust-lang/rust/pull/119440#issuecomment-1874255727
use rustc_serialize::int_overflow::DebugStrictAdd;

/// A compressed span.
///
/// [`SpanData`] is 16 bytes, which is too big to stick everywhere. `Span` only
Expand Down Expand Up @@ -166,7 +170,7 @@ impl Span {
debug_assert!(len <= MAX_LEN);
SpanData {
lo: BytePos(self.lo_or_index),
hi: BytePos(self.lo_or_index + len),
hi: BytePos(self.lo_or_index.debug_strict_add(len)),
ctxt: SyntaxContext::from_u32(self.ctxt_or_parent_or_marker as u32),
parent: None,
}
Expand All @@ -179,7 +183,7 @@ impl Span {
};
SpanData {
lo: BytePos(self.lo_or_index),
hi: BytePos(self.lo_or_index + len),
hi: BytePos(self.lo_or_index.debug_strict_add(len)),
ctxt: SyntaxContext::root(),
parent: Some(parent),
}
Expand Down
Loading