Skip to content

Commit 60cad20

Browse files
authored
Rollup merge of #73949 - wesleywiser:simplify_try_fixes, r=oli-obk
[mir-opt] Fix mis-optimization and other issues with the SimplifyArmIdentity pass This does not yet attempt re-enabling the pass, but it does resolve a number of issues with the pass. r? @oli-obk I believe this closes #73223.
2 parents 9d0ca38 + e16d6a6 commit 60cad20

12 files changed

+1411
-35
lines changed

Diff for: src/librustc_middle/mir/mod.rs

+12
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,18 @@ impl<'tcx> Body<'tcx> {
257257
(&mut self.basic_blocks, &mut self.local_decls)
258258
}
259259

260+
#[inline]
261+
pub fn basic_blocks_local_decls_mut_and_var_debug_info(
262+
&mut self,
263+
) -> (
264+
&mut IndexVec<BasicBlock, BasicBlockData<'tcx>>,
265+
&mut LocalDecls<'tcx>,
266+
&mut Vec<VarDebugInfo<'tcx>>,
267+
) {
268+
self.predecessor_cache.invalidate();
269+
(&mut self.basic_blocks, &mut self.local_decls, &mut self.var_debug_info)
270+
}
271+
260272
/// Returns `true` if a cycle exists in the control-flow graph that is reachable from the
261273
/// `START_BLOCK`.
262274
pub fn is_cfg_cyclic(&self) -> bool {

Diff for: src/librustc_mir/transform/simplify_try.rs

+106-9
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,10 @@
1111
1212
use crate::transform::{simplify, MirPass, MirSource};
1313
use itertools::Itertools as _;
14-
use rustc_index::vec::IndexVec;
14+
use rustc_index::{bit_set::BitSet, vec::IndexVec};
15+
use rustc_middle::mir::visit::{NonUseContext, PlaceContext, Visitor};
1516
use rustc_middle::mir::*;
16-
use rustc_middle::ty::{Ty, TyCtxt};
17+
use rustc_middle::ty::{List, Ty, TyCtxt};
1718
use rustc_target::abi::VariantIdx;
1819
use std::iter::{Enumerate, Peekable};
1920
use std::slice::Iter;
@@ -73,9 +74,20 @@ struct ArmIdentityInfo<'tcx> {
7374

7475
/// The statements that should be removed (turned into nops)
7576
stmts_to_remove: Vec<usize>,
77+
78+
/// Indices of debug variables that need to be adjusted to point to
79+
// `{local_0}.{dbg_projection}`.
80+
dbg_info_to_adjust: Vec<usize>,
81+
82+
/// The projection used to rewrite debug info.
83+
dbg_projection: &'tcx List<PlaceElem<'tcx>>,
7684
}
7785

78-
fn get_arm_identity_info<'a, 'tcx>(stmts: &'a [Statement<'tcx>]) -> Option<ArmIdentityInfo<'tcx>> {
86+
fn get_arm_identity_info<'a, 'tcx>(
87+
stmts: &'a [Statement<'tcx>],
88+
locals_count: usize,
89+
debug_info: &'a [VarDebugInfo<'tcx>],
90+
) -> Option<ArmIdentityInfo<'tcx>> {
7991
// This can't possibly match unless there are at least 3 statements in the block
8092
// so fail fast on tiny blocks.
8193
if stmts.len() < 3 {
@@ -187,7 +199,7 @@ fn get_arm_identity_info<'a, 'tcx>(stmts: &'a [Statement<'tcx>]) -> Option<ArmId
187199
try_eat_storage_stmts(&mut stmt_iter, &mut storage_live_stmts, &mut storage_dead_stmts);
188200

189201
let (get_variant_field_stmt, stmt) = stmt_iter.next()?;
190-
let (local_tmp_s0, local_1, vf_s0) = match_get_variant_field(stmt)?;
202+
let (local_tmp_s0, local_1, vf_s0, dbg_projection) = match_get_variant_field(stmt)?;
191203

