Skip to content

Commit 9c0bc30

Browse files
committed
Auto merge of #104975 - JakobDegen:custom_mir_let, r=oli-obk
`#![custom_mir]`: Various improvements This PR makes a bunch of improvements to `#![custom_mir]`. Ideally this would be 4 PRs, one for each commit, but those would take forever to get merged and be a pain to juggle. Should still be reviewed one commit at a time though. ### Commit 1: Support arbitrary `let` Before this change, all locals used in the body need to be declared at the top of the `mir!` invocation, which is rather annoying. We attempt to change that. Unfortunately, we still have the requirement that the output of the `mir!` macro must resolve, typecheck, etc. Because of that, we can't just accept this in the THIR -> MIR parser because something like ```rust { let x = 0; Goto(other) } other = { RET = x; Return() } ``` will fail to resolve. Instead, the implementation does macro shenanigans to find the let declarations and extract them as part of the `mir!` macro. That *works*, but it is fairly complicated and degrades debuginfo by quite a bit. Specifically, the spans for any statements and declarations that are affected by this are completely wrong. My guess is that this is a net improvement though. One way to recover some of the debuginfo would be to not support type annotations in the `let` statements, which would allow us to parse like `let $stmt:stmt`. That seems quite surprising though. ### Commit 2: Parse consts Reuses most of the const parsing from regular Mir building for building custom mir ### Commit 3: Parse statics Statics are slightly weird because the Mir primitive associated with them is a reference/pointer to them, so this is factored out separately. ### Commit 4: Fix some spans A bunch of the spans were non-ideal, so we adjust them to be much more helpful. r? `@oli-obk`
2 parents d6c4de0 + 5a34dbf commit 9c0bc30

15 files changed

+454
-104
lines changed

