Skip to content

Commit ec56537

Browse files
committed
Auto merge of #105356 - JakobDegen:more-custom-mir, r=oli-obk
Custom MIR: Many more improvements Commits are each atomic changes, best reviewed one at a time, with the exception that the last commit includes all the documentation. ### First commit Unsafetyck was not correctly disabled before for `dialect = "built"` custom MIR. This is fixed and a regression test is added. ### Second commit Implements `Discriminant`, `SetDiscriminant`, and `SwitchInt`. ### Third commit Implements indexing, field, and variant projections. ### Fourth commit Documents the previous commits and everything else. There is some amount of weirdness here due to having to beat Rust syntax into cooperating with MIR concepts, but it hopefully should not be too much. All of it is documented. r? `@oli-obk`
2 parents 4954a7e + b580f29 commit ec56537

20 files changed

+731
-32
lines changed

compiler/rustc_middle/src/mir/mod.rs

+5
Original file line numberDiff line numberDiff line change
@@ -533,6 +533,11 @@ impl<'tcx> Body<'tcx> {
533533
};
534534
injection_phase > self.phase
535535
}
536+
537+
#[inline]
538+
pub fn is_custom_mir(&self) -> bool {
539+
self.injection_phase.is_some()
540+
}
536541
}
537542

538543
#[derive(Copy, Clone, PartialEq, Eq, Debug, TyEncodable, TyDecodable, HashStable)]

compiler/rustc_mir_build/src/build/custom/mod.rs

+3-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ use rustc_index::vec::IndexVec;
2525
use rustc_middle::{
2626
mir::*,
2727
thir::*,
28-
ty::{Ty, TyCtxt},
28+
ty::{ParamEnv, Ty, TyCtxt},
2929
};
3030
use rustc_span::Span;
3131

@@ -78,6 +78,7 @@ pub(super) fn build_custom_mir<'tcx>(
7878

