Skip to content

Commit

Permalink
ScalarInt: add methods to assert being a (u)int of given size
Browse files Browse the repository at this point in the history
  • Loading branch information
RalfJung committed Apr 19, 2024
1 parent 5e6184c commit 42220f0
Show file tree
Hide file tree
Showing 13 changed files with 78 additions and 72 deletions.
41 changes: 19 additions & 22 deletions compiler/rustc_codegen_cranelift/src/constant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ pub(crate) fn codegen_const_value<'tcx>(
if fx.clif_type(layout.ty).is_some() {
return CValue::const_val(fx, layout, int);
} else {
let raw_val = int.size().truncate(int.to_bits(int.size()).unwrap());
let raw_val = int.size().truncate(int.assert_bits(int.size()));
let val = match int.size().bytes() {
1 => fx.bcx.ins().iconst(types::I8, raw_val as i64),
2 => fx.bcx.ins().iconst(types::I16, raw_val as i64),
Expand Down Expand Up @@ -491,27 +491,24 @@ pub(crate) fn mir_operand_get_const_val<'tcx>(
return None;
}
let scalar_int = mir_operand_get_const_val(fx, operand)?;
let scalar_int = match fx
.layout_of(*ty)
.size
.cmp(&scalar_int.size())
{
Ordering::Equal => scalar_int,
Ordering::Less => match ty.kind() {
ty::Uint(_) => ScalarInt::try_from_uint(
scalar_int.try_to_uint(scalar_int.size()).unwrap(),
fx.layout_of(*ty).size,
)
.unwrap(),
ty::Int(_) => ScalarInt::try_from_int(
scalar_int.try_to_int(scalar_int.size()).unwrap(),
fx.layout_of(*ty).size,
)
.unwrap(),
_ => unreachable!(),
},
Ordering::Greater => return None,
};
let scalar_int =
match fx.layout_of(*ty).size.cmp(&scalar_int.size()) {
Ordering::Equal => scalar_int,
Ordering::Less => match ty.kind() {
ty::Uint(_) => ScalarInt::try_from_uint(
scalar_int.assert_uint(scalar_int.size()),
fx.layout_of(*ty).size,
)
.unwrap(),
ty::Int(_) => ScalarInt::try_from_int(
scalar_int.assert_int(scalar_int.size()),
fx.layout_of(*ty).size,
)
.unwrap(),
_ => unreachable!(),
},
Ordering::Greater => return None,
};
computed_scalar_int = Some(scalar_int);
}
Rvalue::Use(operand) => {
Expand Down
4 changes: 2 additions & 2 deletions compiler/rustc_codegen_cranelift/src/value_and_place.rs
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ impl<'tcx> CValue<'tcx> {

let val = match layout.ty.kind() {
ty::Uint(UintTy::U128) | ty::Int(IntTy::I128) => {
let const_val = const_val.to_bits(layout.size).unwrap();
let const_val = const_val.assert_bits(layout.size);
let lsb = fx.bcx.ins().iconst(types::I64, const_val as u64 as i64);
let msb = fx.bcx.ins().iconst(types::I64, (const_val >> 64) as u64 as i64);
fx.bcx.ins().iconcat(lsb, msb)
Expand All @@ -338,7 +338,7 @@ impl<'tcx> CValue<'tcx> {
| ty::Ref(..)
| ty::RawPtr(..)
| ty::FnPtr(..) => {
let raw_val = const_val.size().truncate(const_val.to_bits(layout.size).unwrap());
let raw_val = const_val.size().truncate(const_val.assert_bits(layout.size));
fx.bcx.ins().iconst(clif_ty, raw_val as i64)
}
ty::Float(FloatTy::F32) => {
Expand Down
3 changes: 1 addition & 2 deletions compiler/rustc_const_eval/src/interpret/discriminant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -295,8 +295,7 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
&niche_start_val,
)?
.to_scalar()
.try_to_int()
.unwrap();
.assert_int();
Ok(Some((tag, tag_field)))
}
}
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_const_eval/src/interpret/operand.rs
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ impl<'tcx, Prov: Provenance> ImmTy<'tcx, Prov> {
}

/// Return the immediate as a `ScalarInt`. Ensures that it has the size that the layout of the
/// immediate indcates.
/// immediate indicates.
#[inline]
pub fn to_scalar_int(&self) -> InterpResult<'tcx, ScalarInt> {
let s = self.to_scalar().to_scalar_int()?;
Expand Down
8 changes: 4 additions & 4 deletions compiler/rustc_const_eval/src/interpret/operator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,10 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
let l = left.to_scalar_int()?;
let r = right.to_scalar_int()?;
// Prepare to convert the values to signed or unsigned form.
let l_signed = || l.try_to_int(left.layout.size).unwrap();
let l_unsigned = || l.try_to_uint(left.layout.size).unwrap();
let r_signed = || r.try_to_int(right.layout.size).unwrap();
let r_unsigned = || r.try_to_uint(right.layout.size).unwrap();
let l_signed = || l.assert_int(left.layout.size);
let l_unsigned = || l.assert_uint(left.layout.size);
let r_signed = || r.assert_int(right.layout.size);
let r_unsigned = || r.assert_uint(right.layout.size);

let throw_ub_on_overflow = match bin_op {
AddUnchecked => Some(sym::unchecked_add),
Expand Down
6 changes: 3 additions & 3 deletions compiler/rustc_middle/src/mir/consts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ impl<'tcx> ConstValue<'tcx> {
}

pub fn try_to_bits(&self, size: Size) -> Option<u128> {
self.try_to_scalar_int()?.to_bits(size).ok()
self.try_to_scalar_int()?.try_to_bits(size).ok()
}

pub fn try_to_bool(&self) -> Option<bool> {
Expand Down Expand Up @@ -260,7 +260,7 @@ impl<'tcx> Const<'tcx> {

#[inline]
pub fn try_to_bits(self, size: Size) -> Option<u128> {
self.try_to_scalar_int()?.to_bits(size).ok()
self.try_to_scalar_int()?.try_to_bits(size).ok()
}

#[inline]
Expand Down Expand Up @@ -334,7 +334,7 @@ impl<'tcx> Const<'tcx> {
let int = self.try_eval_scalar_int(tcx, param_env)?;
let size =
tcx.layout_of(param_env.with_reveal_all_normalized(tcx).and(self.ty())).ok()?.size;
int.to_bits(size).ok()
int.try_to_bits(size).ok()
}

/// Panics if the value cannot be evaluated or doesn't contain a valid integer of the given type.
Expand Down
4 changes: 2 additions & 2 deletions compiler/rustc_middle/src/mir/interpret/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ impl<Prov> Scalar<Prov> {
) -> Result<Either<u128, Pointer<Prov>>, ScalarSizeMismatch> {
assert_ne!(target_size.bytes(), 0, "you should never look at the bits of a ZST");
Ok(match self {
Scalar::Int(int) => Left(int.to_bits(target_size).map_err(|size| {
Scalar::Int(int) => Left(int.try_to_bits(target_size).map_err(|size| {
ScalarSizeMismatch { target_size: target_size.bytes(), data_size: size.bytes() }
})?),
Scalar::Ptr(ptr, sz) => {
Expand Down Expand Up @@ -316,7 +316,7 @@ impl<'tcx, Prov: Provenance> Scalar<Prov> {
#[inline]
pub fn to_bits(self, target_size: Size) -> InterpResult<'tcx, u128> {
assert_ne!(target_size.bytes(), 0, "you should never look at the bits of a ZST");
self.to_scalar_int()?.to_bits(target_size).map_err(|size| {
self.to_scalar_int()?.try_to_bits(target_size).map_err(|size| {
err_ub!(ScalarSizeMismatch(ScalarSizeMismatch {
target_size: target_size.bytes(),
data_size: size.bytes(),
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_middle/src/ty/consts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ impl<'tcx> Const<'tcx> {
let size =
tcx.layout_of(param_env.with_reveal_all_normalized(tcx).and(self.ty())).ok()?.size;
// if `ty` does not depend on generic parameters, use an empty param_env
int.to_bits(size).ok()
int.try_to_bits(size).ok()
}

#[inline]
Expand Down
67 changes: 39 additions & 28 deletions compiler/rustc_middle/src/ty/consts/int.rs
Original file line number Diff line number Diff line change
Expand Up @@ -247,14 +247,7 @@ impl ScalarInt {
}

#[inline]
pub fn assert_bits(self, target_size: Size) -> u128 {
self.to_bits(target_size).unwrap_or_else(|size| {
bug!("expected int of size {}, but got size {}", target_size.bytes(), size.bytes())
})
}

#[inline]
pub fn to_bits(self, target_size: Size) -> Result<u128, Size> {
pub fn try_to_bits(self, target_size: Size) -> Result<u128, Size> {
assert_ne!(target_size.bytes(), 0, "you should never look at the bits of a ZST");
if target_size.bytes() == u64::from(self.size.get()) {
self.check_data();
Expand All @@ -264,48 +257,60 @@ impl ScalarInt {
}
}

#[inline]
pub fn assert_bits(self, target_size: Size) -> u128 {
self.try_to_bits(target_size).unwrap_or_else(|size| {
bug!("expected int of size {}, but got size {}", target_size.bytes(), size.bytes())
})
}

/// Tries to convert the `ScalarInt` to an unsigned integer of the given size.
/// Fails if the size of the `ScalarInt` is not equal to `size` and returns the
/// `ScalarInt`s size in that case.
#[inline]
pub fn try_to_uint(self, size: Size) -> Result<u128, Size> {
self.to_bits(size)
self.try_to_bits(size)
}

#[inline]
pub fn assert_uint(self, size: Size) -> u128 {
self.assert_bits(size)
}

// Tries to convert the `ScalarInt` to `u8`. Fails if the `size` of the `ScalarInt`
// in not equal to `Size { raw: 1 }` and returns the `size` value of the `ScalarInt` in
// in not equal to 1 byte and returns the `size` value of the `ScalarInt` in
// that case.
#[inline]
pub fn try_to_u8(self) -> Result<u8, Size> {
self.try_to_uint(Size::from_bits(8)).map(|v| u8::try_from(v).unwrap())
}

/// Tries to convert the `ScalarInt` to `u16`. Fails if the size of the `ScalarInt`
/// in not equal to `Size { raw: 2 }` and returns the `size` value of the `ScalarInt` in
/// in not equal to 2 bytes and returns the `size` value of the `ScalarInt` in
/// that case.
#[inline]
pub fn try_to_u16(self) -> Result<u16, Size> {
self.try_to_uint(Size::from_bits(16)).map(|v| u16::try_from(v).unwrap())
}

/// Tries to convert the `ScalarInt` to `u32`. Fails if the `size` of the `ScalarInt`
/// in not equal to `Size { raw: 4 }` and returns the `size` value of the `ScalarInt` in
/// in not equal to 4 bytes and returns the `size` value of the `ScalarInt` in
/// that case.
#[inline]
pub fn try_to_u32(self) -> Result<u32, Size> {
self.try_to_uint(Size::from_bits(32)).map(|v| u32::try_from(v).unwrap())
}

/// Tries to convert the `ScalarInt` to `u64`. Fails if the `size` of the `ScalarInt`
/// in not equal to `Size { raw: 8 }` and returns the `size` value of the `ScalarInt` in
/// in not equal to 8 bytes and returns the `size` value of the `ScalarInt` in
/// that case.
#[inline]
pub fn try_to_u64(self) -> Result<u64, Size> {
self.try_to_uint(Size::from_bits(64)).map(|v| u64::try_from(v).unwrap())
}

/// Tries to convert the `ScalarInt` to `u128`. Fails if the `size` of the `ScalarInt`
/// in not equal to `Size { raw: 16 }` and returns the `size` value of the `ScalarInt` in
/// in not equal to 16 bytes and returns the `size` value of the `ScalarInt` in
/// that case.
#[inline]
pub fn try_to_u128(self) -> Result<u128, Size> {
Expand All @@ -318,7 +323,7 @@ impl ScalarInt {
}

// Tries to convert the `ScalarInt` to `bool`. Fails if the `size` of the `ScalarInt`
// in not equal to `Size { raw: 1 }` or if the value is not 0 or 1 and returns the `size`
// in not equal to 1 byte or if the value is not 0 or 1 and returns the `size`
// value of the `ScalarInt` in that case.
#[inline]
pub fn try_to_bool(self) -> Result<bool, Size> {
Expand All @@ -334,40 +339,46 @@ impl ScalarInt {
/// `ScalarInt`s size in that case.
#[inline]
pub fn try_to_int(self, size: Size) -> Result<i128, Size> {
let b = self.to_bits(size)?;
let b = self.try_to_bits(size)?;
Ok(size.sign_extend(b) as i128)
}

#[inline]
pub fn assert_int(self, size: Size) -> i128 {
let b = self.assert_bits(size);
size.sign_extend(b) as i128
}

/// Tries to convert the `ScalarInt` to i8.
/// Fails if the size of the `ScalarInt` is not equal to `Size { raw: 1 }`
/// Fails if the size of the `ScalarInt` is not equal to 1 byte
/// and returns the `ScalarInt`s size in that case.
pub fn try_to_i8(self) -> Result<i8, Size> {
self.try_to_int(Size::from_bits(8)).map(|v| i8::try_from(v).unwrap())
}

/// Tries to convert the `ScalarInt` to i16.
/// Fails if the size of the `ScalarInt` is not equal to `Size { raw: 2 }`
/// Fails if the size of the `ScalarInt` is not equal to 2 bytes
/// and returns the `ScalarInt`s size in that case.
pub fn try_to_i16(self) -> Result<i16, Size> {
self.try_to_int(Size::from_bits(16)).map(|v| i16::try_from(v).unwrap())
}

/// Tries to convert the `ScalarInt` to i32.
/// Fails if the size of the `ScalarInt` is not equal to `Size { raw: 4 }`
/// Fails if the size of the `ScalarInt` is not equal to 4 bytes
/// and returns the `ScalarInt`s size in that case.
pub fn try_to_i32(self) -> Result<i32, Size> {
self.try_to_int(Size::from_bits(32)).map(|v| i32::try_from(v).unwrap())
}

/// Tries to convert the `ScalarInt` to i64.
/// Fails if the size of the `ScalarInt` is not equal to `Size { raw: 8 }`
/// Fails if the size of the `ScalarInt` is not equal to 8 bytes
/// and returns the `ScalarInt`s size in that case.
pub fn try_to_i64(self) -> Result<i64, Size> {
self.try_to_int(Size::from_bits(64)).map(|v| i64::try_from(v).unwrap())
}

/// Tries to convert the `ScalarInt` to i128.
/// Fails if the size of the `ScalarInt` is not equal to `Size { raw: 16 }`
/// Fails if the size of the `ScalarInt` is not equal to 16 bytes
/// and returns the `ScalarInt`s size in that case.
pub fn try_to_i128(self) -> Result<i128, Size> {
self.try_to_int(Size::from_bits(128))
Expand All @@ -381,7 +392,7 @@ impl ScalarInt {
#[inline]
pub fn try_to_float<F: Float>(self) -> Result<F, Size> {
// Going through `to_uint` to check size and truncation.
Ok(F::from_bits(self.to_bits(Size::from_bits(F::BITS))?))
Ok(F::from_bits(self.try_to_bits(Size::from_bits(F::BITS))?))
}

#[inline]
Expand Down Expand Up @@ -430,7 +441,7 @@ macro_rules! try_from {
fn try_from(int: ScalarInt) -> Result<Self, Size> {
// The `unwrap` cannot fail because to_bits (if it succeeds)
// is guaranteed to return a value that fits into the size.
int.to_bits(Size::from_bytes(std::mem::size_of::<$ty>()))
int.try_to_bits(Size::from_bytes(std::mem::size_of::<$ty>()))
.map(|u| u.try_into().unwrap())
}
}
Expand Down Expand Up @@ -465,7 +476,7 @@ impl TryFrom<ScalarInt> for char {

#[inline]
fn try_from(int: ScalarInt) -> Result<Self, Self::Error> {
let Ok(bits) = int.to_bits(Size::from_bytes(std::mem::size_of::<char>())) else {
let Ok(bits) = int.try_to_bits(Size::from_bytes(std::mem::size_of::<char>())) else {
return Err(CharTryFromScalarInt);
};
match char::from_u32(bits.try_into().unwrap()) {
Expand All @@ -487,7 +498,7 @@ impl TryFrom<ScalarInt> for Half {
type Error = Size;
#[inline]
fn try_from(int: ScalarInt) -> Result<Self, Size> {
int.to_bits(Size::from_bytes(2)).map(Self::from_bits)
int.try_to_bits(Size::from_bytes(2)).map(Self::from_bits)
}
}

Expand All @@ -503,7 +514,7 @@ impl TryFrom<ScalarInt> for Single {
type Error = Size;
#[inline]
fn try_from(int: ScalarInt) -> Result<Self, Size> {
int.to_bits(Size::from_bytes(4)).map(Self::from_bits)
int.try_to_bits(Size::from_bytes(4)).map(Self::from_bits)
}
}

Expand All @@ -519,7 +530,7 @@ impl TryFrom<ScalarInt> for Double {
type Error = Size;
#[inline]
fn try_from(int: ScalarInt) -> Result<Self, Size> {
int.to_bits(Size::from_bytes(8)).map(Self::from_bits)
int.try_to_bits(Size::from_bytes(8)).map(Self::from_bits)
}
}

Expand All @@ -535,7 +546,7 @@ impl TryFrom<ScalarInt> for Quad {
type Error = Size;
#[inline]
fn try_from(int: ScalarInt) -> Result<Self, Size> {
int.to_bits(Size::from_bytes(16)).map(Self::from_bits)
int.try_to_bits(Size::from_bytes(16)).map(Self::from_bits)
}
}

Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_mir_transform/src/known_panics_lint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -796,7 +796,7 @@ impl<'tcx> Visitor<'tcx> for ConstPropagator<'_, 'tcx> {
if let Some(ref value) = self.eval_operand(discr)
&& let Some(value_const) = self.use_ecx(|this| this.ecx.read_scalar(value))
&& let Ok(constant) = value_const.try_to_int()
&& let Ok(constant) = constant.to_bits(constant.size())
&& let Ok(constant) = constant.try_to_bits(constant.size())
{
// We managed to evaluate the discriminant, so we know we only need to visit
// one target.
Expand Down
3 changes: 1 addition & 2 deletions compiler/rustc_mir_transform/src/match_branches.rs
Original file line number Diff line number Diff line change
Expand Up @@ -369,8 +369,7 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
}

fn int_equal(l: ScalarInt, r: impl Into<u128>, size: Size) -> bool {
l.try_to_int(l.size()).unwrap()
== ScalarInt::try_from_uint(r, size).unwrap().try_to_int(size).unwrap()
l.assert_int(l.size()) == ScalarInt::try_from_uint(r, size).unwrap().assert_int(size)
}

// We first compare the two branches, and then the other branches need to fulfill the same conditions.
Expand Down
Loading

0 comments on commit 42220f0

Please sign in to comment.