Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Match usize/isize exhaustively with half-open ranges #116692

Merged
merged 9 commits into from
Nov 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
249 changes: 239 additions & 10 deletions compiler/rustc_middle/src/thir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,19 @@ use rustc_hir::RangeEnd;
use rustc_index::newtype_index;
use rustc_index::IndexVec;
use rustc_middle::middle::region;
use rustc_middle::mir::interpret::AllocId;
use rustc_middle::mir::interpret::{AllocId, Scalar};
use rustc_middle::mir::{self, BinOp, BorrowKind, FakeReadCause, Mutability, UnOp};
use rustc_middle::ty::adjustment::PointerCoercion;
use rustc_middle::ty::layout::IntegerExt;
use rustc_middle::ty::{
self, AdtDef, CanonicalUserType, CanonicalUserTypeAnnotation, FnSig, GenericArgsRef, List, Ty,
UpvarArgs,
TyCtxt, UpvarArgs,
};
use rustc_span::def_id::LocalDefId;
use rustc_span::{sym, ErrorGuaranteed, Span, Symbol, DUMMY_SP};
use rustc_target::abi::{FieldIdx, VariantIdx};
use rustc_target::abi::{FieldIdx, Integer, Size, VariantIdx};
use rustc_target::asm::InlineAsmRegOrRegClass;
use std::cmp::Ordering;
use std::fmt;
use std::ops::Index;

Expand Down Expand Up @@ -810,12 +812,243 @@ pub enum PatKind<'tcx> {
Error(ErrorGuaranteed),
}

/// A range pattern.
/// The boundaries must be of the same type and that type must be numeric.
#[derive(Clone, Debug, PartialEq, HashStable, TypeVisitable)]
pub struct PatRange<'tcx> {
pub lo: mir::Const<'tcx>,
pub hi: mir::Const<'tcx>,
pub lo: PatRangeBoundary<'tcx>,
pub hi: PatRangeBoundary<'tcx>,
#[type_visitable(ignore)]
pub end: RangeEnd,
pub ty: Ty<'tcx>,
}

impl<'tcx> PatRange<'tcx> {
/// Whether this range covers the full extent of possible values (best-effort, we ignore floats).
#[inline]
pub fn is_full_range(&self, tcx: TyCtxt<'tcx>) -> Option<bool> {
let (min, max, size, bias) = match *self.ty.kind() {
ty::Char => (0, std::char::MAX as u128, Size::from_bits(32), 0),
ty::Int(ity) => {
let size = Integer::from_int_ty(&tcx, ity).size();
let max = size.truncate(u128::MAX);
let bias = 1u128 << (size.bits() - 1);
(0, max, size, bias)
}
ty::Uint(uty) => {
let size = Integer::from_uint_ty(&tcx, uty).size();
let max = size.unsigned_int_max();
(0, max, size, 0)
}
_ => return None,
};

// We want to compare ranges numerically, but the order of the bitwise representation of
// signed integers does not match their numeric order. Thus, to correct the ordering, we
// need to shift the range of signed integers to correct the comparison. This is achieved by
// XORing with a bias (see pattern/deconstruct_pat.rs for another pertinent example of this
// pattern).
//
// Also, for performance, it's important to only do the second `try_to_bits` if necessary.
let lo_is_min = match self.lo {
PatRangeBoundary::NegInfinity => true,
PatRangeBoundary::Finite(value) => {
let lo = value.try_to_bits(size).unwrap() ^ bias;
lo <= min
}
PatRangeBoundary::PosInfinity => false,
};
if lo_is_min {
let hi_is_max = match self.hi {
PatRangeBoundary::NegInfinity => false,
PatRangeBoundary::Finite(value) => {
let hi = value.try_to_bits(size).unwrap() ^ bias;
hi > max || hi == max && self.end == RangeEnd::Included
}
PatRangeBoundary::PosInfinity => true,
};
if hi_is_max {
return Some(true);
}
}
Some(false)
}

#[inline]
pub fn contains(
&self,
value: mir::Const<'tcx>,
tcx: TyCtxt<'tcx>,
param_env: ty::ParamEnv<'tcx>,
) -> Option<bool> {
use Ordering::*;
debug_assert_eq!(self.ty, value.ty());
let ty = self.ty;
let value = PatRangeBoundary::Finite(value);
// For performance, it's important to only do the second comparison if necessary.
Some(
match self.lo.compare_with(value, ty, tcx, param_env)? {
Less | Equal => true,
Greater => false,
} && match value.compare_with(self.hi, ty, tcx, param_env)? {
Less => true,
Equal => self.end == RangeEnd::Included,
Greater => false,
},
)
}

#[inline]
pub fn overlaps(
&self,
other: &Self,
tcx: TyCtxt<'tcx>,
param_env: ty::ParamEnv<'tcx>,
) -> Option<bool> {
use Ordering::*;
debug_assert_eq!(self.ty, other.ty);
// For performance, it's important to only do the second comparison if necessary.
Some(
match other.lo.compare_with(self.hi, self.ty, tcx, param_env)? {
Less => true,
Equal => self.end == RangeEnd::Included,
Greater => false,
} && match self.lo.compare_with(other.hi, self.ty, tcx, param_env)? {
Less => true,
Equal => other.end == RangeEnd::Included,
Greater => false,
},
)
}
}

