@@ -14,7 +14,7 @@ use itertools::Itertools as _;
14
14
use rustc_index:: { bit_set:: BitSet , vec:: IndexVec } ;
15
15
use rustc_middle:: mir:: visit:: { NonUseContext , PlaceContext , Visitor } ;
16
16
use rustc_middle:: mir:: * ;
17
- use rustc_middle:: ty:: { List , Ty , TyCtxt } ;
17
+ use rustc_middle:: ty:: { self , List , Ty , TyCtxt } ;
18
18
use rustc_target:: abi:: VariantIdx ;
19
19
use std:: iter:: { Enumerate , Peekable } ;
20
20
use std:: slice:: Iter ;
@@ -527,52 +527,239 @@ fn match_variant_field_place<'tcx>(place: Place<'tcx>) -> Option<(Local, VarFiel
527
527
pub struct SimplifyBranchSame ;
528
528
529
529
impl < ' tcx > MirPass < ' tcx > for SimplifyBranchSame {
530
- fn run_pass ( & self , _: TyCtxt < ' tcx > , _: MirSource < ' tcx > , body : & mut Body < ' tcx > ) {
531
- let mut did_remove_blocks = false ;
532
- let bbs = body. basic_blocks_mut ( ) ;
533
- for bb_idx in bbs. indices ( ) {
534
- let targets = match & bbs[ bb_idx] . terminator ( ) . kind {
535
- TerminatorKind :: SwitchInt { targets, .. } => targets,
536
- _ => continue ,
537
- } ;
530
+ fn run_pass ( & self , tcx : TyCtxt < ' tcx > , source : MirSource < ' tcx > , body : & mut Body < ' tcx > ) {
531
+ trace ! ( "Running SimplifyBranchSame on {:?}" , source) ;
532
+ let finder = SimplifyBranchSameOptimizationFinder { body, tcx } ;
533
+ let opts = finder. find ( ) ;
534
+
535
+ let did_remove_blocks = opts. len ( ) > 0 ;
536
+ for opt in opts. iter ( ) {
537
+ trace ! ( "SUCCESS: Applying optimization {:?}" , opt) ;
538
+ // Replace `SwitchInt(..) -> [bb_first, ..];` with a `goto -> bb_first;`.
539
+ body. basic_blocks_mut ( ) [ opt. bb_to_opt_terminator ] . terminator_mut ( ) . kind =
540
+ TerminatorKind :: Goto { target : opt. bb_to_goto } ;
541
+ }
542
+
543
+ if did_remove_blocks {
544
+ // We have dead blocks now, so remove those.
545
+ simplify:: remove_dead_blocks ( body) ;
546
+ }
547
+ }
548
+ }
549
+
550
+ #[ derive( Debug ) ]
551
+ struct SimplifyBranchSameOptimization {
552
+ /// All basic blocks are equal so go to this one
553
+ bb_to_goto : BasicBlock ,
554
+ /// Basic block where the terminator can be simplified to a goto
555
+ bb_to_opt_terminator : BasicBlock ,
556
+ }
557
+
558
+ struct SimplifyBranchSameOptimizationFinder < ' a , ' tcx > {
559
+ body : & ' a Body < ' tcx > ,
560
+ tcx : TyCtxt < ' tcx > ,
561
+ }
538
562
539
- let mut iter_bbs_reachable = targets
540
- . iter ( )
541
- . map ( |idx| ( * idx, & bbs[ * idx] ) )
542
- . filter ( |( _, bb) | {
543
- // Reaching `unreachable` is UB so assume it doesn't happen.
544
- bb. terminator ( ) . kind != TerminatorKind :: Unreachable
563
+ impl < ' a , ' tcx > SimplifyBranchSameOptimizationFinder < ' a , ' tcx > {
564
+ fn find ( & self ) -> Vec < SimplifyBranchSameOptimization > {
565
+ self . body
566
+ . basic_blocks ( )
567
+ . iter_enumerated ( )
568
+ . filter_map ( |( bb_idx, bb) | {
569
+ let ( discr_switched_on, targets) = match & bb. terminator ( ) . kind {
570
+ TerminatorKind :: SwitchInt { targets, discr, .. } => ( discr, targets) ,
571
+ _ => return None ,
572
+ } ;
573
+
574
+ // find the adt that has its discriminant read
575
+ // assuming this must be the last statement of the block
576
+ let adt_matched_on = match & bb. statements . last ( ) ?. kind {
577
+ StatementKind :: Assign ( box ( place, rhs) )
578
+ if Some ( * place) == discr_switched_on. place ( ) =>
579
+ {
580
+ match rhs {
581
+ Rvalue :: Discriminant ( adt_place) if adt_place. ty ( self . body , self . tcx ) . ty . is_enum ( ) => adt_place,
582
+ _ => {
583
+ trace ! ( "NO: expected a discriminant read of an enum instead of: {:?}" , rhs) ;
584
+ return None ;
585
+ }
586
+ }
587
+ }
588
+ other => {
589
+ trace ! ( "NO: expected an assignment of a discriminant read to a place. Found: {:?}" , other) ;
590
+ return None
591
+ } ,
592
+ } ;
593
+
594
+ let mut iter_bbs_reachable = targets
595
+ . iter ( )
596
+ . map ( |idx| ( * idx, & self . body . basic_blocks ( ) [ * idx] ) )
597
+ . filter ( |( _, bb) | {
598
+ // Reaching `unreachable` is UB so assume it doesn't happen.
599
+ bb. terminator ( ) . kind != TerminatorKind :: Unreachable
545
600
// But `asm!(...)` could abort the program,
546
601
// so we cannot assume that the `unreachable` terminator itself is reachable.
547
602
// FIXME(Centril): use a normalization pass instead of a check.
548
603
|| bb. statements . iter ( ) . any ( |stmt| match stmt. kind {
549
604
StatementKind :: LlvmInlineAsm ( ..) => true ,
550
605
_ => false ,
551
606
} )
552
- } )
553
- . peekable ( ) ;
554
-
555
- // We want to `goto -> bb_first`.
556
- let bb_first = iter_bbs_reachable. peek ( ) . map ( |( idx, _) | * idx) . unwrap_or ( targets[ 0 ] ) ;
557
-
558
- // All successor basic blocks should have the exact same form.
559
- let all_successors_equivalent =
560
- iter_bbs_reachable. map ( |( _, bb) | bb) . tuple_windows ( ) . all ( |( bb_l, bb_r) | {
561
- bb_l. is_cleanup == bb_r. is_cleanup
562
- && bb_l. terminator ( ) . kind == bb_r. terminator ( ) . kind
563
- && bb_l. statements . iter ( ) . eq_by ( & bb_r. statements , |x, y| x. kind == y. kind )
564
- } ) ;
565
-
566
- if all_successors_equivalent {
567
- // Replace `SwitchInt(..) -> [bb_first, ..];` with a `goto -> bb_first;`.
568
- bbs[ bb_idx] . terminator_mut ( ) . kind = TerminatorKind :: Goto { target : bb_first } ;
569
- did_remove_blocks = true ;
607
+ } )
608
+ . peekable ( ) ;
609
+
610
+ let bb_first = iter_bbs_reachable. peek ( ) . map ( |( idx, _) | * idx) . unwrap_or ( targets[ 0 ] ) ;
611
+ let mut all_successors_equivalent = StatementEquality :: TrivialEqual ;
612
+
613
+ // All successor basic blocks must be equal or contain statements that are pairwise considered equal.
614
+ for ( ( bb_l_idx, bb_l) , ( bb_r_idx, bb_r) ) in iter_bbs_reachable. tuple_windows ( ) {
615
+ let trivial_checks = bb_l. is_cleanup == bb_r. is_cleanup
616
+ && bb_l. terminator ( ) . kind == bb_r. terminator ( ) . kind ;
617
+ let statement_check = || {
618
+ bb_l. statements . iter ( ) . zip ( & bb_r. statements ) . try_fold ( StatementEquality :: TrivialEqual , |acc, ( l, r) | {
619
+ let stmt_equality = self . statement_equality ( * adt_matched_on, & l, bb_l_idx, & r, bb_r_idx) ;
620
+ if matches ! ( stmt_equality, StatementEquality :: NotEqual ) {
621
+ // short circuit
622
+ None
623
+ } else {
624
+ Some ( acc. combine ( & stmt_equality) )
625
+ }
626
+ } )
627
+ . unwrap_or ( StatementEquality :: NotEqual )
628
+ } ;
629
+ if !trivial_checks {
630
+ all_successors_equivalent = StatementEquality :: NotEqual ;
631
+ break ;
632
+ }
633
+ all_successors_equivalent = all_successors_equivalent. combine ( & statement_check ( ) ) ;
634
+ } ;
635
+
636
+ match all_successors_equivalent{
637
+ StatementEquality :: TrivialEqual => {
638
+ // statements are trivially equal, so just take first
639
+ trace ! ( "Statements are trivially equal" ) ;
640
+ Some ( SimplifyBranchSameOptimization {
641
+ bb_to_goto : bb_first,
642
+ bb_to_opt_terminator : bb_idx,
643
+ } )
644
+ }
645
+ StatementEquality :: ConsideredEqual ( bb_to_choose) => {
646
+ trace ! ( "Statements are considered equal" ) ;
647
+ Some ( SimplifyBranchSameOptimization {
648
+ bb_to_goto : bb_to_choose,
649
+ bb_to_opt_terminator : bb_idx,
650
+ } )
651
+ }
652
+ StatementEquality :: NotEqual => {
653
+ trace ! ( "NO: not all successors of basic block {:?} were equivalent" , bb_idx) ;
654
+ None
655
+ }
656
+ }
657
+ } )
658
+ . collect ( )
659
+ }
660
+
661
+ /// Tests if two statements can be considered equal
662
+ ///
663
+ /// Statements can be trivially equal if the kinds match.
664
+ /// But they can also be considered equal in the following case A:
665
+ /// ```
666
+ /// discriminant(_0) = 0; // bb1
667
+ /// _0 = move _1; // bb2
668
+ /// ```
669
+ /// In this case the two statements are equal iff
670
+ /// 1: _0 is an enum where the variant index 0 is fieldless, and
671
+ /// 2: bb1 was targeted by a switch where the discriminant of _1 was switched on
672
+ fn statement_equality (
673
+ & self ,
674
+ adt_matched_on : Place < ' tcx > ,
675
+ x : & Statement < ' tcx > ,
676
+ x_bb_idx : BasicBlock ,
677
+ y : & Statement < ' tcx > ,
678
+ y_bb_idx : BasicBlock ,
679
+ ) -> StatementEquality {
680
+ let helper = |rhs : & Rvalue < ' tcx > ,
681
+ place : & Box < Place < ' tcx > > ,
682
+ variant_index : & VariantIdx ,
683
+ side_to_choose| {
684
+ let place_type = place. ty ( self . body , self . tcx ) . ty ;
685
+ let adt = match place_type. kind {
686
+ ty:: Adt ( adt, _) if adt. is_enum ( ) => adt,
687
+ _ => return StatementEquality :: NotEqual ,
688
+ } ;
689
+ let variant_is_fieldless = adt. variants [ * variant_index] . fields . is_empty ( ) ;
690
+ if !variant_is_fieldless {
691
+ trace ! ( "NO: variant {:?} was not fieldless" , variant_index) ;
692
+ return StatementEquality :: NotEqual ;
693
+ }
694
+
695
+ match rhs {
696
+ Rvalue :: Use ( operand) if operand. place ( ) == Some ( adt_matched_on) => {
697
+ StatementEquality :: ConsideredEqual ( side_to_choose)
698
+ }
699
+ _ => {
700
+ trace ! (
701
+ "NO: RHS of assignment was {:?}, but expected it to match the adt being matched on in the switch, which is {:?}" ,
702
+ rhs,
703
+ adt_matched_on
704
+ ) ;
705
+ StatementEquality :: NotEqual
706
+ }
707
+ }
708
+ } ;
709
+ match ( & x. kind , & y. kind ) {
710
+ // trivial case
711
+ ( x, y) if x == y => StatementEquality :: TrivialEqual ,
712
+
713
+ // check for case A
714
+ (
715
+ StatementKind :: Assign ( box ( _, rhs) ) ,
716
+ StatementKind :: SetDiscriminant { place, variant_index } ,
717
+ ) => {
718
+ // choose basic block of x, as that has the assign
719
+ helper ( rhs, place, variant_index, x_bb_idx)
720
+ }
721
+ (
722
+ StatementKind :: SetDiscriminant { place, variant_index } ,
723
+ StatementKind :: Assign ( box ( _, rhs) ) ,
724
+ ) => {
725
+ // choose basic block of y, as that has the assign
726
+ helper ( rhs, place, variant_index, y_bb_idx)
727
+ }
728
+ _ => {
729
+ trace ! ( "NO: statements `{:?}` and `{:?}` not considered equal" , x, y) ;
730
+ StatementEquality :: NotEqual
570
731
}
571
732
}
733
+ }
734
+ }
572
735
573
- if did_remove_blocks {
574
- // We have dead blocks now, so remove those.
575
- simplify:: remove_dead_blocks ( body) ;
736
+ #[ derive( Copy , Clone , Eq , PartialEq ) ]
737
+ enum StatementEquality {
738
+ /// The two statements are trivially equal; same kind
739
+ TrivialEqual ,
740
+ /// The two statements are considered equal, but may be of different kinds. The BasicBlock field is the basic block to jump to when performing the branch-same optimization.
741
+ /// For example, `_0 = _1` and `discriminant(_0) = discriminant(0)` are considered equal if 0 is a fieldless variant of an enum. But we don't want to jump to the basic block with the SetDiscriminant, as that is not legal if _1 is not the 0 variant index
742
+ ConsideredEqual ( BasicBlock ) ,
743
+ /// The two statements are not equal
744
+ NotEqual ,
745
+ }
746
+
747
+ impl StatementEquality {
748
+ fn combine ( & self , other : & StatementEquality ) -> StatementEquality {
749
+ use StatementEquality :: * ;
750
+ match ( self , other) {
751
+ ( TrivialEqual , TrivialEqual ) => TrivialEqual ,
752
+ ( TrivialEqual , ConsideredEqual ( b) ) | ( ConsideredEqual ( b) , TrivialEqual ) => {
753
+ ConsideredEqual ( * b)
754
+ }
755
+ ( ConsideredEqual ( b1) , ConsideredEqual ( b2) ) => {
756
+ if b1 == b2 {
757
+ ConsideredEqual ( * b1)
758
+ } else {
759
+ NotEqual
760
+ }
761
+ }
762
+ ( _, NotEqual ) | ( NotEqual , _) => NotEqual ,
576
763
}
577
764
}
578
765
}
0 commit comments