Skip to content

Commit

Permalink
Address review comments
Browse files Browse the repository at this point in the history
Clean up code and add comments.
Use InlineConstant to wrap range patterns.
  • Loading branch information
matthewjasper committed Oct 13, 2023
1 parent 98b4c1e commit 2d8ed99
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 38 deletions.
10 changes: 10 additions & 0 deletions compiler/rustc_middle/src/thir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -749,8 +749,18 @@ pub enum PatKind<'tcx> {
value: mir::Const<'tcx>,
},

/// Pattern lowered from an inline constant
InlineConstant {
/// Unevaluated version of the constant, we need this for:
/// 1. Having a reference that can be used by unsafety checking to visit nested
/// unevaluated constants.
/// 2. During THIR building we turn this back to a [Self::Constant] in range patterns.
value: mir::UnevaluatedConst<'tcx>,
/// If the inline constant is used in a range pattern, this subpattern represents the range
/// (if both ends are inline constants, there will be multiple InlineConstant wrappers).
///
/// Otherwise, the actual pattern that the constant lowered to. As with other constants, inline constants
/// are matched structurally where possible.
subpattern: Box<Pat<'tcx>>,
},

Expand Down
1 change: 0 additions & 1 deletion compiler/rustc_mir_build/src/build/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ fn mir_build<'tcx>(tcx: TyCtxt<'tcx>, def: LocalDefId) -> Body<'tcx> {
thir::BodyTy::Const(ty) => construct_const(tcx, def, thir, expr, ty),
};

tcx.ensure().check_match(def);
// this must run before MIR dump, because
// "not all control paths return a value" is reported here.
//
Expand Down
22 changes: 5 additions & 17 deletions compiler/rustc_mir_build/src/check_unsafety.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::errors::*;
use rustc_middle::thir::visit::{self, Visitor};