192204
try_eat_storage_stmts(&mut stmt_iter, &mut storage_live_stmts, &mut storage_dead_stmts);
193205

@@ -228,6 +240,19 @@ fn get_arm_identity_info<'a, 'tcx>(stmts: &'a [Statement<'tcx>]) -> Option<ArmId
228240
let stmt_to_overwrite =
229241
nop_stmts.iter().find(|stmt_idx| live_idx < **stmt_idx && **stmt_idx < dead_idx);
230242

243+
let mut tmp_assigned_vars = BitSet::new_empty(locals_count);
244+
for (l, r) in &tmp_assigns {
245+
tmp_assigned_vars.insert(*l);
246+
tmp_assigned_vars.insert(*r);
247+
}
248+
249+
let mut dbg_info_to_adjust = Vec::new();
250+
for (i, var_info) in debug_info.iter().enumerate() {
251+
if tmp_assigned_vars.contains(var_info.place.local) {
252+
dbg_info_to_adjust.push(i);
253+
}
254+
}
255+
231256
Some(ArmIdentityInfo {
232257
local_temp_0: local_tmp_s0,
233258
local_1,
@@ -243,12 +268,16 @@ fn get_arm_identity_info<'a, 'tcx>(stmts: &'a [Statement<'tcx>]) -> Option<ArmId
243268
source_info: discr_stmt_source_info,
244269
storage_stmts,
245270
stmts_to_remove: nop_stmts,
271+
dbg_info_to_adjust,
272+
dbg_projection,
246273
})
247274
}
248275

249276
fn optimization_applies<'tcx>(
250277
opt_info: &ArmIdentityInfo<'tcx>,
251278
local_decls: &IndexVec<Local, LocalDecl<'tcx>>,
279+
local_uses: &IndexVec<Local, usize>,
280+
var_debug_info: &[VarDebugInfo<'tcx>],
252281
) -> bool {
253282
trace!("testing if optimization applies...");
254283

@@ -273,6 +302,7 @@ fn optimization_applies<'tcx>(
273302
// Verify the assigment chain consists of the form b = a; c = b; d = c; etc...
274303
if opt_info.field_tmp_assignments.is_empty() {
275304
trace!("NO: no assignments found");
305+
return false;
276306
}
277307
let mut last_assigned_to = opt_info.field_tmp_assignments[0].1;
278308
let source_local = last_assigned_to;
@@ -285,6 +315,35 @@ fn optimization_applies<'tcx>(
285315
last_assigned_to = *l;
286316
}
287317