impl<'tcx> fmt::Display for PatRange<'tcx> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if let PatRangeBoundary::Finite(value) = &self.lo {
write!(f, "{value}")?;
}
if let PatRangeBoundary::Finite(value) = &self.hi {
write!(f, "{}", self.end)?;
write!(f, "{value}")?;
} else {
// `0..` is parsed as an inclusive range, we must display it correctly.
write!(f, "..")?;
}
Ok(())
}
}

/// A (possibly open) boundary of a range pattern.
/// If present, the const must be of a numeric type.
#[derive(Copy, Clone, Debug, PartialEq, HashStable, TypeVisitable)]
pub enum PatRangeBoundary<'tcx> {
Finite(mir::Const<'tcx>),
NegInfinity,
PosInfinity,
}

impl<'tcx> PatRangeBoundary<'tcx> {
#[inline]
pub fn is_finite(self) -> bool {
matches!(self, Self::Finite(..))
}
#[inline]
pub fn as_finite(self) -> Option<mir::Const<'tcx>> {
match self {
Self::Finite(value) => Some(value),
Self::NegInfinity | Self::PosInfinity => None,
}
}
#[inline]
pub fn to_const(self, ty: Ty<'tcx>, tcx: TyCtxt<'tcx>) -> mir::Const<'tcx> {
match self {
Self::Finite(value) => value,
Self::NegInfinity => {
// Unwrap is ok because the type is known to be numeric.
let c = ty.numeric_min_val(tcx).unwrap();
mir::Const::from_ty_const(c, tcx)
}
Self::PosInfinity => {
// Unwrap is ok because the type is known to be numeric.
let c = ty.numeric_max_val(tcx).unwrap();
mir::Const::from_ty_const(c, tcx)
}
}
}
pub fn eval_bits(self, ty: Ty<'tcx>, tcx: TyCtxt<'tcx>, param_env: ty::ParamEnv<'tcx>) -> u128 {
match self {
Self::Finite(value) => value.eval_bits(tcx, param_env),
Self::NegInfinity => {
// Unwrap is ok because the type is known to be numeric.
ty.numeric_min_and_max_as_bits(tcx).unwrap().0
}
Self::PosInfinity => {
// Unwrap is ok because the type is known to be numeric.
ty.numeric_min_and_max_as_bits(tcx).unwrap().1
}
}
}