use rustc_hir as hir;
use rustc_middle::mir::{BorrowKind, Const};
use rustc_middle::mir::BorrowKind;
use rustc_middle::thir::*;
use rustc_middle::ty::print::with_no_trimmed_paths;
use rustc_middle::ty::{self, ParamEnv, Ty, TyCtxt};
Expand Down Expand Up @@ -124,7 +124,8 @@ impl<'tcx> UnsafetyVisitor<'_, 'tcx> {
/// Handle closures/generators/inline-consts, which is unsafecked with their parent body.
fn visit_inner_body(&mut self, def: LocalDefId) {
if let Ok((inner_thir, expr)) = self.tcx.thir_body(def) {
let _ = self.tcx.ensure_with_value().mir_built(def);
// Runs all other queries that depend on THIR.
self.tcx.ensure_with_value().mir_built(def);
let inner_thir = &inner_thir.steal();
let hir_context = self.tcx.hir().local_def_id_to_hir_id(def);
let mut inner_visitor = UnsafetyVisitor { thir: inner_thir, hir_context, ..*self };
Expand Down Expand Up @@ -278,20 +279,6 @@ impl<'a, 'tcx> Visitor<'a, 'tcx> for UnsafetyVisitor<'a, 'tcx> {
visit::walk_pat(self, pat);
self.inside_adt = old_inside_adt;
}
PatKind::Range(range) => {
if let Const::Unevaluated(c, _) = range.lo {
if let hir::def::DefKind::InlineConst = self.tcx.def_kind(c.def) {
let def_id = c.def.expect_local();
self.visit_inner_body(def_id);
}
}
if let Const::Unevaluated(c, _) = range.hi {
if let hir::def::DefKind::InlineConst = self.tcx.def_kind(c.def) {
let def_id = c.def.expect_local();
self.visit_inner_body(def_id);
}
}
}
PatKind::InlineConstant { value, .. } => {
let def_id = value.def.expect_local();
self.visit_inner_body(def_id);
Expand Down Expand Up @@ -804,7 +791,8 @@ pub fn thir_check_unsafety(tcx: TyCtxt<'_>, def: LocalDefId) {
}

let Ok((thir, expr)) = tcx.thir_body(def) else { return };
let _ = tcx.ensure_with_value().mir_built(def);
// Runs all other queries that depend on THIR.
tcx.ensure_with_value().mir_built(def);
let thir = &thir.steal();
// If `thir` is empty, a type error occurred, skip this body.
if thir.exprs.is_empty() {
Expand Down
44 changes: 26 additions & 18 deletions compiler/rustc_mir_build/src/thir/pattern/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ use rustc_index::Idx;
use rustc_middle::mir::interpret::{
ErrorHandled, GlobalId, LitToConstError, LitToConstInput, Scalar,
};
use rustc_middle::mir::{self, BorrowKind, Const, Mutability, UserTypeProjection};
use rustc_middle::mir::{
self, BorrowKind, Const, Mutability, UnevaluatedConst, UserTypeProjection,
};
use rustc_middle::thir::{Ascription, BindingMode, FieldPat, LocalVarId, Pat, PatKind, PatRange};
use rustc_middle::ty::layout::IntegerExt;
use rustc_middle::ty::{
Expand Down Expand Up @@ -88,19 +90,21 @@ impl<'a, 'tcx> PatCtxt<'a, 'tcx> {
fn lower_pattern_range_endpoint(
&mut self,
expr: Option<&'tcx hir::Expr<'tcx>>,
) -> Result<(Option<mir::Const<'tcx>>, Option<Ascription<'tcx>>), ErrorGuaranteed> {
) -> Result<
(Option<mir::Const<'tcx>>, Option<Ascription<'tcx>>, Option<UnevaluatedConst<'tcx>>),
ErrorGuaranteed,
> {
match expr {
None => Ok((None, None)),
None => Ok((None, None, None)),
Some(expr) => {
let (kind, ascr) = match self.lower_lit(expr) {
PatKind::InlineConstant { subpattern, value } => (
PatKind::Constant { value: Const::Unevaluated(value, subpattern.ty) },
None,
),
let (kind, ascr, inline_const) = match self.lower_lit(expr) {
PatKind::InlineConstant { subpattern, value } => {
(subpattern.kind, None, Some(value))
}
PatKind::AscribeUserType { ascription, subpattern: box Pat { kind, .. } } => {
(kind, Some(ascription))
(kind, Some(ascription), None)
}
kind => (kind, None),
kind => (kind, None, None),
};
let value = if let PatKind::Constant { value } = kind {
value
Expand All @@ -110,7 +114,7 @@ impl<'a, 'tcx> PatCtxt<'a, 'tcx> {
);
return Err(self.tcx.sess.delay_span_bug(expr.span, msg));
};
Ok((Some(value), ascr))
Ok((Some(value), ascr, inline_const))
}
}
}
Expand Down Expand Up @@ -181,8 +185,8 @@ impl<'a, 'tcx> PatCtxt<'a, 'tcx> {
return Err(self.tcx.sess.delay_span_bug(span, msg));
}

let (lo, lo_ascr) = self.lower_pattern_range_endpoint(lo_expr)?;
let (hi, hi_ascr) = self.lower_pattern_range_endpoint(hi_expr)?;
let (lo, lo_ascr, lo_inline) = self.lower_pattern_range_endpoint(lo_expr)?;
let (hi, hi_ascr, hi_inline) = self.lower_pattern_range_endpoint(hi_expr)?;

let lo = lo.unwrap_or_else(|| {
// Unwrap is ok because the type is known to be numeric.
Expand Down Expand Up @@ -241,6 +245,12 @@ impl<'a, 'tcx> PatCtxt<'a, 'tcx> {
};
}
}
for inline_const in [lo_inline, hi_inline] {
if let Some(value) = inline_const {
kind =
PatKind::InlineConstant { value, subpattern: Box::new(Pat { span, ty, kind }) };
}
}
Ok(kind)
}

Expand Down Expand Up @@ -606,11 +616,9 @@ impl<'a, 'tcx> PatCtxt<'a, 'tcx> {
// const eval path below.
// FIXME: investigate the performance impact of removing this.
let lit_input = match expr.kind {
hir::ExprKind::Lit(ref lit) => Some(LitToConstInput { lit: &lit.node, ty, neg: false }),
hir::ExprKind::Unary(hir::UnOp::Neg, ref expr) => match expr.kind {
hir::ExprKind::Lit(ref lit) => {
Some(LitToConstInput { lit: &lit.node, ty, neg: true })
}
hir::ExprKind::Lit(lit) => Some(LitToConstInput { lit: &lit.node, ty, neg: false }),
hir::ExprKind::Unary(hir::UnOp::Neg, expr) => match expr.kind {
hir::ExprKind::Lit(lit) => Some(LitToConstInput { lit: &lit.node, ty, neg: true }),
_ => None,
},
_ => None,
Expand Down
4 changes: 2 additions & 2 deletions compiler/rustc_mir_build/src/thir/print.rs
Original file line number Diff line number Diff line change
Expand Up @@ -692,7 +692,7 @@ impl<'a, 'tcx> ThirPrinter<'a, 'tcx> {
}
PatKind::Deref { subpattern } => {
print_indented!(self, "Deref { ", depth_lvl + 1);
print_indented!(self, "subpattern: ", depth_lvl + 2);
print_indented!(self, "subpattern:", depth_lvl + 2);
self.print_pat(subpattern, depth_lvl + 2);
print_indented!(self, "}", depth_lvl + 1);
}
Expand All @@ -704,7 +704,7 @@ impl<'a, 'tcx> ThirPrinter<'a, 'tcx> {
PatKind::InlineConstant { value, subpattern } => {
print_indented!(self, "InlineConstant {", depth_lvl + 1);
print_indented!(self, format!("value: {:?}", value), depth_lvl + 2);
print_indented!(self, "subpattern: ", depth_lvl + 2);
print_indented!(self, "subpattern:", depth_lvl + 2);
self.print_pat(subpattern, depth_lvl + 2);
print_indented!(self, "}", depth_lvl + 1);
}
Expand Down

0 comments on commit 2d8ed99

Please sign in to comment.