318+
// Check that the first and last used locals are only used twice
319+
// since they are of the form:
320+
//
321+
// ```
322+
// _first = ((_x as Variant).n: ty);
323+
// _n = _first;
324+
// ...
325+
// ((_y as Variant).n: ty) = _n;
326+
// discriminant(_y) = z;
327+
// ```
328+
for (l, r) in &opt_info.field_tmp_assignments {
329+
if local_uses[*l] != 2 {
330+
warn!("NO: FAILED assignment chain local {:?} was used more than twice", l);
331+
return false;
332+
} else if local_uses[*r] != 2 {
333+
warn!("NO: FAILED assignment chain local {:?} was used more than twice", r);
334+
return false;
335+
}
336+
}
337+
338+
// Check that debug info only points to full Locals and not projections.
339+
for dbg_idx in &opt_info.dbg_info_to_adjust {
340+
let dbg_info = &var_debug_info[*dbg_idx];
341+
if !dbg_info.place.projection.is_empty() {
342+
trace!("NO: debug info for {:?} had a projection {:?}", dbg_info.name, dbg_info.place);
343+
return false;
344+
}
345+
}
346+
288347
if source_local != opt_info.local_temp_0 {
289348
trace!(
290349
"NO: start of assignment chain does not match enum variant temp: {:?} != {:?}",
@@ -312,11 +371,15 @@ impl<'tcx> MirPass<'tcx> for SimplifyArmIdentity {
312371
}
313372

314373
trace!("running SimplifyArmIdentity on {:?}", source);
315-
let (basic_blocks, local_decls) = body.basic_blocks_and_local_decls_mut();
374+
let local_uses = LocalUseCounter::get_local_uses(body);
375+
let (basic_blocks, local_decls, debug_info) =
376+
body.basic_blocks_local_decls_mut_and_var_debug_info();
316377
for bb in basic_blocks {
317-
if let Some(opt_info) = get_arm_identity_info(&bb.statements) {
378+
if let Some(opt_info) =
379+
get_arm_identity_info(&bb.statements, local_decls.len(), debug_info)
380+
{
318381
trace!("got opt_info = {:#?}", opt_info);
319-
if !optimization_applies(&opt_info, local_decls) {
382+
if !optimization_applies(&opt_info, local_decls, &local_uses, &debug_info) {
320383
debug!("optimization skipped for {:?}", source);
321384
continue;
322385
}
@@ -352,23 +415,57 @@ impl<'tcx> MirPass<'tcx> for SimplifyArmIdentity {
352415

353416
bb.statements.retain(|stmt| stmt.kind != StatementKind::Nop);
354417

418+
// Fix the debug info to point to the right local
419+
for dbg_index in opt_info.dbg_info_to_adjust {
420+
let dbg_info = &mut debug_info[dbg_index];
421+
assert!(dbg_info.place.projection.is_empty());
422+
dbg_info.place.local = opt_info.local_0;
423+
dbg_info.place.projection = opt_info.dbg_projection;
424+
}
425+
355426
trace!("block is now {:?}", bb.statements);
356427
}
357428
}
358429
}
359430
}
360431

432+
struct LocalUseCounter {
433+
local_uses: IndexVec<Local, usize>,
434+
}
435+
436+
impl LocalUseCounter {
437+
fn get_local_uses<'tcx>(body: &Body<'tcx>) -> IndexVec<Local, usize> {
438+
let mut counter = LocalUseCounter { local_uses: IndexVec::from_elem(0, &body.local_decls) };
439+
counter.visit_body(body);
440+
counter.local_uses
441+
}
442+
}
443+
444+
impl<'tcx> Visitor<'tcx> for LocalUseCounter {
445+
fn visit_local(&mut self, local: &Local, context: PlaceContext, _location: Location) {
446+
if context.is_storage_marker()
447+
|| context == PlaceContext::NonUse(NonUseContext::VarDebugInfo)
448+
{
449+
return;
450+
}
451+
452+
self.local_uses[*local] += 1;
453+
}
454+
}
455+
361456
/// Match on:
362457
/// ```rust
363458
/// _LOCAL_INTO = ((_LOCAL_FROM as Variant).FIELD: TY);
364459
/// ```
365-
fn match_get_variant_field<'tcx>(stmt: &Statement<'tcx>) -> Option<(Local, Local, VarField<'tcx>)> {
460+
fn match_get_variant_field<'tcx>(
461+
stmt: &Statement<'tcx>,
462+
) -> Option<(Local, Local, VarField<'tcx>, &'tcx List<PlaceElem<'tcx>>)> {
366463
match &stmt.kind {
367464
StatementKind::Assign(box (place_into, rvalue_from)) => match rvalue_from {
368465
Rvalue::Use(Operand::Copy(pf) | Operand::Move(pf)) => {
369466
let local_into = place_into.as_local()?;
370467
let (local_from, vf) = match_variant_field_place(*pf)?;
371-
Some((local_into, local_from, vf))
468+
Some((local_into, local_from, vf, pf.projection))
372469
}
373470
_ => None,
374471
},

Diff for: src/test/mir-opt/issue-73223.rs

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
fn main() {
2+
let split = match Some(1) {
3+
Some(v) => v,
4+
None => return,
5+
};
6+
7+
let _prev = Some(split);
8+
assert_eq!(split, 1);
9+
}
10+
11+
// EMIT_MIR_FOR_EACH_BIT_WIDTH
12+
// EMIT_MIR rustc.main.SimplifyArmIdentity.diff
13+
// EMIT_MIR rustc.main.PreCodegen.diff

0 commit comments

Comments
 (0)