7979
let mut pctxt = ParseCtxt {
8080
tcx,
81+
param_env: tcx.param_env(did),
8182
thir,
8283
source_scope: OUTERMOST_SOURCE_SCOPE,
8384
body: &mut body,
@@ -132,6 +133,7 @@ fn parse_attribute(attr: &Attribute) -> MirPhase {
132133

133134
struct ParseCtxt<'tcx, 'body> {
134135
tcx: TyCtxt<'tcx>,
136+
param_env: ParamEnv<'tcx>,
135137
thir: &'body Thir<'tcx>,
136138
source_scope: SourceScope,
137139

compiler/rustc_mir_build/src/build/custom/parse/instruction.rs

+106-8
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
use rustc_middle::mir::interpret::{ConstValue, Scalar};
2+
use rustc_middle::mir::tcx::PlaceTy;
23
use rustc_middle::{mir::*, thir::*, ty};
4+
use rustc_span::Span;
5+
use rustc_target::abi::VariantIdx;
6+
7+
use crate::build::custom::ParseError;
8+
use crate::build::expr::as_constant::as_constant_inner;
39

410
use super::{parse_by_kind, PResult, ParseCtxt};
511

@@ -12,6 +18,14 @@ impl<'tcx, 'body> ParseCtxt<'tcx, 'body> {
1218
@call("mir_retag_raw", args) => {
1319
Ok(StatementKind::Retag(RetagKind::Raw, Box::new(self.parse_place(args[0])?)))
1420
},
21+
@call("mir_set_discriminant", args) => {
22+
let place = self.parse_place(args[0])?;
23+
let var = self.parse_integer_literal(args[1])? as u32;
24+
Ok(StatementKind::SetDiscriminant {
25+
place: Box::new(place),
26+
variant_index: VariantIdx::from_u32(var),
27+
})
28+
},
1529
ExprKind::Assign { lhs, rhs } => {
1630
let lhs = self.parse_place(*lhs)?;
1731
let rhs = self.parse_rvalue(*rhs)?;
@@ -21,18 +35,60 @@ impl<'tcx, 'body> ParseCtxt<'tcx, 'body> {
2135
}
2236

2337
pub fn parse_terminator(&self, expr_id: ExprId) -> PResult<TerminatorKind<'tcx>> {
24-
parse_by_kind!(self, expr_id, _, "terminator",
38+
parse_by_kind!(self, expr_id, expr, "terminator",
2539
@call("mir_return", _args) => {
2640
Ok(TerminatorKind::Return)
2741
},
2842
@call("mir_goto", args) => {
2943
Ok(TerminatorKind::Goto { target: self.parse_block(args[0])? } )
3044
},
45+
ExprKind::Match { scrutinee, arms } => {
46+
let discr = self.parse_operand(*scrutinee)?;
47+
self.parse_match(arms, expr.span).map(|t| TerminatorKind::SwitchInt { discr, targets: t })
48+
},
3149
)
3250
}
3351

52+
fn parse_match(&self, arms: &[ArmId], span: Span) -> PResult<SwitchTargets> {
53+
let Some((otherwise, rest)) = arms.split_last() else {
54+
return Err(ParseError {
55+
span,
56+
item_description: format!("no arms"),
57+
expected: "at least one arm".to_string(),
58+
})
59+
};
60+
61+
let otherwise = &self.thir[*otherwise];
62+
let PatKind::Wild = otherwise.pattern.kind else {
63+
return Err(ParseError {
64+
span: otherwise.span,
65+
item_description: format!("{:?}", otherwise.pattern.kind),
66+
expected: "wildcard pattern".to_string(),
67+
})
68+
};
69+
let otherwise = self.parse_block(otherwise.body)?;
70+
71+
let mut values = Vec::new();
72+
let mut targets = Vec::new();
73+
for arm in rest {
74+
let arm = &self.thir[*arm];
75+
let PatKind::Constant { value } = arm.pattern.kind else {
76+
return Err(ParseError {
77+
span: arm.pattern.span,
78+
item_description: format!("{:?}", arm.pattern.kind),
79+
expected: "constant pattern".to_string(),
80+
})
81+
};
82+
values.push(value.eval_bits(self.tcx, self.param_env, arm.pattern.ty));
83+
targets.push(self.parse_block(arm.body)?);
84+
}
85+
86+
Ok(SwitchTargets::new(values.into_iter().zip(targets), otherwise))
87+
}
88+
3489
fn parse_rvalue(&self, expr_id: ExprId) -> PResult<Rvalue<'tcx>> {
3590
parse_by_kind!(self, expr_id, _, "rvalue",
91+
@call("mir_discriminant", args) => self.parse_place(args[0]).map(Rvalue::Discriminant),
3692
ExprKind::Borrow { borrow_kind, arg } => Ok(
3793
Rvalue::Ref(self.tcx.lifetimes.re_erased, *borrow_kind, self.parse_place(*arg)?)
3894
),
@@ -55,20 +111,50 @@ impl<'tcx, 'body> ParseCtxt<'tcx, 'body> {
55111
| ExprKind::ConstParam { .. }
56112
| ExprKind::ConstBlock { .. } => {
57113
Ok(Operand::Constant(Box::new(
58-
crate::build::expr::as_constant::as_constant_inner(expr, |_| None, self.tcx)
114+
as_constant_inner(expr, |_| None, self.tcx)
59115
)))
60116
},
61117
_ => self.parse_place(expr_id).map(Operand::Copy),
62118
)
63119
}
64120

65121
fn parse_place(&self, expr_id: ExprId) -> PResult<Place<'tcx>> {
66-
parse_by_kind!(self, expr_id, _, "place",
67-
ExprKind::Deref { arg } => Ok(
68-
self.parse_place(*arg)?.project_deeper(&[PlaceElem::Deref], self.tcx)
69-
),
70-
_ => self.parse_local(expr_id).map(Place::from),
71-
)
122+
self.parse_place_inner(expr_id).map(|(x, _)| x)
123+
}
124+
125+
fn parse_place_inner(&self, expr_id: ExprId) -> PResult<(Place<'tcx>, PlaceTy<'tcx>)> {
126+
let (parent, proj) = parse_by_kind!(self, expr_id, expr, "place",
127+
@call("mir_field", args) => {
128+
let (parent, ty) = self.parse_place_inner(args[0])?;
129+
let field = Field::from_u32(self.parse_integer_literal(args[1])? as u32);
130+
let field_ty = ty.field_ty(self.tcx, field);
131+
let proj = PlaceElem::Field(field, field_ty);
132+
let place = parent.project_deeper(&[proj], self.tcx);
133+
return Ok((place, PlaceTy::from_ty(field_ty)));
134+
},
135+
@call("mir_variant", args) => {
136+
(args[0], PlaceElem::Downcast(
137+
None,
138+
VariantIdx::from_u32(self.parse_integer_literal(args[1])? as u32)
139+
))
140+
},
141+
ExprKind::Deref { arg } => {
142+
parse_by_kind!(self, *arg, _, "does not matter",
143+
@call("mir_make_place", args) => return self.parse_place_inner(args[0]),
144+
_ => (*arg, PlaceElem::Deref),
145+
)
146+
},
147+
ExprKind::Index { lhs, index } => (*lhs, PlaceElem::Index(self.parse_local(*index)?)),
148+
ExprKind::Field { lhs, name: field, .. } => (*lhs, PlaceElem::Field(*field, expr.ty)),
149+
_ => {
150+
let place = self.parse_local(expr_id).map(Place::from)?;
151+
return Ok((place, PlaceTy::from_ty(expr.ty)))
152+
},
153+
);
154+
let (parent, ty) = self.parse_place_inner(parent)?;
155+
let place = parent.project_deeper(&[proj], self.tcx);
156+
let ty = ty.projection_ty(self.tcx, proj);
157+
Ok((place, ty))
72158
}
73159

74160
fn parse_local(&self, expr_id: ExprId) -> PResult<Local> {
@@ -102,4 +188,16 @@ impl<'tcx, 'body> ParseCtxt<'tcx, 'body> {
102188
},
103189
)
104190
}
191+
192+
fn parse_integer_literal(&self, expr_id: ExprId) -> PResult<u128> {
193+
parse_by_kind!(self, expr_id, expr, "constant",
194+
ExprKind::Literal { .. }
195+
| ExprKind::NamedConst { .. }
196+
| ExprKind::NonHirLiteral { .. }
197+
| ExprKind::ConstBlock { .. } => Ok({
198+
let value = as_constant_inner(expr, |_| None, self.tcx);
199+
value.literal.eval_bits(self.tcx, self.param_env, value.ty())
200+
}),
201+
)
202+
}
105203
}

compiler/rustc_mir_transform/src/check_unsafety.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -500,7 +500,7 @@ fn unsafety_check_result<'tcx>(
500500
// `mir_built` force this.
501501
let body = &tcx.mir_built(def).borrow();
502502

503-
if body.should_skip() {
503+
if body.is_custom_mir() {
504504
return tcx.arena.alloc(UnsafetyCheckResult {
505505
violations: Vec::new(),
506506
used_unsafe_blocks: FxHashSet::default(),

0 commit comments

Comments
 (0)