#[instrument(skip(tcx, param_env), level = "debug", ret)]
pub fn compare_with(
self,
other: Self,
ty: Ty<'tcx>,
tcx: TyCtxt<'tcx>,
param_env: ty::ParamEnv<'tcx>,
) -> Option<Ordering> {
use PatRangeBoundary::*;
match (self, other) {
// When comparing with infinities, we must remember that `0u8..` and `0u8..=255`
// describe the same range. These two shortcuts are ok, but for the rest we must check
// bit values.
(PosInfinity, PosInfinity) => return Some(Ordering::Equal),
(NegInfinity, NegInfinity) => return Some(Ordering::Equal),
cjgillot marked this conversation as resolved.
Show resolved Hide resolved

// This code is hot when compiling matches with many ranges. So we
// special-case extraction of evaluated scalars for speed, for types where
// raw data comparisons are appropriate. E.g. `unicode-normalization` has
// many ranges such as '\u{037A}'..='\u{037F}', and chars can be compared
// in this way.
(Finite(mir::Const::Ty(a)), Finite(mir::Const::Ty(b)))
if matches!(ty.kind(), ty::Uint(_) | ty::Char) =>
{
return Some(a.kind().cmp(&b.kind()));
}
(
Finite(mir::Const::Val(mir::ConstValue::Scalar(Scalar::Int(a)), _)),
Finite(mir::Const::Val(mir::ConstValue::Scalar(Scalar::Int(b)), _)),
) if matches!(ty.kind(), ty::Uint(_) | ty::Char) => return Some(a.cmp(&b)),
_ => {}
}

let a = self.eval_bits(ty, tcx, param_env);
let b = other.eval_bits(ty, tcx, param_env);

match ty.kind() {
ty::Float(ty::FloatTy::F32) => {
use rustc_apfloat::Float;
let a = rustc_apfloat::ieee::Single::from_bits(a);
let b = rustc_apfloat::ieee::Single::from_bits(b);
a.partial_cmp(&b)
}
ty::Float(ty::FloatTy::F64) => {
use rustc_apfloat::Float;
let a = rustc_apfloat::ieee::Double::from_bits(a);
let b = rustc_apfloat::ieee::Double::from_bits(b);
a.partial_cmp(&b)
}
ty::Int(ity) => {
use rustc_middle::ty::layout::IntegerExt;
let size = rustc_target::abi::Integer::from_int_ty(&tcx, *ity).size();
let a = size.sign_extend(a) as i128;
let b = size.sign_extend(b) as i128;
Some(a.cmp(&b))
}
ty::Uint(_) | ty::Char => Some(a.cmp(&b)),
_ => bug!(),
}
}
}

