diff --git a/compiler/rustc_mir/src/transform/simplify_try.rs b/compiler/rustc_mir/src/transform/simplify_try.rs index 9a11d92724020..bf6fdebb3d3c7 100644 --- a/compiler/rustc_mir/src/transform/simplify_try.rs +++ b/compiler/rustc_mir/src/transform/simplify_try.rs @@ -16,7 +16,7 @@ use rustc_middle::mir::visit::{NonUseContext, PlaceContext, Visitor}; use rustc_middle::mir::*; use rustc_middle::ty::{self, List, Ty, TyCtxt}; use rustc_target::abi::VariantIdx; -use std::iter::{Enumerate, Peekable}; +use std::iter::{once, Enumerate, Peekable}; use std::slice::Iter; /// Simplifies arms of form `Variant(x) => Variant(x)` to just a move. @@ -551,6 +551,12 @@ struct SimplifyBranchSameOptimization { bb_to_opt_terminator: BasicBlock, } +struct SwitchTargetAndValue { + target: BasicBlock, + // None in case of the `otherwise` case + value: Option, +} + struct SimplifyBranchSameOptimizationFinder<'a, 'tcx> { body: &'a Body<'tcx>, tcx: TyCtxt<'tcx>, @@ -562,8 +568,15 @@ impl<'a, 'tcx> SimplifyBranchSameOptimizationFinder<'a, 'tcx> { .basic_blocks() .iter_enumerated() .filter_map(|(bb_idx, bb)| { - let (discr_switched_on, targets) = match &bb.terminator().kind { - TerminatorKind::SwitchInt { targets, discr, .. } => (discr, targets), + let (discr_switched_on, targets_and_values):(_, Vec<_>) = match &bb.terminator().kind { + TerminatorKind::SwitchInt { targets, discr, values, .. } => { + // if values.len() == targets.len() - 1, we need to include None where no value is present + // such that the zip does not throw away targets. If no `otherwise` case is in targets, the zip will simply throw away the added None + let values_extended = values.iter().map(|x|Some(*x)).chain(once(None)); + let targets_and_values = targets.iter().zip(values_extended) + .map(|(target, value)| SwitchTargetAndValue{target:*target, value:value}) + .collect(); + (discr, targets_and_values)}, _ => return None, }; @@ -587,9 +600,9 @@ impl<'a, 'tcx> SimplifyBranchSameOptimizationFinder<'a, 'tcx> { }, }; - let mut iter_bbs_reachable = targets + let mut iter_bbs_reachable = targets_and_values .iter() - .map(|idx| (*idx, &self.body.basic_blocks()[*idx])) + .map(|target_and_value| (target_and_value, &self.body.basic_blocks()[target_and_value.target])) .filter(|(_, bb)| { // Reaching `unreachable` is UB so assume it doesn't happen. bb.terminator().kind != TerminatorKind::Unreachable @@ -603,16 +616,16 @@ impl<'a, 'tcx> SimplifyBranchSameOptimizationFinder<'a, 'tcx> { }) .peekable(); - let bb_first = iter_bbs_reachable.peek().map(|(idx, _)| *idx).unwrap_or(targets[0]); + let bb_first = iter_bbs_reachable.peek().map(|(idx, _)| *idx).unwrap_or(&targets_and_values[0]); let mut all_successors_equivalent = StatementEquality::TrivialEqual; // All successor basic blocks must be equal or contain statements that are pairwise considered equal. - for ((bb_l_idx,bb_l), (bb_r_idx,bb_r)) in iter_bbs_reachable.tuple_windows() { + for ((target_and_value_l,bb_l), (target_and_value_r,bb_r)) in iter_bbs_reachable.tuple_windows() { let trivial_checks = bb_l.is_cleanup == bb_r.is_cleanup && bb_l.terminator().kind == bb_r.terminator().kind; let statement_check = || { bb_l.statements.iter().zip(&bb_r.statements).try_fold(StatementEquality::TrivialEqual, |acc,(l,r)| { - let stmt_equality = self.statement_equality(*adt_matched_on, &l, bb_l_idx, &r, bb_r_idx, self.tcx.sess.opts.debugging_opts.mir_opt_level); + let stmt_equality = self.statement_equality(*adt_matched_on, &l, target_and_value_l, &r, target_and_value_r); if matches!(stmt_equality, StatementEquality::NotEqual) { // short circuit None @@ -634,7 +647,7 @@ impl<'a, 'tcx> SimplifyBranchSameOptimizationFinder<'a, 'tcx> { // statements are trivially equal, so just take first trace!("Statements are trivially equal"); Some(SimplifyBranchSameOptimization { - bb_to_goto: bb_first, + bb_to_goto: bb_first.target, bb_to_opt_terminator: bb_idx, }) } @@ -669,10 +682,9 @@ impl<'a, 'tcx> SimplifyBranchSameOptimizationFinder<'a, 'tcx> { &self, adt_matched_on: Place<'tcx>, x: &Statement<'tcx>, - x_bb_idx: BasicBlock, + x_target_and_value: &SwitchTargetAndValue, y: &Statement<'tcx>, - y_bb_idx: BasicBlock, - mir_opt_level: usize, + y_target_and_value: &SwitchTargetAndValue, ) -> StatementEquality { let helper = |rhs: &Rvalue<'tcx>, place: &Place<'tcx>, @@ -691,13 +703,7 @@ impl<'a, 'tcx> SimplifyBranchSameOptimizationFinder<'a, 'tcx> { match rhs { Rvalue::Use(operand) if operand.place() == Some(adt_matched_on) => { - // FIXME(76803): This logic is currently broken because it does not take into - // account the current discriminant value. - if mir_opt_level > 2 { - StatementEquality::ConsideredEqual(side_to_choose) - } else { - StatementEquality::NotEqual - } + StatementEquality::ConsideredEqual(side_to_choose) } _ => { trace!( @@ -717,16 +723,20 @@ impl<'a, 'tcx> SimplifyBranchSameOptimizationFinder<'a, 'tcx> { ( StatementKind::Assign(box (_, rhs)), StatementKind::SetDiscriminant { place, variant_index }, - ) => { + ) + // we need to make sure that the switch value that targets the bb with SetDiscriminant (y), is the same as the variant index + if Some(variant_index.index() as u128) == y_target_and_value.value => { // choose basic block of x, as that has the assign - helper(rhs, place, variant_index, x_bb_idx) + helper(rhs, place, variant_index, x_target_and_value.target) } ( StatementKind::SetDiscriminant { place, variant_index }, StatementKind::Assign(box (_, rhs)), - ) => { + ) + // we need to make sure that the switch value that targets the bb with SetDiscriminant (x), is the same as the variant index + if Some(variant_index.index() as u128) == x_target_and_value.value => { // choose basic block of y, as that has the assign - helper(rhs, place, variant_index, y_bb_idx) + helper(rhs, place, variant_index, y_target_and_value.target) } _ => { trace!("NO: statements `{:?}` and `{:?}` not considered equal", x, y); diff --git a/src/test/mir-opt/76803_regression.encode.SimplifyBranchSame.diff b/src/test/mir-opt/76803_regression.encode.SimplifyBranchSame.diff index 63139285209b6..28b8329606c1b 100644 --- a/src/test/mir-opt/76803_regression.encode.SimplifyBranchSame.diff +++ b/src/test/mir-opt/76803_regression.encode.SimplifyBranchSame.diff @@ -8,22 +8,20 @@ bb0: { _2 = discriminant(_1); // scope 0 at $DIR/76803_regression.rs:12:9: 12:16 -- switchInt(move _2) -> [0_isize: bb2, otherwise: bb1]; // scope 0 at $DIR/76803_regression.rs:12:9: 12:16 -+ goto -> bb1; // scope 0 at $DIR/76803_regression.rs:12:9: 12:16 + switchInt(move _2) -> [0_isize: bb2, otherwise: bb1]; // scope 0 at $DIR/76803_regression.rs:12:9: 12:16 } bb1: { _0 = move _1; // scope 0 at $DIR/76803_regression.rs:13:14: 13:15 -- goto -> bb3; // scope 0 at $DIR/76803_regression.rs:11:5: 14:6 -+ goto -> bb2; // scope 0 at $DIR/76803_regression.rs:11:5: 14:6 + goto -> bb3; // scope 0 at $DIR/76803_regression.rs:11:5: 14:6 } bb2: { -- discriminant(_0) = 1; // scope 0 at $DIR/76803_regression.rs:12:20: 12:27 -- goto -> bb3; // scope 0 at $DIR/76803_regression.rs:11:5: 14:6 -- } -- -- bb3: { + discriminant(_0) = 1; // scope 0 at $DIR/76803_regression.rs:12:20: 12:27 + goto -> bb3; // scope 0 at $DIR/76803_regression.rs:11:5: 14:6 + } + + bb3: { return; // scope 0 at $DIR/76803_regression.rs:15:2: 15:2 } } diff --git a/src/test/mir-opt/simplify_arm.id.SimplifyBranchSame.diff b/src/test/mir-opt/simplify_arm.id.SimplifyBranchSame.diff index 06f359da2e70d..81a0e6ba0b4ee 100644 --- a/src/test/mir-opt/simplify_arm.id.SimplifyBranchSame.diff +++ b/src/test/mir-opt/simplify_arm.id.SimplifyBranchSame.diff @@ -13,24 +13,27 @@ bb0: { _2 = discriminant(_1); // scope 0 at $DIR/simplify-arm.rs:11:9: 11:16 - switchInt(move _2) -> [0_isize: bb1, 1_isize: bb3, otherwise: bb2]; // scope 0 at $DIR/simplify-arm.rs:11:9: 11:16 +- switchInt(move _2) -> [0_isize: bb1, 1_isize: bb3, otherwise: bb2]; // scope 0 at $DIR/simplify-arm.rs:11:9: 11:16 ++ goto -> bb1; // scope 0 at $DIR/simplify-arm.rs:11:9: 11:16 } bb1: { - discriminant(_0) = 0; // scope 0 at $DIR/simplify-arm.rs:12:17: 12:21 - goto -> bb4; // scope 0 at $DIR/simplify-arm.rs:10:5: 13:6 - } - - bb2: { - unreachable; // scope 0 at $DIR/simplify-arm.rs:10:11: 10:12 - } - - bb3: { +- discriminant(_0) = 0; // scope 0 at $DIR/simplify-arm.rs:12:17: 12:21 +- goto -> bb4; // scope 0 at $DIR/simplify-arm.rs:10:5: 13:6 +- } +- +- bb2: { +- unreachable; // scope 0 at $DIR/simplify-arm.rs:10:11: 10:12 +- } +- +- bb3: { _0 = move _1; // scope 1 at $DIR/simplify-arm.rs:11:20: 11:27 - goto -> bb4; // scope 0 at $DIR/simplify-arm.rs:10:5: 13:6 +- goto -> bb4; // scope 0 at $DIR/simplify-arm.rs:10:5: 13:6 ++ goto -> bb2; // scope 0 at $DIR/simplify-arm.rs:10:5: 13:6 } - bb4: { +- bb4: { ++ bb2: { return; // scope 0 at $DIR/simplify-arm.rs:14:2: 14:2 } }