Skip to content

Commit 293756c

Browse files
simonvandelSimon Vandel Sillesen
authored and
Simon Vandel Sillesen
committed
Implement 'considered equal' for statements, so that for example _0 = _1 and discriminant(_0) = discriminant(0) are considered equal if 0 is a fieldless variant of an enum
1 parent 009551f commit 293756c

10 files changed

+381
-267
lines changed

src/librustc_mir/transform/simplify_try.rs

+223-36
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ use itertools::Itertools as _;
1414
use rustc_index::{bit_set::BitSet, vec::IndexVec};
1515
use rustc_middle::mir::visit::{NonUseContext, PlaceContext, Visitor};
1616
use rustc_middle::mir::*;
17-
use rustc_middle::ty::{List, Ty, TyCtxt};
17+
use rustc_middle::ty::{self, List, Ty, TyCtxt};
1818
use rustc_target::abi::VariantIdx;
1919
use std::iter::{Enumerate, Peekable};
2020
use std::slice::Iter;
@@ -527,52 +527,239 @@ fn match_variant_field_place<'tcx>(place: Place<'tcx>) -> Option<(Local, VarFiel
527527
pub struct SimplifyBranchSame;
528528

529529
impl<'tcx> MirPass<'tcx> for SimplifyBranchSame {
530-
fn run_pass(&self, _: TyCtxt<'tcx>, _: MirSource<'tcx>, body: &mut Body<'tcx>) {
531-
let mut did_remove_blocks = false;
532-
let bbs = body.basic_blocks_mut();
533-
for bb_idx in bbs.indices() {
534-
let targets = match &bbs[bb_idx].terminator().kind {
535-
TerminatorKind::SwitchInt { targets, .. } => targets,
536-
_ => continue,
537-
};
530+
fn run_pass(&self, tcx: TyCtxt<'tcx>, source: MirSource<'tcx>, body: &mut Body<'tcx>) {
531+
trace!("Running SimplifyBranchSame on {:?}", source);
532+
let finder = SimplifyBranchSameOptimizationFinder { body, tcx };
533+
let opts = finder.find();
534+
535+
let did_remove_blocks = opts.len() > 0;
536+
for opt in opts.iter() {
537+
trace!("SUCCESS: Applying optimization {:?}", opt);
538+
// Replace `SwitchInt(..) -> [bb_first, ..];` with a `goto -> bb_first;`.
539+
body.basic_blocks_mut()[opt.bb_to_opt_terminator].terminator_mut().kind =
540+
TerminatorKind::Goto { target: opt.bb_to_goto };
541+
}
542+
543+
if did_remove_blocks {
544+
// We have dead blocks now, so remove those.
545+
simplify::remove_dead_blocks(body);
546+
}
547+
}
548+
}
549+
550+
#[derive(Debug)]
551+
struct SimplifyBranchSameOptimization {
552+
/// All basic blocks are equal so go to this one
553+
bb_to_goto: BasicBlock,
554+
/// Basic block where the terminator can be simplified to a goto
555+
bb_to_opt_terminator: BasicBlock,
556+
}
557+
558+
struct SimplifyBranchSameOptimizationFinder<'a, 'tcx> {
559+
body: &'a Body<'tcx>,
560+
tcx: TyCtxt<'tcx>,
561+
}
538562

539-
let mut iter_bbs_reachable = targets
540-
.iter()
541-
.map(|idx| (*idx, &bbs[*idx]))
542-
.filter(|(_, bb)| {
543-
// Reaching `unreachable` is UB so assume it doesn't happen.
544-
bb.terminator().kind != TerminatorKind::Unreachable
563+
impl<'a, 'tcx> SimplifyBranchSameOptimizationFinder<'a, 'tcx> {
564+
fn find(&self) -> Vec<SimplifyBranchSameOptimization> {
565+
self.body
566+
.basic_blocks()
567+
.iter_enumerated()
568+
.filter_map(|(bb_idx, bb)| {
569+
let (discr_switched_on, targets) = match &bb.terminator().kind {
570+
TerminatorKind::SwitchInt { targets, discr, .. } => (discr, targets),
571+
_ => return None,
572+
};
573+
574+
// find the adt that has its discriminant read
575+
// assuming this must be the last statement of the block
576+
let adt_matched_on = match &bb.statements.last()?.kind {
577+
StatementKind::Assign(box (place, rhs))
578+
if Some(*place) == discr_switched_on.place() =>
579+
{
580+
match rhs {
581+
Rvalue::Discriminant(adt_place) if adt_place.ty(self.body, self.tcx).ty.is_enum() => adt_place,
582+
_ => {
583+
trace!("NO: expected a discriminant read of an enum instead of: {:?}", rhs);
584+
return None;
585+
}
586+
}
587+
}
588+
other => {
589+
trace!("NO: expected an assignment of a discriminant read to a place. Found: {:?}", other);
590+
return None
591+
},
592+
};
593+
594+
let mut iter_bbs_reachable = targets
595+
.iter()
596+
.map(|idx| (*idx, &self.body.basic_blocks()[*idx]))
597+
.filter(|(_, bb)| {
598+
// Reaching `unreachable` is UB so assume it doesn't happen.
599+
bb.terminator().kind != TerminatorKind::Unreachable
545600
// But `asm!(...)` could abort the program,
546601
// so we cannot assume that the `unreachable` terminator itself is reachable.
547602
// FIXME(Centril): use a normalization pass instead of a check.
548603
|| bb.statements.iter().any(|stmt| match stmt.kind {
549604
StatementKind::LlvmInlineAsm(..) => true,
550605
_ => false,
551606
})
552-
})
553-
.peekable();
554-
555-
// We want to `goto -> bb_first`.
556-
let bb_first = iter_bbs_reachable.peek().map(|(idx, _)| *idx).unwrap_or(targets[0]);
557-
558-
// All successor basic blocks should have the exact same form.
559-
let all_successors_equivalent =
560-
iter_bbs_reachable.map(|(_, bb)| bb).tuple_windows().all(|(bb_l, bb_r)| {
561-
bb_l.is_cleanup == bb_r.is_cleanup
562-
&& bb_l.terminator().kind == bb_r.terminator().kind
563-
&& bb_l.statements.iter().eq_by(&bb_r.statements, |x, y| x.kind == y.kind)
564-
});
565-
566-
if all_successors_equivalent {
567-
// Replace `SwitchInt(..) -> [bb_first, ..];` with a `goto -> bb_first;`.
568-
bbs[bb_idx].terminator_mut().kind = TerminatorKind::Goto { target: bb_first };
569-
did_remove_blocks = true;
607+
})
608+
.peekable();
609+
610+
let bb_first = iter_bbs_reachable.peek().map(|(idx, _)| *idx).unwrap_or(targets[0]);
611+
let mut all_successors_equivalent = StatementEquality::TrivialEqual;
612+
613+
// All successor basic blocks must be equal or contain statements that are pairwise considered equal.
614+
for ((bb_l_idx,bb_l), (bb_r_idx,bb_r)) in iter_bbs_reachable.tuple_windows() {
615+
let trivial_checks = bb_l.is_cleanup == bb_r.is_cleanup
616+
&& bb_l.terminator().kind == bb_r.terminator().kind;
617+
let statement_check = || {
618+
bb_l.statements.iter().zip(&bb_r.statements).try_fold(StatementEquality::TrivialEqual, |acc,(l,r)| {
619+
let stmt_equality = self.statement_equality(*adt_matched_on, &l, bb_l_idx, &r, bb_r_idx);
620+
if matches!(stmt_equality, StatementEquality::NotEqual) {
621+
// short circuit
622+
None
623+
} else {
624+
Some(acc.combine(&stmt_equality))
625+
}
626+
})
627+
.unwrap_or(StatementEquality::NotEqual)
628+
};
629+
if !trivial_checks {
630+
all_successors_equivalent = StatementEquality::NotEqual;
631+
break;
632+
}
633+
all_successors_equivalent = all_successors_equivalent.combine(&statement_check());
634+
};
635+
636+
match all_successors_equivalent{
637+
StatementEquality::TrivialEqual => {
638+
// statements are trivially equal, so just take first
639+
trace!("Statements are trivially equal");
640+
Some(SimplifyBranchSameOptimization {
641+
bb_to_goto: bb_first,
642+
bb_to_opt_terminator: bb_idx,
643+
})
644+
}
645+
StatementEquality::ConsideredEqual(bb_to_choose) => {
646+
trace!("Statements are considered equal");
647+
Some(SimplifyBranchSameOptimization {
648+
bb_to_goto: bb_to_choose,
649+
bb_to_opt_terminator: bb_idx,
650+
})
651+
}
652+
StatementEquality::NotEqual => {
653+
trace!("NO: not all successors of basic block {:?} were equivalent", bb_idx);
654+
None
655+
}
656+
}
657+
})
658+
.collect()
659+
}
660+
661+
/// Tests if two statements can be considered equal
662+
///
663+
/// Statements can be trivially equal if the kinds match.
664+
/// But they can also be considered equal in the following case A:
665+
/// ```
666+
/// discriminant(_0) = 0; // bb1
667+
/// _0 = move _1; // bb2
668+
/// ```
669+
/// In this case the two statements are equal iff
670+
/// 1: _0 is an enum where the variant index 0 is fieldless, and
671+
/// 2: bb1 was targeted by a switch where the discriminant of _1 was switched on
672+
fn statement_equality(
673+
&self,
674+
adt_matched_on: Place<'tcx>,
675+
x: &Statement<'tcx>,
676+
x_bb_idx: BasicBlock,
677+
y: &Statement<'tcx>,
678+
y_bb_idx: BasicBlock,
679+
) -> StatementEquality {
680+
let helper = |rhs: &Rvalue<'tcx>,
681+
place: &Box<Place<'tcx>>,
682+
variant_index: &VariantIdx,
683+
side_to_choose| {
684+
let place_type = place.ty(self.body, self.tcx).ty;
685+
let adt = match place_type.kind {
686+
ty::Adt(adt, _) if adt.is_enum() => adt,
687+
_ => return StatementEquality::NotEqual,
688+
};
689+
let variant_is_fieldless = adt.variants[*variant_index].fields.is_empty();
690+
if !variant_is_fieldless {
691+
trace!("NO: variant {:?} was not fieldless", variant_index);
692+
return StatementEquality::NotEqual;
693+
}
694+
695+
match rhs {
696+
Rvalue::Use(operand) if operand.place() == Some(adt_matched_on) => {
697+
StatementEquality::ConsideredEqual(side_to_choose)
698+
}
699+
_ => {
700+
trace!(
701+
"NO: RHS of assignment was {:?}, but expected it to match the adt being matched on in the switch, which is {:?}",
702+
rhs,
703+
adt_matched_on
704+
);
705+
StatementEquality::NotEqual
706+
}
707+
}
708+
};
709+
match (&x.kind, &y.kind) {
710+
// trivial case
711+
(x, y) if x == y => StatementEquality::TrivialEqual,
712+
713+
// check for case A
714+
(
715+
StatementKind::Assign(box (_, rhs)),
716+
StatementKind::SetDiscriminant { place, variant_index },
717+
) => {
718+
// choose basic block of x, as that has the assign
719+
helper(rhs, place, variant_index, x_bb_idx)
720+
}
721+
(
722+
StatementKind::SetDiscriminant { place, variant_index },
723+
StatementKind::Assign(box (_, rhs)),
724+
) => {
725+
// choose basic block of y, as that has the assign
726+
helper(rhs, place, variant_index, y_bb_idx)
727+
}
728+
_ => {
729+
trace!("NO: statements `{:?}` and `{:?}` not considered equal", x, y);
730+
StatementEquality::NotEqual
570731
}
571732
}
733+
}
734+
}
572735

573-
if did_remove_blocks {
574-
// We have dead blocks now, so remove those.
575-
simplify::remove_dead_blocks(body);
736+
#[derive(Copy, Clone, Eq, PartialEq)]
737+
enum StatementEquality {
738+
/// The two statements are trivially equal; same kind
739+
TrivialEqual,
740+
/// The two statements are considered equal, but may be of different kinds. The BasicBlock field is the basic block to jump to when performing the branch-same optimization.
741+
/// For example, `_0 = _1` and `discriminant(_0) = discriminant(0)` are considered equal if 0 is a fieldless variant of an enum. But we don't want to jump to the basic block with the SetDiscriminant, as that is not legal if _1 is not the 0 variant index
742+
ConsideredEqual(BasicBlock),
743+
/// The two statements are not equal
744+
NotEqual,
745+
}
746+
747+
impl StatementEquality {
748+
fn combine(&self, other: &StatementEquality) -> StatementEquality {
749+
use StatementEquality::*;
750+
match (self, other) {
751+
(TrivialEqual, TrivialEqual) => TrivialEqual,
752+
(TrivialEqual, ConsideredEqual(b)) | (ConsideredEqual(b), TrivialEqual) => {
753+
ConsideredEqual(*b)
754+
}
755+
(ConsideredEqual(b1), ConsideredEqual(b2)) => {
756+
if b1 == b2 {
757+
ConsideredEqual(*b1)
758+
} else {
759+
NotEqual
760+
}
761+
}
762+
(_, NotEqual) | (NotEqual, _) => NotEqual,
576763
}
577764
}
578765
}

src/test/mir-opt/simplify-arm.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// compile-flags: -Z mir-opt-level=1
1+
// compile-flags: -Z mir-opt-level=2
22
// EMIT_MIR simplify_arm.id.SimplifyArmIdentity.diff
33
// EMIT_MIR simplify_arm.id.SimplifyBranchSame.diff
44
// EMIT_MIR simplify_arm.id_result.SimplifyArmIdentity.diff

src/test/mir-opt/simplify_arm.id.SimplifyArmIdentity.diff

+11-9
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
let _3: u8; // in scope 0 at $DIR/simplify-arm.rs:11:14: 11:15
99
let mut _4: u8; // in scope 0 at $DIR/simplify-arm.rs:11:25: 11:26
1010
scope 1 {
11-
debug v => _3; // in scope 1 at $DIR/simplify-arm.rs:11:14: 11:15
11+
- debug v => _3; // in scope 1 at $DIR/simplify-arm.rs:11:14: 11:15
12+
+ debug v => ((_0 as Some).0: u8); // in scope 1 at $DIR/simplify-arm.rs:11:14: 11:15
1213
}
1314

1415
bb0: {
@@ -26,14 +27,15 @@
2627
}
2728

2829
bb3: {
29-
StorageLive(_3); // scope 0 at $DIR/simplify-arm.rs:11:14: 11:15
30-
_3 = ((_1 as Some).0: u8); // scope 0 at $DIR/simplify-arm.rs:11:14: 11:15
31-
StorageLive(_4); // scope 1 at $DIR/simplify-arm.rs:11:25: 11:26
32-
_4 = _3; // scope 1 at $DIR/simplify-arm.rs:11:25: 11:26
33-
((_0 as Some).0: u8) = move _4; // scope 1 at $DIR/simplify-arm.rs:11:20: 11:27
34-
discriminant(_0) = 1; // scope 1 at $DIR/simplify-arm.rs:11:20: 11:27
35-
StorageDead(_4); // scope 1 at $DIR/simplify-arm.rs:11:26: 11:27
36-
StorageDead(_3); // scope 0 at $DIR/simplify-arm.rs:11:26: 11:27
30+
- StorageLive(_3); // scope 0 at $DIR/simplify-arm.rs:11:14: 11:15
31+
- _3 = ((_1 as Some).0: u8); // scope 0 at $DIR/simplify-arm.rs:11:14: 11:15
32+
- StorageLive(_4); // scope 1 at $DIR/simplify-arm.rs:11:25: 11:26
33+
- _4 = _3; // scope 1 at $DIR/simplify-arm.rs:11:25: 11:26
34+
- ((_0 as Some).0: u8) = move _4; // scope 1 at $DIR/simplify-arm.rs:11:20: 11:27
35+
- discriminant(_0) = 1; // scope 1 at $DIR/simplify-arm.rs:11:20: 11:27
36+
- StorageDead(_4); // scope 1 at $DIR/simplify-arm.rs:11:26: 11:27
37+
- StorageDead(_3); // scope 0 at $DIR/simplify-arm.rs:11:26: 11:27
38+
+ _0 = move _1; // scope 1 at $DIR/simplify-arm.rs:11:20: 11:27
3739
goto -> bb4; // scope 0 at $DIR/simplify-arm.rs:10:5: 13:6
3840
}
3941

src/test/mir-opt/simplify_arm.id.SimplifyBranchSame.diff

+17-21
Original file line numberDiff line numberDiff line change
@@ -8,36 +8,32 @@
88
let _3: u8; // in scope 0 at $DIR/simplify-arm.rs:11:14: 11:15
99
let mut _4: u8; // in scope 0 at $DIR/simplify-arm.rs:11:25: 11:26
1010
scope 1 {
11-
debug v => _3; // in scope 1 at $DIR/simplify-arm.rs:11:14: 11:15
11+
debug v => ((_0 as Some).0: u8); // in scope 1 at $DIR/simplify-arm.rs:11:14: 11:15
1212
}
1313

1414
bb0: {
1515
_2 = discriminant(_1); // scope 0 at $DIR/simplify-arm.rs:11:9: 11:16
16-
switchInt(move _2) -> [0_isize: bb1, 1_isize: bb3, otherwise: bb2]; // scope 0 at $DIR/simplify-arm.rs:11:9: 11:16
16+
- switchInt(move _2) -> [0_isize: bb1, 1_isize: bb3, otherwise: bb2]; // scope 0 at $DIR/simplify-arm.rs:11:9: 11:16
17+
+ goto -> bb1; // scope 0 at $DIR/simplify-arm.rs:11:9: 11:16
1718
}
1819

1920
bb1: {
20-
discriminant(_0) = 0; // scope 0 at $DIR/simplify-arm.rs:12:17: 12:21
21-
goto -> bb4; // scope 0 at $DIR/simplify-arm.rs:10:5: 13:6
21+
- discriminant(_0) = 0; // scope 0 at $DIR/simplify-arm.rs:12:17: 12:21
22+
- goto -> bb4; // scope 0 at $DIR/simplify-arm.rs:10:5: 13:6
23+
- }
24+
-
25+
- bb2: {
26+
- unreachable; // scope 0 at $DIR/simplify-arm.rs:10:11: 10:12
27+
- }
28+
-
29+
- bb3: {
30+
_0 = move _1; // scope 1 at $DIR/simplify-arm.rs:11:20: 11:27
31+
- goto -> bb4; // scope 0 at $DIR/simplify-arm.rs:10:5: 13:6
32+
+ goto -> bb2; // scope 0 at $DIR/simplify-arm.rs:10:5: 13:6
2233
}
2334

24-
bb2: {
25-
unreachable; // scope 0 at $DIR/simplify-arm.rs:10:11: 10:12
26-
}
27-
28-
bb3: {
29-
StorageLive(_3); // scope 0 at $DIR/simplify-arm.rs:11:14: 11:15
30-
_3 = ((_1 as Some).0: u8); // scope 0 at $DIR/simplify-arm.rs:11:14: 11:15
31-
StorageLive(_4); // scope 1 at $DIR/simplify-arm.rs:11:25: 11:26
32-
_4 = _3; // scope 1 at $DIR/simplify-arm.rs:11:25: 11:26
33-
((_0 as Some).0: u8) = move _4; // scope 1 at $DIR/simplify-arm.rs:11:20: 11:27
34-
discriminant(_0) = 1; // scope 1 at $DIR/simplify-arm.rs:11:20: 11:27
35-
StorageDead(_4); // scope 1 at $DIR/simplify-arm.rs:11:26: 11:27
36-
StorageDead(_3); // scope 0 at $DIR/simplify-arm.rs:11:26: 11:27
37-
goto -> bb4; // scope 0 at $DIR/simplify-arm.rs:10:5: 13:6
38-
}
39-
40-
bb4: {
35+
- bb4: {
36+
+ bb2: {
4137
return; // scope 0 at $DIR/simplify-arm.rs:14:2: 14:2
4238
}
4339
}

0 commit comments

Comments
 (0)