Skip to content

Commit 11637c9

Browse files
committed
pattern lowering: make sure we never call user-defined PartialEq instances
1 parent 719ecb9 commit 11637c9

File tree

3 files changed

+33
-37
lines changed

3 files changed

+33
-37
lines changed

compiler/rustc_middle/src/thir.rs

+2-5
Original file line numberDiff line numberDiff line change
@@ -783,16 +783,13 @@ pub enum PatKind<'tcx> {
783783
},
784784

785785
/// One of the following:
786-
/// * `&str` (represented as a valtree), which will be handled as a string pattern and thus
786+
/// * `&str`/`&[u8]` (represented as a valtree), which will be handled as a string pattern and thus
787787
/// exhaustiveness checking will detect if you use the same string twice in different
788788
/// patterns.
789789
/// * integer, bool, char or float (represented as a valtree), which will be handled by
790790
/// exhaustiveness to cover exactly its own value, similar to `&str`, but these values are
791791
/// much simpler.
792-
/// * Opaque constants (represented as `mir::ConstValue`), that must not be matched
793-
/// structurally. So anything that does not derive `PartialEq` and `Eq`.
794-
///
795-
/// These are always compared with the matched place using (the semantics of) `PartialEq`.
792+
/// * `String`, if `string_deref_patterns` is enabled.
796793
Constant {
797794
value: mir::Const<'tcx>,
798795
},

compiler/rustc_mir_build/src/build/matches/test.rs

+17-28
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,8 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
144144
&& tcx.is_lang_item(def.did(), LangItem::String)
145145
{
146146
if !tcx.features().string_deref_patterns {
147-
bug!(
147+
span_bug!(
148+
test.span,
148149
"matching on `String` went through without enabling string_deref_patterns"
149150
);
150151
}
@@ -432,40 +433,28 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
432433
}
433434
}
434435

435-
match *ty.kind() {
436-
ty::Ref(_, deref_ty, _) => ty = deref_ty,
437-
_ => {
438-
// non_scalar_compare called on non-reference type
439-
let temp = self.temp(ty, source_info.span);
440-
self.cfg.push_assign(block, source_info, temp, Rvalue::Use(expect));
441-
let ref_ty = Ty::new_imm_ref(self.tcx, self.tcx.lifetimes.re_erased, ty);
442-
let ref_temp = self.temp(ref_ty, source_info.span);
443-
444-
self.cfg.push_assign(
445-
block,
446-
source_info,
447-
ref_temp,
448-
Rvalue::Ref(self.tcx.lifetimes.re_erased, BorrowKind::Shared, temp),
449-
);
450-
expect = Operand::Move(ref_temp);
451-
452-
let ref_temp = self.temp(ref_ty, source_info.span);
453-
self.cfg.push_assign(
454-
block,
455-
source_info,
456-
ref_temp,
457-
Rvalue::Ref(self.tcx.lifetimes.re_erased, BorrowKind::Shared, val),
458-
);
459-
val = ref_temp;
436+
// Figure out the type on which we are calling `PartialEq`. This involves an extra wrapping
437+
// reference: we can only compare two `&T`, and then compare_ty will be `T`.
438+
// Make sure that we do *not* call any user-defined code here.
439+
// The only types that can end up here are string and byte literals,
440+
// which have their comparison defined in `core`.
441+
// (Interestingly this means that exhaustiveness analysis relies, for soundness,
442+
// on the `PartialEq` impls for `str` and `[u8]` to b correct!)
443+
let compare_ty = match *ty.kind() {
444+
ty::Ref(_, deref_ty, _)
445+
if deref_ty == self.tcx.types.str_ || deref_ty != self.tcx.types.u8 =>
446+
{
447+
deref_ty
460448
}
461-
}
449+
_ => span_bug!(source_info.span, "invalid type for non-scalar compare: {}", ty),
450+
};
462451

463452
let eq_def_id = self.tcx.require_lang_item(LangItem::PartialEq, Some(source_info.span));
464453
let method = trait_method(
465454
self.tcx,
466455
eq_def_id,
467456
sym::eq,
468-
self.tcx.with_opt_host_effect_param(self.def_id, eq_def_id, [ty, ty]),
457+
self.tcx.with_opt_host_effect_param(self.def_id, eq_def_id, [compare_ty, compare_ty]),
469458
);
470459

471460
let bool_ty = self.tcx.types.bool;

compiler/rustc_pattern_analysis/src/rustc.rs

+14-4
Original file line numberDiff line numberDiff line change
@@ -462,7 +462,12 @@ impl<'p, 'tcx: 'p> RustcPatCtxt<'p, 'tcx> {
462462
// This is a box pattern.
463463
ty::Adt(adt, ..) if adt.is_box() => Struct,
464464
ty::Ref(..) => Ref,
465-
_ => bug!("pattern has unexpected type: pat: {:?}, ty: {:?}", pat, ty),
465+
_ => span_bug!(
466+
pat.span,
467+
"pattern has unexpected type: pat: {:?}, ty: {:?}",
468+
pat.kind,
469+
ty.inner()
470+
),
466471
};
467472
}
468473
PatKind::DerefPattern { .. } => {
@@ -518,7 +523,12 @@ impl<'p, 'tcx: 'p> RustcPatCtxt<'p, 'tcx> {
518523
.map(|ipat| self.lower_pat(&ipat.pattern).at_index(ipat.field.index()))
519524
.collect();
520525
}
521-
_ => bug!("pattern has unexpected type: pat: {:?}, ty: {:?}", pat, ty),
526+
_ => span_bug!(
527+
pat.span,
528+
"pattern has unexpected type: pat: {:?}, ty: {}",
529+
pat.kind,
530+
ty.inner()
531+
),
522532
}
523533
}
524534
PatKind::Constant { value } => {
@@ -663,7 +673,7 @@ impl<'p, 'tcx: 'p> RustcPatCtxt<'p, 'tcx> {
663673
}
664674
}
665675
}
666-
_ => bug!("invalid type for range pattern: {}", ty.inner()),
676+
_ => span_bug!(pat.span, "invalid type for range pattern: {}", ty.inner()),
667677
};
668678
fields = vec![];
669679
arity = 0;
@@ -674,7 +684,7 @@ impl<'p, 'tcx: 'p> RustcPatCtxt<'p, 'tcx> {
674684
Some(length.eval_target_usize(cx.tcx, cx.param_env) as usize)
675685
}
676686
ty::Slice(_) => None,
677-
_ => span_bug!(pat.span, "bad ty {:?} for slice pattern", ty),
687+
_ => span_bug!(pat.span, "bad ty {} for slice pattern", ty.inner()),
678688
};
679689
let kind = if slice.is_some() {
680690
SliceKind::VarLen(prefix.len(), suffix.len())

0 commit comments

Comments
 (0)