impl<'tcx> fmt::Display for Pat<'tcx> {
Expand Down Expand Up @@ -944,11 +1177,7 @@ impl<'tcx> fmt::Display for Pat<'tcx> {
PatKind::InlineConstant { def: _, ref subpattern } => {
write!(f, "{} (from inline const)", subpattern)
}
PatKind::Range(box PatRange { lo, hi, end }) => {
write!(f, "{lo}")?;
write!(f, "{end}")?;
write!(f, "{hi}")
}
PatKind::Range(ref range) => write!(f, "{range}"),
PatKind::Slice { ref prefix, ref slice, ref suffix }
| PatKind::Array { ref prefix, ref slice, ref suffix } => {
write!(f, "[")?;
Expand Down
78 changes: 43 additions & 35 deletions compiler/rustc_middle/src/ty/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use rustc_index::bit_set::GrowableBitSet;
use rustc_macros::HashStable;
use rustc_session::Limit;
use rustc_span::sym;
use rustc_target::abi::{Integer, IntegerType, Size};
use rustc_target::abi::{Integer, IntegerType, Primitive, Size};
use rustc_target::spec::abi::Abi;
use smallvec::SmallVec;
use std::{fmt, iter};
Expand Down Expand Up @@ -917,54 +917,62 @@ impl<'tcx> TypeFolder<TyCtxt<'tcx>> for OpaqueTypeExpander<'tcx> {
}

impl<'tcx> Ty<'tcx> {
/// Returns the `Size` for primitive types (bool, uint, int, char, float).
pub fn primitive_size(self, tcx: TyCtxt<'tcx>) -> Size {
match *self.kind() {
ty::Bool => Size::from_bytes(1),
ty::Char => Size::from_bytes(4),
ty::Int(ity) => Integer::from_int_ty(&tcx, ity).size(),
ty::Uint(uty) => Integer::from_uint_ty(&tcx, uty).size(),
ty::Float(ty::FloatTy::F32) => Primitive::F32.size(&tcx),
ty::Float(ty::FloatTy::F64) => Primitive::F64.size(&tcx),
_ => bug!("non primitive type"),
}
}

pub fn int_size_and_signed(self, tcx: TyCtxt<'tcx>) -> (Size, bool) {
let (int, signed) = match *self.kind() {
ty::Int(ity) => (Integer::from_int_ty(&tcx, ity), true),
ty::Uint(uty) => (Integer::from_uint_ty(&tcx, uty), false),
match *self.kind() {
ty::Int(ity) => (Integer::from_int_ty(&tcx, ity).size(), true),
ty::Uint(uty) => (Integer::from_uint_ty(&tcx, uty).size(), false),
_ => bug!("non integer discriminant"),
};
(int.size(), signed)
}
}

/// Returns the maximum value for the given numeric type (including `char`s)
/// or returns `None` if the type is not numeric.
pub fn numeric_max_val(self, tcx: TyCtxt<'tcx>) -> Option<ty::Const<'tcx>> {
let val = match self.kind() {
/// Returns the minimum and maximum values for the given numeric type (including `char`s) or
/// returns `None` if the type is not numeric.
pub fn numeric_min_and_max_as_bits(self, tcx: TyCtxt<'tcx>) -> Option<(u128, u128)> {
use rustc_apfloat::ieee::{Double, Single};
Some(match self.kind() {
ty::Int(_) | ty::Uint(_) => {
let (size, signed) = self.int_size_and_signed(tcx);
let val =
let min = if signed { size.truncate(size.signed_int_min() as u128) } else { 0 };
let max =
if signed { size.signed_int_max() as u128 } else { size.unsigned_int_max() };
Some(val)
(min, max)
}
ty::Char => Some(std::char::MAX as u128),
ty::Float(fty) => Some(match fty {
ty::FloatTy::F32 => rustc_apfloat::ieee::Single::INFINITY.to_bits(),
ty::FloatTy::F64 => rustc_apfloat::ieee::Double::INFINITY.to_bits(),
}),
_ => None,
};
ty::Char => (0, std::char::MAX as u128),
ty::Float(ty::FloatTy::F32) => {
((-Single::INFINITY).to_bits(), Single::INFINITY.to_bits())
}
ty::Float(ty::FloatTy::F64) => {
((-Double::INFINITY).to_bits(), Double::INFINITY.to_bits())
}
_ => return None,
})
}

val.map(|v| ty::Const::from_bits(tcx, v, ty::ParamEnv::empty().and(self)))
/// Returns the maximum value for the given numeric type (including `char`s)
/// or returns `None` if the type is not numeric.
pub fn numeric_max_val(self, tcx: TyCtxt<'tcx>) -> Option<ty::Const<'tcx>> {
self.numeric_min_and_max_as_bits(tcx)
.map(|(_, max)| ty::Const::from_bits(tcx, max, ty::ParamEnv::empty().and(self)))
}

/// Returns the minimum value for the given numeric type (including `char`s)
/// or returns `None` if the type is not numeric.
pub fn numeric_min_val(self, tcx: TyCtxt<'tcx>) -> Option<ty::Const<'tcx>> {
let val = match self.kind() {
ty::Int(_) | ty::Uint(_) => {
let (size, signed) = self.int_size_and_signed(tcx);
let val = if signed { size.truncate(size.signed_int_min() as u128) } else { 0 };
Some(val)
}
ty::Char => Some(0),
ty::Float(fty) => Some(match fty {
ty::FloatTy::F32 => (-::rustc_apfloat::ieee::Single::INFINITY).to_bits(),
ty::FloatTy::F64 => (-::rustc_apfloat::ieee::Double::INFINITY).to_bits(),
}),
_ => None,
};

val.map(|v| ty::Const::from_bits(tcx, v, ty::ParamEnv::empty().and(self)))
self.numeric_min_and_max_as_bits(tcx)
.map(|(min, _)| ty::Const::from_bits(tcx, min, ty::ParamEnv::empty().and(self)))
}

/// Checks whether values of this type `T` are *moved* or *copied*
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_mir_build/src/build/matches/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1027,7 +1027,7 @@ enum TestKind<'tcx> {
ty: Ty<'tcx>,
},

/// Test whether the value falls within an inclusive or exclusive range
/// Test whether the value falls within an inclusive or exclusive range.
Range(Box<PatRange<'tcx>>),

/// Test that the length of the slice is equal to `len`.
Expand Down
Loading
Loading