diff --git a/compiler/rustc_codegen_llvm/src/builder.rs b/compiler/rustc_codegen_llvm/src/builder.rs index 2139f9776b736..7986d1d9cb283 100644 --- a/compiler/rustc_codegen_llvm/src/builder.rs +++ b/compiler/rustc_codegen_llvm/src/builder.rs @@ -462,7 +462,6 @@ impl BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { load: &'ll Value, scalar: &abi::Scalar, ) { - let vr = scalar.valid_range.clone(); match scalar.value { abi::Int(..) => { let range = scalar.valid_range_exclusive(bx); @@ -470,7 +469,7 @@ impl BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> { bx.range_metadata(load, range); } } - abi::Pointer if vr.start() < vr.end() && !vr.contains(&0) => { + abi::Pointer if !scalar.valid_range.contains_zero() => { bx.nonnull_metadata(load); } _ => {} diff --git a/compiler/rustc_codegen_llvm/src/consts.rs b/compiler/rustc_codegen_llvm/src/consts.rs index e1baf95e1d9e5..ec92bd686d2df 100644 --- a/compiler/rustc_codegen_llvm/src/consts.rs +++ b/compiler/rustc_codegen_llvm/src/consts.rs @@ -16,7 +16,9 @@ use rustc_middle::mir::interpret::{ use rustc_middle::mir::mono::MonoItem; use rustc_middle::ty::{self, Instance, Ty}; use rustc_middle::{bug, span_bug}; -use rustc_target::abi::{AddressSpace, Align, HasDataLayout, LayoutOf, Primitive, Scalar, Size}; +use rustc_target::abi::{ + AddressSpace, Align, HasDataLayout, LayoutOf, Primitive, Scalar, Size, WrappingRange, +}; use tracing::debug; pub fn const_alloc_to_llvm(cx: &CodegenCx<'ll, '_>, alloc: &Allocation) -> &'ll Value { @@ -59,7 +61,7 @@ pub fn const_alloc_to_llvm(cx: &CodegenCx<'ll, '_>, alloc: &Allocation) -> &'ll Pointer::new(alloc_id, Size::from_bytes(ptr_offset)), &cx.tcx, ), - &Scalar { value: Primitive::Pointer, valid_range: 0..=!0 }, + &Scalar { value: Primitive::Pointer, valid_range: WrappingRange { start: 0, end: !0 } }, cx.type_i8p_ext(address_space), )); next_offset = offset + pointer_size; diff --git a/compiler/rustc_codegen_ssa/src/debuginfo/type_names.rs b/compiler/rustc_codegen_ssa/src/debuginfo/type_names.rs index 81e905b1b5f57..f0b32c96309d6 100644 --- a/compiler/rustc_codegen_ssa/src/debuginfo/type_names.rs +++ b/compiler/rustc_codegen_ssa/src/debuginfo/type_names.rs @@ -406,11 +406,11 @@ fn push_debuginfo_type_name<'tcx>( let dataful_discriminant_range = &dataful_variant_layout.largest_niche.as_ref().unwrap().scalar.valid_range; - let min = dataful_discriminant_range.start(); - let min = tag.value.size(&tcx).truncate(*min); + let min = dataful_discriminant_range.start; + let min = tag.value.size(&tcx).truncate(min); - let max = dataful_discriminant_range.end(); - let max = tag.value.size(&tcx).truncate(*max); + let max = dataful_discriminant_range.end; + let max = tag.value.size(&tcx).truncate(max); let dataful_variant_name = def.variants[*dataful_variant].ident.as_str(); diff --git a/compiler/rustc_codegen_ssa/src/mir/rvalue.rs b/compiler/rustc_codegen_ssa/src/mir/rvalue.rs index 7e432d2740224..90a29f24b8e0a 100644 --- a/compiler/rustc_codegen_ssa/src/mir/rvalue.rs +++ b/compiler/rustc_codegen_ssa/src/mir/rvalue.rs @@ -310,15 +310,15 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> { let er = scalar.valid_range_exclusive(bx.cx()); if er.end != er.start - && scalar.valid_range.end() >= scalar.valid_range.start() + && scalar.valid_range.end >= scalar.valid_range.start { // We want `table[e as usize ± k]` to not // have bound checks, and this is the most // convenient place to put the `assume`s. - if *scalar.valid_range.start() > 0 { + if scalar.valid_range.start > 0 { let enum_value_lower_bound = bx .cx() - .const_uint_big(ll_t_in, *scalar.valid_range.start()); + .const_uint_big(ll_t_in, scalar.valid_range.start); let cmp_start = bx.icmp( IntPredicate::IntUGE, llval, @@ -328,7 +328,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> { } let enum_value_upper_bound = - bx.cx().const_uint_big(ll_t_in, *scalar.valid_range.end()); + bx.cx().const_uint_big(ll_t_in, scalar.valid_range.end); let cmp_end = bx.icmp( IntPredicate::IntULE, llval, diff --git a/compiler/rustc_lint/src/types.rs b/compiler/rustc_lint/src/types.rs index 34d342e66945e..82a23da3ed9ca 100644 --- a/compiler/rustc_lint/src/types.rs +++ b/compiler/rustc_lint/src/types.rs @@ -797,7 +797,7 @@ crate fn repr_nullable_ptr<'tcx>( // Return the nullable type this Option-like enum can be safely represented with. let field_ty_abi = &cx.layout_of(field_ty).unwrap().abi; if let Abi::Scalar(field_ty_scalar) = field_ty_abi { - match (field_ty_scalar.valid_range.start(), field_ty_scalar.valid_range.end()) { + match (field_ty_scalar.valid_range.start, field_ty_scalar.valid_range.end) { (0, _) => unreachable!("Non-null optimisation extended to a non-zero value."), (1, _) => { return Some(get_nullable_type(cx, field_ty).unwrap()); diff --git a/compiler/rustc_middle/src/ty/layout.rs b/compiler/rustc_middle/src/ty/layout.rs index 3caca313ffddd..dcb56d5b2ba1b 100644 --- a/compiler/rustc_middle/src/ty/layout.rs +++ b/compiler/rustc_middle/src/ty/layout.rs @@ -499,7 +499,7 @@ impl<'tcx> LayoutCx<'tcx, TyCtxt<'tcx>> { let scalar_unit = |value: Primitive| { let bits = value.size(dl).bits(); assert!(bits <= 128); - Scalar { value, valid_range: 0..=(!0 >> (128 - bits)) } + Scalar { value, valid_range: WrappingRange { start: 0, end: (!0 >> (128 - bits)) } } }; let scalar = |value: Primitive| tcx.intern_layout(Layout::scalar(self, scalar_unit(value))); @@ -512,11 +512,14 @@ impl<'tcx> LayoutCx<'tcx, TyCtxt<'tcx>> { // Basic scalars. ty::Bool => tcx.intern_layout(Layout::scalar( self, - Scalar { value: Int(I8, false), valid_range: 0..=1 }, + Scalar { value: Int(I8, false), valid_range: WrappingRange { start: 0, end: 1 } }, )), ty::Char => tcx.intern_layout(Layout::scalar( self, - Scalar { value: Int(I32, false), valid_range: 0..=0x10FFFF }, + Scalar { + value: Int(I32, false), + valid_range: WrappingRange { start: 0, end: 0x10FFFF }, + }, )), ty::Int(ity) => scalar(Int(Integer::from_int_ty(dl, ity), true)), ty::Uint(ity) => scalar(Int(Integer::from_uint_ty(dl, ity), false)), @@ -526,7 +529,7 @@ impl<'tcx> LayoutCx<'tcx, TyCtxt<'tcx>> { }), ty::FnPtr(_) => { let mut ptr = scalar_unit(Pointer); - ptr.valid_range = 1..=*ptr.valid_range.end(); + ptr.valid_range = ptr.valid_range.with_start(1); tcx.intern_layout(Layout::scalar(self, ptr)) } @@ -544,7 +547,7 @@ impl<'tcx> LayoutCx<'tcx, TyCtxt<'tcx>> { ty::Ref(_, pointee, _) | ty::RawPtr(ty::TypeAndMut { ty: pointee, .. }) => { let mut data_ptr = scalar_unit(Pointer); if !ty.is_unsafe_ptr() { - data_ptr.valid_range = 1..=*data_ptr.valid_range.end(); + data_ptr.valid_range = data_ptr.valid_range.with_start(1); } let pointee = tcx.normalize_erasing_regions(param_env, pointee); @@ -560,7 +563,7 @@ impl<'tcx> LayoutCx<'tcx, TyCtxt<'tcx>> { ty::Slice(_) | ty::Str => scalar_unit(Int(dl.ptr_sized_integer(), false)), ty::Dynamic(..) => { let mut vtable = scalar_unit(Pointer); - vtable.valid_range = 1..=*vtable.valid_range.end(); + vtable.valid_range = vtable.valid_range.with_start(1); vtable } _ => return Err(LayoutError::Unknown(unsized_part)), @@ -933,14 +936,14 @@ impl<'tcx> LayoutCx<'tcx, TyCtxt<'tcx>> { if let Bound::Included(start) = start { // FIXME(eddyb) this might be incorrect - it doesn't // account for wrap-around (end < start) ranges. - assert!(*scalar.valid_range.start() <= start); - scalar.valid_range = start..=*scalar.valid_range.end(); + assert!(scalar.valid_range.start <= start); + scalar.valid_range.start = start; } if let Bound::Included(end) = end { // FIXME(eddyb) this might be incorrect - it doesn't // account for wrap-around (end < start) ranges. - assert!(*scalar.valid_range.end() >= end); - scalar.valid_range = *scalar.valid_range.start()..=end; + assert!(scalar.valid_range.end >= end); + scalar.valid_range.end = end; } // Update `largest_niche` if we have introduced a larger niche. @@ -1256,7 +1259,10 @@ impl<'tcx> LayoutCx<'tcx, TyCtxt<'tcx>> { let tag_mask = !0u128 >> (128 - ity.size().bits()); let tag = Scalar { value: Int(ity, signed), - valid_range: (min as u128 & tag_mask)..=(max as u128 & tag_mask), + valid_range: WrappingRange { + start: (min as u128 & tag_mask), + end: (max as u128 & tag_mask), + }, }; let mut abi = Abi::Aggregate { sized: true }; if tag.value.size(dl) == size { @@ -1535,7 +1541,10 @@ impl<'tcx> LayoutCx<'tcx, TyCtxt<'tcx>> { let max_discr = (info.variant_fields.len() - 1) as u128; let discr_int = Integer::fit_unsigned(max_discr); let discr_int_ty = discr_int.to_ty(tcx, false); - let tag = Scalar { value: Primitive::Int(discr_int, false), valid_range: 0..=max_discr }; + let tag = Scalar { + value: Primitive::Int(discr_int, false), + valid_range: WrappingRange { start: 0, end: max_discr }, + }; let tag_layout = self.tcx.intern_layout(Layout::scalar(self, tag.clone())); let tag_layout = TyAndLayout { ty: discr_int_ty, layout: tag_layout }; @@ -2846,10 +2855,8 @@ where return; } - if scalar.valid_range.start() < scalar.valid_range.end() { - if *scalar.valid_range.start() > 0 { - attrs.set(ArgAttribute::NonNull); - } + if !scalar.valid_range.contains_zero() { + attrs.set(ArgAttribute::NonNull); } if let Some(pointee) = layout.pointee_info_at(cx, offset) { diff --git a/compiler/rustc_mir/src/interpret/validity.rs b/compiler/rustc_mir/src/interpret/validity.rs index 0c7f89c1a36ba..3ff149d6a7a25 100644 --- a/compiler/rustc_mir/src/interpret/validity.rs +++ b/compiler/rustc_mir/src/interpret/validity.rs @@ -7,7 +7,6 @@ use std::convert::TryFrom; use std::fmt::Write; use std::num::NonZeroUsize; -use std::ops::RangeInclusive; use rustc_data_structures::fx::FxHashSet; use rustc_hir as hir; @@ -15,7 +14,9 @@ use rustc_middle::mir::interpret::InterpError; use rustc_middle::ty; use rustc_middle::ty::layout::TyAndLayout; use rustc_span::symbol::{sym, Symbol}; -use rustc_target::abi::{Abi, LayoutOf, Scalar as ScalarAbi, Size, VariantIdx, Variants}; +use rustc_target::abi::{ + Abi, LayoutOf, Scalar as ScalarAbi, Size, VariantIdx, Variants, WrappingRange, +}; use std::hash::Hash; @@ -181,22 +182,10 @@ fn write_path(out: &mut String, path: &[PathElem]) { } } -// Test if a range that wraps at overflow contains `test` -fn wrapping_range_contains(r: &RangeInclusive, test: u128) -> bool { - let (lo, hi) = r.clone().into_inner(); - if lo > hi { - // Wrapped - (..=hi).contains(&test) || (lo..).contains(&test) - } else { - // Normal - r.contains(&test) - } -} - // Formats such that a sentence like "expected something {}" to mean // "expected something " makes sense. -fn wrapping_range_format(r: &RangeInclusive, max_hi: u128) -> String { - let (lo, hi) = r.clone().into_inner(); +fn wrapping_range_format(r: WrappingRange, max_hi: u128) -> String { + let WrappingRange { start: lo, end: hi } = r; assert!(hi <= max_hi); if lo > hi { format!("less or equal to {}, or greater or equal to {}", hi, lo) @@ -634,8 +623,8 @@ impl<'rt, 'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> ValidityVisitor<'rt, 'mir, ' scalar_layout: &ScalarAbi, ) -> InterpResult<'tcx> { let value = self.read_scalar(op)?; - let valid_range = &scalar_layout.valid_range; - let (lo, hi) = valid_range.clone().into_inner(); + let valid_range = scalar_layout.valid_range.clone(); + let WrappingRange { start: lo, end: hi } = valid_range; // Determine the allowed range // `max_hi` is as big as the size fits let max_hi = u128::MAX >> (128 - op.layout.size.bits()); @@ -684,7 +673,7 @@ impl<'rt, 'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> ValidityVisitor<'rt, 'mir, ' Ok(int) => int.assert_bits(op.layout.size), }; // Now compare. This is slightly subtle because this is a special "wrap-around" range. - if wrapping_range_contains(&valid_range, bits) { + if valid_range.contains(bits) { Ok(()) } else { throw_validation_failure!(self.path, diff --git a/compiler/rustc_target/src/abi/mod.rs b/compiler/rustc_target/src/abi/mod.rs index 8ef6e142caecf..d206df461200a 100644 --- a/compiler/rustc_target/src/abi/mod.rs +++ b/compiler/rustc_target/src/abi/mod.rs @@ -677,32 +677,80 @@ impl Primitive { } } +/// Inclusive wrap-around range of valid values, that is, if +/// start > end, it represents `start..=MAX`, +/// followed by `0..=end`. +/// +/// That is, for an i8 primitive, a range of `254..=2` means following +/// sequence: +/// +/// 254 (-2), 255 (-1), 0, 1, 2 +/// +/// This is intended specifically to mirror LLVM’s `!range` metadata, +/// semantics. +#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(HashStable_Generic)] +pub struct WrappingRange { + pub start: u128, + pub end: u128, +} + +impl WrappingRange { + /// Returns `true` if `v` is contained in the range. + #[inline(always)] + pub fn contains(&self, v: u128) -> bool { + if self.start <= self.end { + self.start <= v && v <= self.end + } else { + self.start <= v || v <= self.end + } + } + + /// Returns `true` if zero is contained in the range. + /// Equal to `range.contains(0)` but should be faster. + #[inline(always)] + pub fn contains_zero(&self) -> bool { + self.start > self.end || self.start == 0 + } + + /// Returns `self` with replaced `start` + #[inline(always)] + pub fn with_start(mut self, start: u128) -> Self { + self.start = start; + self + } + + /// Returns `self` with replaced `end` + #[inline(always)] + pub fn with_end(mut self, end: u128) -> Self { + self.end = end; + self + } +} + +impl fmt::Debug for WrappingRange { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "{}..={}", self.start, self.end)?; + Ok(()) + } +} + /// Information about one scalar component of a Rust type. #[derive(Clone, PartialEq, Eq, Hash, Debug)] #[derive(HashStable_Generic)] pub struct Scalar { pub value: Primitive, - /// Inclusive wrap-around range of valid values, that is, if - /// start > end, it represents `start..=MAX`, - /// followed by `0..=end`. - /// - /// That is, for an i8 primitive, a range of `254..=2` means following - /// sequence: - /// - /// 254 (-2), 255 (-1), 0, 1, 2 - /// - /// This is intended specifically to mirror LLVM’s `!range` metadata, - /// semantics. // FIXME(eddyb) always use the shortest range, e.g., by finding // the largest space between two consecutive valid values and // taking everything else as the (shortest) valid range. - pub valid_range: RangeInclusive, + pub valid_range: WrappingRange, } impl Scalar { pub fn is_bool(&self) -> bool { - matches!(self.value, Int(I8, false)) && self.valid_range == (0..=1) + matches!(self.value, Int(I8, false)) + && matches!(self.valid_range, WrappingRange { start: 0, end: 1 }) } /// Returns the valid range as a `x..y` range. @@ -715,8 +763,8 @@ impl Scalar { let bits = self.value.size(cx).bits(); assert!(bits <= 128); let mask = !0u128 >> (128 - bits); - let start = *self.valid_range.start(); - let end = *self.valid_range.end(); + let start = self.valid_range.start; + let end = self.valid_range.end; assert_eq!(start, start & mask); assert_eq!(end, end & mask); start..(end.wrapping_add(1) & mask) @@ -971,14 +1019,14 @@ impl Niche { let max_value = !0u128 >> (128 - bits); // Find out how many values are outside the valid range. - let niche = v.end().wrapping_add(1)..*v.start(); + let niche = v.end.wrapping_add(1)..v.start; niche.end.wrapping_sub(niche.start) & max_value } pub fn reserve(&self, cx: &C, count: u128) -> Option<(u128, Scalar)> { assert!(count > 0); - let Scalar { value, valid_range: ref v } = self.scalar; + let Scalar { value, valid_range: v } = self.scalar.clone(); let bits = value.size(cx).bits(); assert!(bits <= 128); let max_value = !0u128 >> (128 - bits); @@ -988,24 +1036,14 @@ impl Niche { } // Compute the range of invalid values being reserved. - let start = v.end().wrapping_add(1) & max_value; - let end = v.end().wrapping_add(count) & max_value; - - // If the `end` of our range is inside the valid range, - // then we ran out of invalid values. - // FIXME(eddyb) abstract this with a wraparound range type. - let valid_range_contains = |x| { - if v.start() <= v.end() { - *v.start() <= x && x <= *v.end() - } else { - *v.start() <= x || x <= *v.end() - } - }; - if valid_range_contains(end) { + let start = v.end.wrapping_add(1) & max_value; + let end = v.end.wrapping_add(count) & max_value; + + if v.contains(end) { return None; } - Some((start, Scalar { value, valid_range: *v.start()..=end })) + Some((start, Scalar { value, valid_range: v.with_end(end) })) } } @@ -1212,9 +1250,8 @@ impl<'a, Ty> TyAndLayout<'a, Ty> { { let scalar_allows_raw_init = move |s: &Scalar| -> bool { if zero { - let range = &s.valid_range; // The range must contain 0. - range.contains(&0) || (*range.start() > *range.end()) // wrap-around allows 0 + s.valid_range.contains_zero() } else { // The range must include all values. `valid_range_exclusive` handles // the wrap-around using target arithmetic; with wrap-around then the full