compiler/rustc_mir_build/src/build/custom/mod.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ pub(super) fn build_custom_mir<'tcx>(
7474
let mut pctxt = ParseCtxt {
7575
tcx,
7676
thir,
77-
source_info: SourceInfo { span, scope: OUTERMOST_SOURCE_SCOPE },
77+
source_scope: OUTERMOST_SOURCE_SCOPE,
7878
body: &mut body,
7979
local_map: FxHashMap::default(),
8080
block_map: FxHashMap::default(),
@@ -128,7 +128,7 @@ fn parse_attribute(attr: &Attribute) -> MirPhase {
128128
struct ParseCtxt<'tcx, 'body> {
129129
tcx: TyCtxt<'tcx>,
130130
thir: &'body Thir<'tcx>,
131-
source_info: SourceInfo,
131+
source_scope: SourceScope,
132132

133133
body: &'body mut Body<'tcx>,
134134
local_map: FxHashMap<LocalVarId, Local>,

compiler/rustc_mir_build/src/build/custom/parse.rs

+18-7
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ macro_rules! parse_by_kind {
2323
(
2424
$self:ident,
2525
$expr_id:expr,
26+
$expr_name:pat,
2627
$expected:literal,
2728
$(
2829
@call($name:literal, $args:ident) => $call_expr:expr,
@@ -33,6 +34,8 @@ macro_rules! parse_by_kind {
3334
) => {{
3435
let expr_id = $self.preparse($expr_id);
3536
let expr = &$self.thir[expr_id];
37+
debug!("Trying to parse {:?} as {}", expr.kind, $expected);
38+
let $expr_name = expr;
3639
match &expr.kind {
3740
$(
3841
ExprKind::Call { ty, fun: _, args: $args, .. } if {
@@ -137,26 +140,26 @@ impl<'tcx, 'body> ParseCtxt<'tcx, 'body> {
137140
/// This allows us to easily parse the basic blocks declarations, local declarations, and
138141
/// basic block definitions in order.
139142
pub fn parse_body(&mut self, expr_id: ExprId) -> PResult<()> {
140-
let body = parse_by_kind!(self, expr_id, "whole body",
143+
let body = parse_by_kind!(self, expr_id, _, "whole body",
141144
ExprKind::Block { block } => self.thir[*block].expr.unwrap(),
142145
);
143-
let (block_decls, rest) = parse_by_kind!(self, body, "body with block decls",
146+
let (block_decls, rest) = parse_by_kind!(self, body, _, "body with block decls",
144147
ExprKind::Block { block } => {
145148
let block = &self.thir[*block];
146149
(&block.stmts, block.expr.unwrap())
147150
},
148151
);
149152
self.parse_block_decls(block_decls.iter().copied())?;
150153

151-
let (local_decls, rest) = parse_by_kind!(self, rest, "body with local decls",
154+
let (local_decls, rest) = parse_by_kind!(self, rest, _, "body with local decls",
152155
ExprKind::Block { block } => {
153156
let block = &self.thir[*block];
154157
(&block.stmts, block.expr.unwrap())
155158
},
156159
);
157160
self.parse_local_decls(local_decls.iter().copied())?;
158161

159-
let block_defs = parse_by_kind!(self, rest, "body with block defs",
162+
let block_defs = parse_by_kind!(self, rest, _, "body with block defs",
160163
ExprKind::Block { block } => &self.thir[*block].stmts,
161164
);
162165
for (i, block_def) in block_defs.iter().enumerate() {
@@ -223,22 +226,30 @@ impl<'tcx, 'body> ParseCtxt<'tcx, 'body> {
223226
}
224227

225228
fn parse_block_def(&self, expr_id: ExprId) -> PResult<BasicBlockData<'tcx>> {
226-
let block = parse_by_kind!(self, expr_id, "basic block",
229+
let block = parse_by_kind!(self, expr_id, _, "basic block",
227230
ExprKind::Block { block } => &self.thir[*block],
228231
);
229232

230233
let mut data = BasicBlockData::new(None);
231234
for stmt_id in &*block.stmts {
232235
let stmt = self.statement_as_expr(*stmt_id)?;
236+
let span = self.thir[stmt].span;
233237
let statement = self.parse_statement(stmt)?;
234-
data.statements.push(Statement { source_info: self.source_info, kind: statement });
238+
data.statements.push(Statement {
239+
source_info: SourceInfo { span, scope: self.source_scope },
240+
kind: statement,
241+
});
235242
}
236243

237244
let Some(trailing) = block.expr else {
238245
return Err(self.expr_error(expr_id, "terminator"))
239246
};
247+
let span = self.thir[trailing].span;
240248
let terminator = self.parse_terminator(trailing)?;
241-
data.terminator = Some(Terminator { source_info: self.source_info, kind: terminator });
249+
data.terminator = Some(Terminator {
250+
source_info: SourceInfo { span, scope: self.source_scope },
251+
kind: terminator,
252+
});
242253

243254
Ok(data)
244255
}

compiler/rustc_mir_build/src/build/custom/parse/instruction.rs

+40-7
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1+
use rustc_middle::mir::interpret::{ConstValue, Scalar};
12
use rustc_middle::{mir::*, thir::*, ty};
23

34
use super::{parse_by_kind, PResult, ParseCtxt};
45

56
impl<'tcx, 'body> ParseCtxt<'tcx, 'body> {
67
pub fn parse_statement(&self, expr_id: ExprId) -> PResult<StatementKind<'tcx>> {
7-
parse_by_kind!(self, expr_id, "statement",
8+
parse_by_kind!(self, expr_id, _, "statement",
89
@call("mir_retag", args) => {
910
Ok(StatementKind::Retag(RetagKind::Default, Box::new(self.parse_place(args[0])?)))
1011
},
@@ -20,7 +21,7 @@ impl<'tcx, 'body> ParseCtxt<'tcx, 'body> {
2021
}
2122

2223
pub fn parse_terminator(&self, expr_id: ExprId) -> PResult<TerminatorKind<'tcx>> {
23-
parse_by_kind!(self, expr_id, "terminator",
24+
parse_by_kind!(self, expr_id, _, "terminator",
2425
@call("mir_return", _args) => {
2526
Ok(TerminatorKind::Return)
2627
},
@@ -31,7 +32,7 @@ impl<'tcx, 'body> ParseCtxt<'tcx, 'body> {
3132
}
3233

3334
fn parse_rvalue(&self, expr_id: ExprId) -> PResult<Rvalue<'tcx>> {
34-
parse_by_kind!(self, expr_id, "rvalue",
35+
parse_by_kind!(self, expr_id, _, "rvalue",
3536
ExprKind::Borrow { borrow_kind, arg } => Ok(
3637
Rvalue::Ref(self.tcx.lifetimes.re_erased, *borrow_kind, self.parse_place(*arg)?)
3738
),
@@ -43,14 +44,26 @@ impl<'tcx, 'body> ParseCtxt<'tcx, 'body> {
4344
}
4445

4546
fn parse_operand(&self, expr_id: ExprId) -> PResult<Operand<'tcx>> {
46-
parse_by_kind!(self, expr_id, "operand",
47+
parse_by_kind!(self, expr_id, expr, "operand",
4748
@call("mir_move", args) => self.parse_place(args[0]).map(Operand::Move),
49+
@call("mir_static", args) => self.parse_static(args[0]),
50+
@call("mir_static_mut", args) => self.parse_static(args[0]),
51+
ExprKind::Literal { .. }
52+
| ExprKind::NamedConst { .. }
53+
| ExprKind::NonHirLiteral { .. }
54+
| ExprKind::ZstLiteral { .. }
55+
| ExprKind::ConstParam { .. }
56+
| ExprKind::ConstBlock { .. } => {
57+
Ok(Operand::Constant(Box::new(
58+
crate::build::expr::as_constant::as_constant_inner(expr, |_| None, self.tcx)
59+
)))
60+
},
4861
_ => self.parse_place(expr_id).map(Operand::Copy),
4962
)
5063
}
5164

5265
fn parse_place(&self, expr_id: ExprId) -> PResult<Place<'tcx>> {
53-
parse_by_kind!(self, expr_id, "place",
66+
parse_by_kind!(self, expr_id, _, "place",
5467
ExprKind::Deref { arg } => Ok(
5568
self.parse_place(*arg)?.project_deeper(&[PlaceElem::Deref], self.tcx)
5669
),
@@ -59,14 +72,34 @@ impl<'tcx, 'body> ParseCtxt<'tcx, 'body> {
5972
}
6073

6174
fn parse_local(&self, expr_id: ExprId) -> PResult<Local> {
62-
parse_by_kind!(self, expr_id, "local",
75+
parse_by_kind!(self, expr_id, _, "local",
6376
ExprKind::VarRef { id } => Ok(self.local_map[id]),
6477
)
6578
}
6679

6780
fn parse_block(&self, expr_id: ExprId) -> PResult<BasicBlock> {
68-
parse_by_kind!(self, expr_id, "basic block",
81+
parse_by_kind!(self, expr_id, _, "basic block",
6982
ExprKind::VarRef { id } => Ok(self.block_map[id]),
7083
)
7184
}
85+
86+
fn parse_static(&self, expr_id: ExprId) -> PResult<Operand<'tcx>> {
87+
let expr_id = parse_by_kind!(self, expr_id, _, "static",
88+
ExprKind::Deref { arg } => *arg,
89+
);
90+
91+
parse_by_kind!(self, expr_id, expr, "static",
92+
ExprKind::StaticRef { alloc_id, ty, .. } => {
93+
let const_val =
94+
ConstValue::Scalar(Scalar::from_pointer((*alloc_id).into(), &self.tcx));
95+
let literal = ConstantKind::Val(const_val, *ty);
96+
97+
Ok(Operand::Constant(Box::new(Constant {
98+
span: expr.span,
99+
user_ty: None,
100+
literal
101+
})))
102+
},
103+
)
104+
}
72105
}

compiler/rustc_mir_build/src/build/expr/as_constant.rs

+71-66
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@ use rustc_middle::mir::interpret::{
88
};
99
use rustc_middle::mir::*;
1010
use rustc_middle::thir::*;
11-
use rustc_middle::ty::{self, CanonicalUserTypeAnnotation, TyCtxt};
11+
use rustc_middle::ty::{
12+
self, CanonicalUserType, CanonicalUserTypeAnnotation, TyCtxt, UserTypeAnnotationIndex,
13+
};
1214
use rustc_span::DUMMY_SP;
1315
use rustc_target::abi::Size;
1416

@@ -19,84 +21,87 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
1921
let this = self;
2022
let tcx = this.tcx;
2123
let Expr { ty, temp_lifetime: _, span, ref kind } = *expr;
22-
match *kind {
24+
match kind {
2325
ExprKind::Scope { region_scope: _, lint_level: _, value } => {
24-
this.as_constant(&this.thir[value])
25-
}
26-
ExprKind::Literal { lit, neg } => {
27-
let literal =
28-
match lit_to_mir_constant(tcx, LitToConstInput { lit: &lit.node, ty, neg }) {
29-
Ok(c) => c,
30-
Err(LitToConstError::Reported(guar)) => {
31-
ConstantKind::Ty(tcx.const_error_with_guaranteed(ty, guar))
32-
}
33-
Err(LitToConstError::TypeError) => {
34-
bug!("encountered type error in `lit_to_mir_constant")
35-
}
36-
};
37-
38-
Constant { span, user_ty: None, literal }
26+
this.as_constant(&this.thir[*value])
3927
}
40-
ExprKind::NonHirLiteral { lit, ref user_ty } => {
41-
let user_ty = user_ty.as_ref().map(|user_ty| {
42-
this.canonical_user_type_annotations.push(CanonicalUserTypeAnnotation {
28+
_ => as_constant_inner(
29+
expr,
30+
|user_ty| {
31+
Some(this.canonical_user_type_annotations.push(CanonicalUserTypeAnnotation {
4332
span,
4433
user_ty: user_ty.clone(),
4534
inferred_ty: ty,
46-
})
47-
});
48-
let literal = ConstantKind::Val(ConstValue::Scalar(Scalar::Int(lit)), ty);
35+
}))
36+
},
37+
tcx,
38+
),
39+
}
40+
}
41+
}
4942

50-
Constant { span, user_ty: user_ty, literal }
51-
}
52-
ExprKind::ZstLiteral { ref user_ty } => {
53-
let user_ty = user_ty.as_ref().map(|user_ty| {
54-
this.canonical_user_type_annotations.push(CanonicalUserTypeAnnotation {
55-
span,
56-
user_ty: user_ty.clone(),
57-
inferred_ty: ty,
58-
})
59-
});
60-
let literal = ConstantKind::Val(ConstValue::ZeroSized, ty);
43+
pub fn as_constant_inner<'tcx>(
44+
expr: &Expr<'tcx>,
45+
push_cuta: impl FnMut(&Box<CanonicalUserType<'tcx>>) -> Option<UserTypeAnnotationIndex>,
46+
tcx: TyCtxt<'tcx>,
47+
) -> Constant<'tcx> {
48+
let Expr { ty, temp_lifetime: _, span, ref kind } = *expr;
49+
match *kind {
50+
ExprKind::Literal { lit, neg } => {
51+
let literal =
52+
match lit_to_mir_constant(tcx, LitToConstInput { lit: &lit.node, ty, neg }) {
53+
Ok(c) => c,
54+
Err(LitToConstError::Reported(guar)) => {
55+
ConstantKind::Ty(tcx.const_error_with_guaranteed(ty, guar))
56+
}
57+
Err(LitToConstError::TypeError) => {
58+
bug!("encountered type error in `lit_to_mir_constant")
59+
}
60+
};
6161

62-
Constant { span, user_ty: user_ty, literal }
63-
}
64-
ExprKind::NamedConst { def_id, substs, ref user_ty } => {
65-
let user_ty = user_ty.as_ref().map(|user_ty| {
66-
this.canonical_user_type_annotations.push(CanonicalUserTypeAnnotation {
67-
span,
68-
user_ty: user_ty.clone(),
69-
inferred_ty: ty,
70-
})
71-
});
62+
Constant { span, user_ty: None, literal }
63+
}
64+
ExprKind::NonHirLiteral { lit, ref user_ty } => {
65+
let user_ty = user_ty.as_ref().map(push_cuta).flatten();
7266

73-
let uneval =
74-
mir::UnevaluatedConst::new(ty::WithOptConstParam::unknown(def_id), substs);
75-
let literal = ConstantKind::Unevaluated(uneval, ty);
67+
let literal = ConstantKind::Val(ConstValue::Scalar(Scalar::Int(lit)), ty);
7668

77-
Constant { user_ty, span, literal }
78-
}
79-
ExprKind::ConstParam { param, def_id: _ } => {
80-
let const_param = tcx.mk_const(param, expr.ty);
81-
let literal = ConstantKind::Ty(const_param);
69+
Constant { span, user_ty: user_ty, literal }
70+
}
71+
ExprKind::ZstLiteral { ref user_ty } => {
72+
let user_ty = user_ty.as_ref().map(push_cuta).flatten();
8273

83-
Constant { user_ty: None, span, literal }
84-
}
85-
ExprKind::ConstBlock { did: def_id, substs } => {
86-
let uneval =
87-
mir::UnevaluatedConst::new(ty::WithOptConstParam::unknown(def_id), substs);
88-
let literal = ConstantKind::Unevaluated(uneval, ty);
74+
let literal = ConstantKind::Val(ConstValue::ZeroSized, ty);
8975

90-
Constant { user_ty: None, span, literal }
91-
}
92-
ExprKind::StaticRef { alloc_id, ty, .. } => {
93-
let const_val = ConstValue::Scalar(Scalar::from_pointer(alloc_id.into(), &tcx));
94-
let literal = ConstantKind::Val(const_val, ty);
76+
Constant { span, user_ty: user_ty, literal }
77+
}
78+
ExprKind::NamedConst { def_id, substs, ref user_ty } => {
79+
let user_ty = user_ty.as_ref().map(push_cuta).flatten();
9580

96-
Constant { span, user_ty: None, literal }
97-
}
98-
_ => span_bug!(span, "expression is not a valid constant {:?}", kind),
81+
let uneval = mir::UnevaluatedConst::new(ty::WithOptConstParam::unknown(def_id), substs);
82+
let literal = ConstantKind::Unevaluated(uneval, ty);
83+
84+
Constant { user_ty, span, literal }
85+
}
86+
ExprKind::ConstParam { param, def_id: _ } => {
87+
let const_param = tcx.mk_const(ty::ConstKind::Param(param), expr.ty);
88+
let literal = ConstantKind::Ty(const_param);
89+
90+
Constant { user_ty: None, span, literal }
91+
}
92+
ExprKind::ConstBlock { did: def_id, substs } => {
93+
let uneval = mir::UnevaluatedConst::new(ty::WithOptConstParam::unknown(def_id), substs);
94+
let literal = ConstantKind::Unevaluated(uneval, ty);
95+
96+
Constant { user_ty: None, span, literal }
97+
}
98+
ExprKind::StaticRef { alloc_id, ty, .. } => {
99+
let const_val = ConstValue::Scalar(Scalar::from_pointer(alloc_id.into(), &tcx));
100+
let literal = ConstantKind::Val(const_val, ty);
101+
102+
Constant { span, user_ty: None, literal }
99103
}
104+
_ => span_bug!(span, "expression is not a valid constant {:?}", kind),
100105
}
101106
}
102107

compiler/rustc_mir_build/src/build/mod.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -492,7 +492,7 @@ fn construct_fn<'tcx>(
492492
arguments,
493493
return_ty,
494494
return_ty_span,
495-
span,
495+
span_with_body,
496496
custom_mir_attr,
497497
);
498498
}

0 commit comments

Comments
 (0)