diff --git a/compiler/rustc_middle/src/thir.rs b/compiler/rustc_middle/src/thir.rs index bbcd509c55869..6783bbf8bf42f 100644 --- a/compiler/rustc_middle/src/thir.rs +++ b/compiler/rustc_middle/src/thir.rs @@ -800,9 +800,9 @@ pub enum PatKind<'tcx> { }, /// One of the following: - /// * `&str`/`&[u8]` (represented as a valtree), which will be handled as a string/slice pattern - /// and thus exhaustiveness checking will detect if you use the same string/slice twice in - /// different patterns. + /// * `&str` (represented as a valtree), which will be handled as a string pattern and thus + /// exhaustiveness checking will detect if you use the same string twice in different + /// patterns. /// * integer, bool, char or float (represented as a valtree), which will be handled by /// exhaustiveness to cover exactly its own value, similar to `&str`, but these values are /// much simpler. diff --git a/compiler/rustc_mir_build/src/builder/matches/mod.rs b/compiler/rustc_mir_build/src/builder/matches/mod.rs index ea341b604e0be..710538ef4b867 100644 --- a/compiler/rustc_mir_build/src/builder/matches/mod.rs +++ b/compiler/rustc_mir_build/src/builder/matches/mod.rs @@ -1326,8 +1326,8 @@ enum TestKind<'tcx> { Eq { value: Const<'tcx>, // Integer types are handled by `SwitchInt`, and constants with ADT - // types are converted back into patterns, so this can only be `&str`, - // `&[T]`, `f32` or `f64`. + // types and `&[T]` types are converted back into patterns, so this can + // only be `&str`, `f32` or `f64`. ty: Ty<'tcx>, }, diff --git a/compiler/rustc_mir_build/src/builder/matches/test.rs b/compiler/rustc_mir_build/src/builder/matches/test.rs index e5d61bc9e556a..7ef9e48326f63 100644 --- a/compiler/rustc_mir_build/src/builder/matches/test.rs +++ b/compiler/rustc_mir_build/src/builder/matches/test.rs @@ -11,7 +11,6 @@ use std::sync::Arc; use rustc_data_structures::fx::FxIndexMap; use rustc_hir::{LangItem, RangeEnd}; use rustc_middle::mir::*; -use rustc_middle::ty::adjustment::PointerCoercion; use rustc_middle::ty::util::IntTypeExt; use rustc_middle::ty::{self, GenericArg, Ty, TyCtxt}; use rustc_middle::{bug, span_bug}; @@ -178,21 +177,30 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { _ => {} } + assert_eq!(expect_ty, ty); if !ty.is_scalar() { // Use `PartialEq::eq` instead of `BinOp::Eq` // (the binop can only handle primitives) - self.non_scalar_compare( + // Make sure that we do *not* call any user-defined code here. + // The only type that can end up here is string literals, which have their + // comparison defined in `core`. + // (Interestingly this means that exhaustiveness analysis relies, for soundness, + // on the `PartialEq` impl for `str` to b correct!) + match *ty.kind() { + ty::Ref(_, deref_ty, _) if deref_ty == self.tcx.types.str_ => {} + _ => { + span_bug!(source_info.span, "invalid type for non-scalar compare: {ty}") + } + }; + self.string_compare( block, success_block, fail_block, source_info, expect, - expect_ty, Operand::Copy(place), - ty, ); } else { - assert_eq!(expect_ty, ty); self.compare( block, success_block, @@ -370,97 +378,19 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { ); } - /// Compare two values using `::eq`. - /// If the values are already references, just call it directly, otherwise - /// take a reference to the values first and then call it. - fn non_scalar_compare( + /// Compare two values of type `&str` using `::eq`. + fn string_compare( &mut self, block: BasicBlock, success_block: BasicBlock, fail_block: BasicBlock, source_info: SourceInfo, - mut expect: Operand<'tcx>, - expect_ty: Ty<'tcx>, - mut val: Operand<'tcx>, - mut ty: Ty<'tcx>, + expect: Operand<'tcx>, + val: Operand<'tcx>, ) { - // If we're using `b"..."` as a pattern, we need to insert an - // unsizing coercion, as the byte string has the type `&[u8; N]`. - // - // We want to do this even when the scrutinee is a reference to an - // array, so we can call `<[u8]>::eq` rather than having to find an - // `<[u8; N]>::eq`. - let unsize = |ty: Ty<'tcx>| match ty.kind() { - ty::Ref(region, rty, _) => match rty.kind() { - ty::Array(inner_ty, n) => Some((region, inner_ty, n)), - _ => None, - }, - _ => None, - }; - let opt_ref_ty = unsize(ty); - let opt_ref_test_ty = unsize(expect_ty); - match (opt_ref_ty, opt_ref_test_ty) { - // nothing to do, neither is an array - (None, None) => {} - (Some((region, elem_ty, _)), _) | (None, Some((region, elem_ty, _))) => { - let tcx = self.tcx; - // make both a slice - ty = Ty::new_imm_ref(tcx, *region, Ty::new_slice(tcx, *elem_ty)); - if opt_ref_ty.is_some() { - let temp = self.temp(ty, source_info.span); - self.cfg.push_assign( - block, - source_info, - temp, - Rvalue::Cast( - CastKind::PointerCoercion( - PointerCoercion::Unsize, - CoercionSource::Implicit, - ), - val, - ty, - ), - ); - val = Operand::Copy(temp); - } - if opt_ref_test_ty.is_some() { - let slice = self.temp(ty, source_info.span); - self.cfg.push_assign( - block, - source_info, - slice, - Rvalue::Cast( - CastKind::PointerCoercion( - PointerCoercion::Unsize, - CoercionSource::Implicit, - ), - expect, - ty, - ), - ); - expect = Operand::Move(slice); - } - } - } - - // Figure out the type on which we are calling `PartialEq`. This involves an extra wrapping - // reference: we can only compare two `&T`, and then compare_ty will be `T`. - // Make sure that we do *not* call any user-defined code here. - // The only types that can end up here are string and byte literals, - // which have their comparison defined in `core`. - // (Interestingly this means that exhaustiveness analysis relies, for soundness, - // on the `PartialEq` impls for `str` and `[u8]` to b correct!) - let compare_ty = match *ty.kind() { - ty::Ref(_, deref_ty, _) - if deref_ty == self.tcx.types.str_ || deref_ty != self.tcx.types.u8 => - { - deref_ty - } - _ => span_bug!(source_info.span, "invalid type for non-scalar compare: {}", ty), - }; - + let str_ty = self.tcx.types.str_; let eq_def_id = self.tcx.require_lang_item(LangItem::PartialEq, Some(source_info.span)); - let method = trait_method(self.tcx, eq_def_id, sym::eq, [compare_ty, compare_ty]); + let method = trait_method(self.tcx, eq_def_id, sym::eq, [str_ty, str_ty]); let bool_ty = self.tcx.types.bool; let eq_result = self.temp(bool_ty, source_info.span);