Skip to content

Commit 2d9fe91

Browse files
committed
Fix inline const pattern unsafety checking in THIR
THIR unsafety checking was getting a cycle of function unsafety checking -> building THIR for the function -> evaluating pattern inline constants in the function -> building MIR for the inline constant -> checking unsafety of functions (so that THIR can be stolen) This is fixed by not stealing THIR when generating MIR but instead when unsafety checking. This leaves an issue with pattern inline constants not being unsafety checked because they are evaluated away when generating THIR. To fix that we now represent inline constants in THIR patterns and visit them in THIR unsafety checking.
1 parent eb0f3ed commit 2d9fe91

24 files changed

+240
-52
lines changed

compiler/rustc_interface/src/passes.rs

+6-2
Original file line numberDiff line numberDiff line change
@@ -775,12 +775,16 @@ fn analysis(tcx: TyCtxt<'_>, (): ()) -> Result<()> {
775775
rustc_hir_analysis::check_crate(tcx)?;
776776

777777
sess.time("MIR_borrow_checking", || {
778-
tcx.hir().par_body_owners(|def_id| tcx.ensure().mir_borrowck(def_id));
778+
tcx.hir().par_body_owners(|def_id| {
779+
// Run THIR unsafety check because it's responsible for stealing
780+
// and deallocating THIR when enabled.
781+
tcx.ensure().thir_check_unsafety(def_id);
782+
tcx.ensure().mir_borrowck(def_id)
783+
});
779784
});
780785

781786
sess.time("MIR_effect_checking", || {
782787
for def_id in tcx.hir().body_owners() {
783-
tcx.ensure().thir_check_unsafety(def_id);
784788
if !tcx.sess.opts.unstable_opts.thir_unsafeck {
785789
rustc_mir_transform::check_unsafety::check_unsafety(tcx, def_id);
786790
}

compiler/rustc_middle/src/thir.rs

+10-1
Original file line numberDiff line numberDiff line change
@@ -635,7 +635,8 @@ impl<'tcx> Pat<'tcx> {
635635
Wild | Range(..) | Binding { subpattern: None, .. } | Constant { .. } => {}
636636
AscribeUserType { subpattern, .. }
637637
| Binding { subpattern: Some(subpattern), .. }
638-
| Deref { subpattern } => subpattern.walk_(it),
638+
| Deref { subpattern }
639+
| InlineConstant { subpattern, .. } => subpattern.walk_(it),
639640
Leaf { subpatterns } | Variant { subpatterns, .. } => {
640641
subpatterns.iter().for_each(|field| field.pattern.walk_(it))
641642
}
@@ -746,6 +747,11 @@ pub enum PatKind<'tcx> {
746747
value: mir::Const<'tcx>,
747748
},
748749

750+
InlineConstant {
751+
value: mir::UnevaluatedConst<'tcx>,
752+
subpattern: Box<Pat<'tcx>>,
753+
},
754+
749755
Range(Box<PatRange<'tcx>>),
750756

751757
/// Matches against a slice, checking the length and extracting elements.
@@ -901,6 +907,9 @@ impl<'tcx> fmt::Display for Pat<'tcx> {
901907
write!(f, "{subpattern}")
902908
}
903909
PatKind::Constant { value } => write!(f, "{value}"),
910+
PatKind::InlineConstant { value: _, ref subpattern } => {
911+
write!(f, "{} (from inline const)", subpattern)
912+
}
904913
PatKind::Range(box PatRange { lo, hi, end }) => {
905914
write!(f, "{lo}")?;
906915
write!(f, "{end}")?;

compiler/rustc_middle/src/thir/visit.rs

+4-3
Original file line numberDiff line numberDiff line change
@@ -233,16 +233,17 @@ pub fn walk_pat<'a, 'tcx: 'a, V: Visitor<'a, 'tcx>>(visitor: &mut V, pat: &Pat<'
233233
}
234234
}
235235
Constant { value: _ } => {}
236+
InlineConstant { value: _, subpattern } => visitor.visit_pat(subpattern),
236237
Range(_) => {}
237238
Slice { prefix, slice, suffix } | Array { prefix, slice, suffix } => {
238239
for subpattern in prefix.iter() {
239-
visitor.visit_pat(&subpattern);
240+
visitor.visit_pat(subpattern);
240241
}
241242
if let Some(pat) = slice {
242-
visitor.visit_pat(&pat);
243+
visitor.visit_pat(pat);
243244
}
244245
for subpattern in suffix.iter() {
245-
visitor.visit_pat(&subpattern);
246+
visitor.visit_pat(subpattern);
246247
}
247248
}
248249
Or { pats } => {

compiler/rustc_mir_build/src/build/matches/mod.rs

+4
Original file line numberDiff line numberDiff line change
@@ -842,6 +842,10 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
842842
self.visit_primary_bindings(subpattern, subpattern_user_ty, f)
843843
}
844844

845+
PatKind::InlineConstant { ref subpattern, .. } => {
846+
self.visit_primary_bindings(subpattern, pattern_user_ty.clone(), f)
847+
}
848+
845849
PatKind::Leaf { ref subpatterns } => {
846850
for subpattern in subpatterns {
847851
let subpattern_user_ty = pattern_user_ty.clone().leaf(subpattern.field);

compiler/rustc_mir_build/src/build/matches/simplify.rs

+20-4
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,12 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
204204
Err(match_pair)
205205
}
206206

207+
PatKind::InlineConstant { subpattern: ref pattern, value: _ } => {
208+
candidate.match_pairs.push(MatchPair::new(match_pair.place, pattern, self));
209+
210+
Ok(())
211+
}
212+
207213
PatKind::Range(box PatRange { lo, hi, end }) => {
208214
let (range, bias) = match *lo.ty().kind() {
209215
ty::Char => {
@@ -229,11 +235,21 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
229235
// correct the comparison. This is achieved by XORing with a bias (see
230236
// pattern/_match.rs for another pertinent example of this pattern).
231237
//
232-
// Also, for performance, it's important to only do the second `try_to_bits` if
233-
// necessary.
234-
let lo = lo.try_to_bits(sz).unwrap() ^ bias;
238+
// Also, for performance, it's important to only do the second
239+
// `try_eval_scalar_int` if necessary.
240+
let lo = lo
241+
.try_eval_scalar_int(self.tcx, self.param_env)
242+
.unwrap()
243+
.to_bits(sz)
244+
.unwrap()
245+
^ bias;
235246
if lo <= min {
236-
let hi = hi.try_to_bits(sz).unwrap() ^ bias;
247+
let hi = hi
248+
.try_eval_scalar_int(self.tcx, self.param_env)
249+
.unwrap()
250+
.to_bits(sz)
251+
.unwrap()
252+
^ bias;
237253
if hi > max || hi == max && end == RangeEnd::Included {
238254
// Irrefutable pattern match.
239255
return Ok(());

compiler/rustc_mir_build/src/build/matches/test.rs

+2
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
7373
PatKind::Or { .. } => bug!("or-patterns should have already been handled"),
7474

7575
PatKind::AscribeUserType { .. }
76+
| PatKind::InlineConstant { .. }
7677
| PatKind::Array { .. }
7778
| PatKind::Wild
7879
| PatKind::Binding { .. }
@@ -110,6 +111,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
110111
| PatKind::Or { .. }
111112
| PatKind::Binding { .. }
112113
| PatKind::AscribeUserType { .. }
114+
| PatKind::InlineConstant { .. }
113115
| PatKind::Leaf { .. }
114116
| PatKind::Deref { .. } => {
115117
// don't know how to add these patterns to a switch

compiler/rustc_mir_build/src/build/mod.rs

+15-10
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,7 @@ pub(crate) fn closure_saved_names_of_captured_variables<'tcx>(
5353
}
5454

5555
/// Construct the MIR for a given `DefId`.
56-
fn mir_build(tcx: TyCtxt<'_>, def: LocalDefId) -> Body<'_> {
57-
// Ensure unsafeck and abstract const building is ran before we steal the THIR.
58-
tcx.ensure_with_value()
59-
.thir_check_unsafety(tcx.typeck_root_def_id(def.to_def_id()).expect_local());
56+
fn mir_build<'tcx>(tcx: TyCtxt<'tcx>, def: LocalDefId) -> Body<'tcx> {
6057
tcx.ensure_with_value().thir_abstract_const(def);
6158
if let Err(e) = tcx.check_match(def) {
6259
return construct_error(tcx, def, e);
@@ -65,9 +62,10 @@ fn mir_build(tcx: TyCtxt<'_>, def: LocalDefId) -> Body<'_> {
6562
let body = match tcx.thir_body(def) {
6663
Err(error_reported) => construct_error(tcx, def, error_reported),
6764
Ok((thir, expr)) => {
68-
// We ran all queries that depended on THIR at the beginning
69-
// of `mir_build`, so now we can steal it
70-
let thir = thir.steal();
65+
let build_mir = |thir: &Thir<'tcx>| match thir.body_type {
66+
thir::BodyTy::Fn(fn_sig) => construct_fn(tcx, def, thir, expr, fn_sig),
67+
thir::BodyTy::Const(ty) => construct_const(tcx, def, thir, expr, ty),
68+
};
7169

7270
tcx.ensure().check_match(def);
7371
// this must run before MIR dump, because
@@ -76,9 +74,16 @@ fn mir_build(tcx: TyCtxt<'_>, def: LocalDefId) -> Body<'_> {
7674
// maybe move the check to a MIR pass?
7775
tcx.ensure().check_liveness(def);
7876

79-
match thir.body_type {
80-
thir::BodyTy::Fn(fn_sig) => construct_fn(tcx, def, &thir, expr, fn_sig),
81-
thir::BodyTy::Const(ty) => construct_const(tcx, def, &thir, expr, ty),
77+
if tcx.sess.opts.unstable_opts.thir_unsafeck {
78+
// Don't steal here if THIR unsafeck is being used. Instead
79+
// steal in unsafeck. This is so that pattern inline constants
80+
// can be evaluated as part of building the THIR of the parent
81+
// function without a cycle.
82+
build_mir(&thir.borrow())
83+
} else {
84+
// We ran all queries that depended on THIR at the beginning
85+
// of `mir_build`, so now we can steal it
86+
build_mir(&thir.steal())
8287
}
8388
}
8489
};

compiler/rustc_mir_build/src/check_unsafety.rs

+24-3
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use crate::errors::*;
33
use rustc_middle::thir::visit::{self, Visitor};
44

55
use rustc_hir as hir;
6-
use rustc_middle::mir::BorrowKind;
6+
use rustc_middle::mir::{BorrowKind, Const};
77
use rustc_middle::thir::*;
88
use rustc_middle::ty::print::with_no_trimmed_paths;
99
use rustc_middle::ty::{self, ParamEnv, Ty, TyCtxt};
@@ -124,7 +124,8 @@ impl<'tcx> UnsafetyVisitor<'_, 'tcx> {
124124
/// Handle closures/generators/inline-consts, which is unsafecked with their parent body.
125125
fn visit_inner_body(&mut self, def: LocalDefId) {
126126
if let Ok((inner_thir, expr)) = self.tcx.thir_body(def) {
127-
let inner_thir = &inner_thir.borrow();
127+
let _ = self.tcx.ensure_with_value().mir_built(def);
128+
let inner_thir = &inner_thir.steal();
128129
let hir_context = self.tcx.hir().local_def_id_to_hir_id(def);
129130
let mut inner_visitor = UnsafetyVisitor { thir: inner_thir, hir_context, ..*self };
130131
inner_visitor.visit_expr(&inner_thir[expr]);
@@ -224,6 +225,7 @@ impl<'a, 'tcx> Visitor<'a, 'tcx> for UnsafetyVisitor<'a, 'tcx> {
224225
PatKind::Wild |
225226
// these just wrap other patterns
226227
PatKind::Or { .. } |
228+
PatKind::InlineConstant { .. } |
227229
PatKind::AscribeUserType { .. } => {}
228230
}
229231
};
@@ -276,6 +278,24 @@ impl<'a, 'tcx> Visitor<'a, 'tcx> for UnsafetyVisitor<'a, 'tcx> {
276278
visit::walk_pat(self, pat);
277279
self.inside_adt = old_inside_adt;
278280
}
281+
PatKind::Range(range) => {
282+
if let Const::Unevaluated(c, _) = range.lo {
283+
if let hir::def::DefKind::InlineConst = self.tcx.def_kind(c.def) {
284+
let def_id = c.def.expect_local();
285+
self.visit_inner_body(def_id);
286+
}
287+
}
288+
if let Const::Unevaluated(c, _) = range.hi {
289+
if let hir::def::DefKind::InlineConst = self.tcx.def_kind(c.def) {
290+
let def_id = c.def.expect_local();
291+
self.visit_inner_body(def_id);
292+
}
293+
}
294+
}
295+
PatKind::InlineConstant { value, .. } => {
296+
let def_id = value.def.expect_local();
297+
self.visit_inner_body(def_id);
298+
}
279299
_ => {
280300
visit::walk_pat(self, pat);
281301
}
@@ -784,7 +804,8 @@ pub fn thir_check_unsafety(tcx: TyCtxt<'_>, def: LocalDefId) {
784804
}
785805

786806
let Ok((thir, expr)) = tcx.thir_body(def) else { return };
787-
let thir = &thir.borrow();
807+
let _ = tcx.ensure_with_value().mir_built(def);
808+
let thir = &thir.steal();
788809
// If `thir` is empty, a type error occurred, skip this body.
789810
if thir.exprs.is_empty() {
790811
return;

compiler/rustc_mir_build/src/thir/pattern/deconstruct_pat.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -1243,7 +1243,8 @@ impl<'p, 'tcx> DeconstructedPat<'p, 'tcx> {
12431243
let ctor;
12441244
let fields;
12451245
match &pat.kind {
1246-
PatKind::AscribeUserType { subpattern, .. } => return mkpat(subpattern),
1246+
PatKind::AscribeUserType { subpattern, .. }
1247+
| PatKind::InlineConstant { subpattern, .. } => return mkpat(subpattern),
12471248
PatKind::Binding { subpattern: Some(subpat), .. } => return mkpat(subpat),
12481249
PatKind::Binding { subpattern: None, .. } | PatKind::Wild => {
12491250
ctor = Wildcard;

compiler/rustc_mir_build/src/thir/pattern/mod.rs

+9-3
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,9 @@ impl<'a, 'tcx> PatCtxt<'a, 'tcx> {
9090
expr: &'tcx hir::Expr<'tcx>,
9191
) -> (PatKind<'tcx>, Option<Ascription<'tcx>>) {
9292
match self.lower_lit(expr) {
93+
PatKind::InlineConstant { subpattern, value } => {
94+
(PatKind::Constant { value: Const::Unevaluated(value, subpattern.ty) }, None)
95+
}
9396
PatKind::AscribeUserType { ascription, subpattern: box Pat { kind, .. } } => {
9497
(kind, Some(ascription))
9598
}
@@ -631,13 +634,13 @@ impl<'a, 'tcx> PatCtxt<'a, 'tcx> {
631634
if let Ok(Some(valtree)) =
632635
self.tcx.const_eval_resolve_for_typeck(self.param_env, ct, Some(span))
633636
{
634-
self.const_to_pat(
637+
let subpattern = self.const_to_pat(
635638
Const::Ty(ty::Const::new_value(self.tcx, valtree, ty)),
636639
id,
637640
span,
638641
None,
639-
)
640-
.kind
642+
);
643+
PatKind::InlineConstant { subpattern, value: uneval }
641644
} else {
642645
// If that fails, convert it to an opaque constant pattern.
643646
match tcx.const_eval_resolve(self.param_env, uneval, Some(span)) {
@@ -819,6 +822,9 @@ impl<'tcx> PatternFoldable<'tcx> for PatKind<'tcx> {
819822
PatKind::Deref { subpattern: subpattern.fold_with(folder) }
820823
}
821824
PatKind::Constant { value } => PatKind::Constant { value },
825+
PatKind::InlineConstant { value, subpattern: ref pattern } => {
826+
PatKind::InlineConstant { value, subpattern: pattern.fold_with(folder) }
827+
}
822828
PatKind::Range(ref range) => PatKind::Range(range.clone()),
823829
PatKind::Slice { ref prefix, ref slice, ref suffix } => PatKind::Slice {
824830
prefix: prefix.fold_with(folder),

compiler/rustc_mir_build/src/thir/print.rs

+7
Original file line numberDiff line numberDiff line change
@@ -701,6 +701,13 @@ impl<'a, 'tcx> ThirPrinter<'a, 'tcx> {
701701
print_indented!(self, format!("value: {:?}", value), depth_lvl + 2);
702702
print_indented!(self, "}", depth_lvl + 1);
703703
}
704+
PatKind::InlineConstant { value, subpattern } => {
705+
print_indented!(self, "InlineConstant {", depth_lvl + 1);
706+
print_indented!(self, format!("value: {:?}", value), depth_lvl + 2);
707+
print_indented!(self, "subpattern: ", depth_lvl + 2);
708+
self.print_pat(subpattern, depth_lvl + 2);
709+
print_indented!(self, "}", depth_lvl + 1);
710+
}
704711
PatKind::Range(pat_range) => {
705712
print_indented!(self, format!("Range ( {:?} )", pat_range), depth_lvl + 1);
706713
}

tests/ui/async-await/async-unsafe-fn-call-in-safe.mir.stderr

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ LL | S::f();
2323
= note: consult the function's documentation for information on how to avoid undefined behavior
2424

2525
error[E0133]: call to unsafe function is unsafe and requires unsafe function or block
26-
--> $DIR/async-unsafe-fn-call-in-safe.rs:24:5
26+
--> $DIR/async-unsafe-fn-call-in-safe.rs:26:5
2727
|
2828
LL | f();
2929
| ^^^ call to unsafe function

tests/ui/async-await/async-unsafe-fn-call-in-safe.rs

+6-2
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@ async fn g() {
2020
}
2121

2222
fn main() {
23-
S::f(); //[mir]~ ERROR call to unsafe function is unsafe
24-
f(); //[mir]~ ERROR call to unsafe function is unsafe
23+
S::f();
24+
//[mir]~^ ERROR call to unsafe function is unsafe
25+
//[thir]~^^ ERROR call to unsafe function `S::f` is unsafe
26+
f();
27+
//[mir]~^ ERROR call to unsafe function is unsafe
28+
//[thir]~^^ ERROR call to unsafe function `f` is unsafe
2529
}

tests/ui/async-await/async-unsafe-fn-call-in-safe.thir.stderr

+17-1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,22 @@ LL | f();
1414
|
1515
= note: consult the function's documentation for information on how to avoid undefined behavior
1616

17-
error: aborting due to 2 previous errors
17+
error[E0133]: call to unsafe function `S::f` is unsafe and requires unsafe function or block
18+
--> $DIR/async-unsafe-fn-call-in-safe.rs:23:5
19+
|
20+
LL | S::f();
21+
| ^^^^^^ call to unsafe function
22+
|
23+
= note: consult the function's documentation for information on how to avoid undefined behavior
24+
25+
error[E0133]: call to unsafe function `f` is unsafe and requires unsafe function or block
26+
--> $DIR/async-unsafe-fn-call-in-safe.rs:26:5
27+
|
28+
LL | f();
29+
| ^^^ call to unsafe function
30+
|
31+
= note: consult the function's documentation for information on how to avoid undefined behavior
32+
33+
error: aborting due to 4 previous errors
1834

1935
For more information about this error, try `rustc --explain E0133`.

tests/ui/consts/const-extern-fn/const-extern-fn-requires-unsafe.rs

+1
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,5 @@ fn main() {
1111
//[thir]~^^ call to unsafe function `foo` is unsafe and requires unsafe function or block
1212
foo();
1313
//[mir]~^ ERROR call to unsafe function is unsafe and requires unsafe function or block
14+
//[thir]~^^ ERROR call to unsafe function `foo` is unsafe and requires unsafe function or block
1415
}
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,11 @@
1+
error[E0133]: call to unsafe function `foo` is unsafe and requires unsafe function or block
2+
--> $DIR/const-extern-fn-requires-unsafe.rs:12:5
3+
|
4+
LL | foo();
5+
| ^^^^^ call to unsafe function
6+
|
7+
= note: consult the function's documentation for information on how to avoid undefined behavior
8+
19
error[E0133]: call to unsafe function `foo` is unsafe and requires unsafe function or block
210
--> $DIR/const-extern-fn-requires-unsafe.rs:9:17
311
|
@@ -6,6 +14,6 @@ LL | let a: [u8; foo()];
614
|
715
= note: consult the function's documentation for information on how to avoid undefined behavior
816

9-
error: aborting due to previous error
17+
error: aborting due to 2 previous errors
1018

1119
For more information about this error, try `rustc --explain E0133`.

0 commit comments

Comments
 (0)