diff --git a/compiler/rustc_mir_transform/src/early_otherwise_branch.rs b/compiler/rustc_mir_transform/src/early_otherwise_branch.rs index aed7f20aaea0b..704ed508b22a8 100644 --- a/compiler/rustc_mir_transform/src/early_otherwise_branch.rs +++ b/compiler/rustc_mir_transform/src/early_otherwise_branch.rs @@ -133,29 +133,18 @@ impl<'tcx> crate::MirPass<'tcx> for EarlyOtherwiseBranch { let mut patch = MirPatch::new(body); - let (second_discriminant_temp, second_operand) = if opt_data.need_hoist_discriminant { - // create temp to store second discriminant in, `_s` in example above - let second_discriminant_temp = - patch.new_temp(opt_data.child_ty, opt_data.child_source.span); + // create temp to store second discriminant in, `_s` in example above + let second_discriminant_temp = + patch.new_temp(opt_data.child_ty, opt_data.child_source.span); - patch.add_statement( - parent_end, - StatementKind::StorageLive(second_discriminant_temp), - ); + patch.add_statement(parent_end, StatementKind::StorageLive(second_discriminant_temp)); - // create assignment of discriminant - patch.add_assign( - parent_end, - Place::from(second_discriminant_temp), - Rvalue::Discriminant(opt_data.child_place), - ); - ( - Some(second_discriminant_temp), - Operand::Move(Place::from(second_discriminant_temp)), - ) - } else { - (None, Operand::Copy(opt_data.child_place)) - }; + // create assignment of discriminant + patch.add_assign( + parent_end, + Place::from(second_discriminant_temp), + Rvalue::Discriminant(opt_data.child_place), + ); // create temp to store inequality comparison between the two discriminants, `_t` in // example above @@ -164,9 +153,11 @@ impl<'tcx> crate::MirPass<'tcx> for EarlyOtherwiseBranch { let comp_temp = patch.new_temp(comp_res_type, opt_data.child_source.span); patch.add_statement(parent_end, StatementKind::StorageLive(comp_temp)); - // create inequality comparison - let comp_rvalue = - Rvalue::BinaryOp(nequal, Box::new((parent_op.clone(), second_operand))); + // create inequality comparison between the two discriminants + let comp_rvalue = Rvalue::BinaryOp( + nequal, + Box::new((parent_op.clone(), Operand::Move(Place::from(second_discriminant_temp)))), + ); patch.add_statement( parent_end, StatementKind::Assign(Box::new((Place::from(comp_temp), comp_rvalue))), @@ -202,13 +193,8 @@ impl<'tcx> crate::MirPass<'tcx> for EarlyOtherwiseBranch { TerminatorKind::if_(Operand::Move(Place::from(comp_temp)), true_case, false_case), ); - if let Some(second_discriminant_temp) = second_discriminant_temp { - // generate StorageDead for the second_discriminant_temp not in use anymore - patch.add_statement( - parent_end, - StatementKind::StorageDead(second_discriminant_temp), - ); - } + // generate StorageDead for the second_discriminant_temp not in use anymore + patch.add_statement(parent_end, StatementKind::StorageDead(second_discriminant_temp)); // Generate a StorageDead for comp_temp in each of the targets, since we moved it into // the switch @@ -236,7 +222,6 @@ struct OptimizationData<'tcx> { child_place: Place<'tcx>, child_ty: Ty<'tcx>, child_source: SourceInfo, - need_hoist_discriminant: bool, } fn evaluate_candidate<'tcx>( @@ -250,12 +235,44 @@ fn evaluate_candidate<'tcx>( return None; }; let parent_ty = parent_discr.ty(body.local_decls(), tcx); + if !bbs[targets.otherwise()].is_empty_unreachable() { + // Someone could write code like this: + // ```rust + // let Q = val; + // if discriminant(P) == otherwise { + // let ptr = &mut Q as *mut _ as *mut u8; + // // It may be difficult for us to effectively determine whether values are valid. + // // Invalid values can come from all sorts of corners. + // unsafe { *ptr = 10; } + // } + // + // match P { + // A => match Q { + // A => { + // // code + // } + // _ => { + // // don't use Q + // } + // } + // _ => { + // // don't use Q + // } + // }; + // ``` + // + // Hoisting the `discriminant(Q)` out of the `A` arm causes us to compute the discriminant + // of an invalid value, which is UB. + // In order to fix this, **we would either need to show that the discriminant computation of + // `place` is computed in all branches**. + // FIXME(#95162) For the moment, we adopt a conservative approach and + // consider only the `otherwise` branch has no statements and an unreachable terminator. + return None; + } let (_, child) = targets.iter().next()?; - - let Terminator { - kind: TerminatorKind::SwitchInt { targets: child_targets, discr: child_discr }, - source_info, - } = bbs[child].terminator() + let child_terminator = &bbs[child].terminator(); + let TerminatorKind::SwitchInt { targets: child_targets, discr: child_discr } = + &child_terminator.kind else { return None; }; @@ -263,115 +280,25 @@ fn evaluate_candidate<'tcx>( if child_ty != parent_ty { return None; } - - // We only handle: - // ``` - // bb4: { - // _8 = discriminant((_3.1: Enum1)); - // switchInt(move _8) -> [2: bb7, otherwise: bb1]; - // } - // ``` - // and - // ``` - // bb2: { - // switchInt((_3.1: u64)) -> [1: bb5, otherwise: bb1]; - // } - // ``` - if bbs[child].statements.len() > 1 { + let Some(StatementKind::Assign(boxed)) = &bbs[child].statements.first().map(|x| &x.kind) else { return None; - } - - // When thie BB has exactly one statement, this statement should be discriminant. - let need_hoist_discriminant = bbs[child].statements.len() == 1; - let child_place = if need_hoist_discriminant { - if !bbs[targets.otherwise()].is_empty_unreachable() { - // Someone could write code like this: - // ```rust - // let Q = val; - // if discriminant(P) == otherwise { - // let ptr = &mut Q as *mut _ as *mut u8; - // // It may be difficult for us to effectively determine whether values are valid. - // // Invalid values can come from all sorts of corners. - // unsafe { *ptr = 10; } - // } - // - // match P { - // A => match Q { - // A => { - // // code - // } - // _ => { - // // don't use Q - // } - // } - // _ => { - // // don't use Q - // } - // }; - // ``` - // - // Hoisting the `discriminant(Q)` out of the `A` arm causes us to compute the discriminant of an - // invalid value, which is UB. - // In order to fix this, **we would either need to show that the discriminant computation of - // `place` is computed in all branches**. - // FIXME(#95162) For the moment, we adopt a conservative approach and - // consider only the `otherwise` branch has no statements and an unreachable terminator. - return None; - } - // Handle: - // ``` - // bb4: { - // _8 = discriminant((_3.1: Enum1)); - // switchInt(move _8) -> [2: bb7, otherwise: bb1]; - // } - // ``` - let [ - Statement { - kind: StatementKind::Assign(box (_, Rvalue::Discriminant(child_place))), - .. - }, - ] = bbs[child].statements.as_slice() - else { - return None; - }; - *child_place - } else { - // Handle: - // ``` - // bb2: { - // switchInt((_3.1: u64)) -> [1: bb5, otherwise: bb1]; - // } - // ``` - let Operand::Copy(child_place) = child_discr else { - return None; - }; - *child_place }; - let destination = if need_hoist_discriminant || bbs[targets.otherwise()].is_empty_unreachable() - { - child_targets.otherwise() - } else { - targets.otherwise() + let (_, Rvalue::Discriminant(child_place)) = &**boxed else { + return None; }; + let destination = child_targets.otherwise(); // Verify that the optimization is legal for each branch for (value, child) in targets.iter() { - if !verify_candidate_branch( - &bbs[child], - value, - child_place, - destination, - need_hoist_discriminant, - ) { + if !verify_candidate_branch(&bbs[child], value, *child_place, destination) { return None; } } Some(OptimizationData { destination, - child_place, + child_place: *child_place, child_ty, - child_source: *source_info, - need_hoist_discriminant, + child_source: child_terminator.source_info, }) } @@ -380,48 +307,31 @@ fn verify_candidate_branch<'tcx>( value: u128, place: Place<'tcx>, destination: BasicBlock, - need_hoist_discriminant: bool, ) -> bool { - // In order for the optimization to be correct, the terminator must be a `SwitchInt`. - let TerminatorKind::SwitchInt { discr: switch_op, targets } = &branch.terminator().kind else { - return false; - }; - if need_hoist_discriminant { - // If we need hoist discriminant, the branch must have exactly one statement. - let [statement] = branch.statements.as_slice() else { - return false; - }; - // The statement must assign the discriminant of `place`. - let StatementKind::Assign(box (discr_place, Rvalue::Discriminant(from_place))) = - statement.kind - else { - return false; - }; - if from_place != place { - return false; - } - // The assignment must invalidate a local that terminate on a `SwitchInt`. - if !discr_place.projection.is_empty() || *switch_op != Operand::Move(discr_place) { - return false; - } + // In order for the optimization to be correct, the branch must... + // ...have exactly one statement + if let [statement] = branch.statements.as_slice() + // ...assign the discriminant of `place` in that statement + && let StatementKind::Assign(boxed) = &statement.kind + && let (discr_place, Rvalue::Discriminant(from_place)) = &**boxed + && *from_place == place + // ...make that assignment to a local + && discr_place.projection.is_empty() + // ...terminate on a `SwitchInt` that invalidates that local + && let TerminatorKind::SwitchInt { discr: switch_op, targets, .. } = + &branch.terminator().kind + && *switch_op == Operand::Move(*discr_place) + // ...fall through to `destination` if the switch misses + && destination == targets.otherwise() + // ...have a branch for value `value` + && let mut iter = targets.iter() + && let Some((target_value, _)) = iter.next() + && target_value == value + // ...and have no more branches + && iter.next().is_none() + { + true } else { - // If we don't need hoist discriminant, the branch must not have any statements. - if !branch.statements.is_empty() { - return false; - } - // The place on `SwitchInt` must be the same. - if *switch_op != Operand::Copy(place) { - return false; - } + false } - // It must fall through to `destination` if the switch misses. - if destination != targets.otherwise() { - return false; - } - // It must have exactly one branch for value `value` and have no more branches. - let mut iter = targets.iter(); - let (Some((target_value, _)), None) = (iter.next(), iter.next()) else { - return false; - }; - target_value == value } diff --git a/tests/mir-opt/early_otherwise_branch.opt5.EarlyOtherwiseBranch.diff b/tests/mir-opt/early_otherwise_branch.opt5.EarlyOtherwiseBranch.diff deleted file mode 100644 index b7447ef0c4699..0000000000000 --- a/tests/mir-opt/early_otherwise_branch.opt5.EarlyOtherwiseBranch.diff +++ /dev/null @@ -1,77 +0,0 @@ -- // MIR for `opt5` before EarlyOtherwiseBranch -+ // MIR for `opt5` after EarlyOtherwiseBranch - - fn opt5(_1: u32, _2: u32) -> u32 { - debug x => _1; - debug y => _2; - let mut _0: u32; - let mut _3: (u32, u32); - let mut _4: u32; - let mut _5: u32; -+ let mut _6: bool; - - bb0: { - StorageLive(_3); - StorageLive(_4); - _4 = copy _1; - StorageLive(_5); - _5 = copy _2; - _3 = (move _4, move _5); - StorageDead(_5); - StorageDead(_4); -- switchInt(copy (_3.0: u32)) -> [1: bb2, 2: bb3, 3: bb4, otherwise: bb1]; -+ StorageLive(_6); -+ _6 = Ne(copy (_3.0: u32), copy (_3.1: u32)); -+ switchInt(move _6) -> [0: bb6, otherwise: bb1]; - } - - bb1: { -+ StorageDead(_6); - _0 = const 0_u32; -- goto -> bb8; -+ goto -> bb5; - } - - bb2: { -- switchInt(copy (_3.1: u32)) -> [1: bb7, otherwise: bb1]; -+ _0 = const 6_u32; -+ goto -> bb5; - } - - bb3: { -- switchInt(copy (_3.1: u32)) -> [2: bb6, otherwise: bb1]; -+ _0 = const 5_u32; -+ goto -> bb5; - } - - bb4: { -- switchInt(copy (_3.1: u32)) -> [3: bb5, otherwise: bb1]; -+ _0 = const 4_u32; -+ goto -> bb5; - } - - bb5: { -- _0 = const 6_u32; -- goto -> bb8; -+ StorageDead(_3); -+ return; - } - - bb6: { -- _0 = const 5_u32; -- goto -> bb8; -- } -- -- bb7: { -- _0 = const 4_u32; -- goto -> bb8; -- } -- -- bb8: { -- StorageDead(_3); -- return; -+ StorageDead(_6); -+ switchInt(copy (_3.0: u32)) -> [1: bb4, 2: bb3, 3: bb2, otherwise: bb1]; - } - } - diff --git a/tests/mir-opt/early_otherwise_branch.opt5_failed.EarlyOtherwiseBranch.diff b/tests/mir-opt/early_otherwise_branch.opt5_failed.EarlyOtherwiseBranch.diff deleted file mode 100644 index af16271f8b1a8..0000000000000 --- a/tests/mir-opt/early_otherwise_branch.opt5_failed.EarlyOtherwiseBranch.diff +++ /dev/null @@ -1,61 +0,0 @@ -- // MIR for `opt5_failed` before EarlyOtherwiseBranch -+ // MIR for `opt5_failed` after EarlyOtherwiseBranch - - fn opt5_failed(_1: u32, _2: u32) -> u32 { - debug x => _1; - debug y => _2; - let mut _0: u32; - let mut _3: (u32, u32); - let mut _4: u32; - let mut _5: u32; - - bb0: { - StorageLive(_3); - StorageLive(_4); - _4 = copy _1; - StorageLive(_5); - _5 = copy _2; - _3 = (move _4, move _5); - StorageDead(_5); - StorageDead(_4); - switchInt(copy (_3.0: u32)) -> [1: bb2, 2: bb3, 3: bb4, otherwise: bb1]; - } - - bb1: { - _0 = const 0_u32; - goto -> bb8; - } - - bb2: { - switchInt(copy (_3.1: u32)) -> [1: bb7, otherwise: bb1]; - } - - bb3: { - switchInt(copy (_3.1: u32)) -> [2: bb6, otherwise: bb1]; - } - - bb4: { - switchInt(copy (_3.1: u32)) -> [2: bb5, otherwise: bb1]; - } - - bb5: { - _0 = const 6_u32; - goto -> bb8; - } - - bb6: { - _0 = const 5_u32; - goto -> bb8; - } - - bb7: { - _0 = const 4_u32; - goto -> bb8; - } - - bb8: { - StorageDead(_3); - return; - } - } - diff --git a/tests/mir-opt/early_otherwise_branch.opt5_failed_type.EarlyOtherwiseBranch.diff b/tests/mir-opt/early_otherwise_branch.opt5_failed_type.EarlyOtherwiseBranch.diff deleted file mode 100644 index 23f14b843b37c..0000000000000 --- a/tests/mir-opt/early_otherwise_branch.opt5_failed_type.EarlyOtherwiseBranch.diff +++ /dev/null @@ -1,61 +0,0 @@ -- // MIR for `opt5_failed_type` before EarlyOtherwiseBranch -+ // MIR for `opt5_failed_type` after EarlyOtherwiseBranch - - fn opt5_failed_type(_1: u32, _2: u64) -> u32 { - debug x => _1; - debug y => _2; - let mut _0: u32; - let mut _3: (u32, u64); - let mut _4: u32; - let mut _5: u64; - - bb0: { - StorageLive(_3); - StorageLive(_4); - _4 = copy _1; - StorageLive(_5); - _5 = copy _2; - _3 = (move _4, move _5); - StorageDead(_5); - StorageDead(_4); - switchInt(copy (_3.0: u32)) -> [1: bb2, 2: bb3, 3: bb4, otherwise: bb1]; - } - - bb1: { - _0 = const 0_u32; - goto -> bb8; - } - - bb2: { - switchInt(copy (_3.1: u64)) -> [1: bb7, otherwise: bb1]; - } - - bb3: { - switchInt(copy (_3.1: u64)) -> [2: bb6, otherwise: bb1]; - } - - bb4: { - switchInt(copy (_3.1: u64)) -> [3: bb5, otherwise: bb1]; - } - - bb5: { - _0 = const 6_u32; - goto -> bb8; - } - - bb6: { - _0 = const 5_u32; - goto -> bb8; - } - - bb7: { - _0 = const 4_u32; - goto -> bb8; - } - - bb8: { - StorageDead(_3); - return; - } - } - diff --git a/tests/mir-opt/early_otherwise_branch.rs b/tests/mir-opt/early_otherwise_branch.rs index 382c38ceb3abb..47bd4be295b04 100644 --- a/tests/mir-opt/early_otherwise_branch.rs +++ b/tests/mir-opt/early_otherwise_branch.rs @@ -78,57 +78,9 @@ fn opt4(x: Option2, y: Option2) -> u32 { } } -// EMIT_MIR early_otherwise_branch.opt5.EarlyOtherwiseBranch.diff -fn opt5(x: u32, y: u32) -> u32 { - // CHECK-LABEL: fn opt5( - // CHECK: let mut [[CMP_LOCAL:_.*]]: bool; - // CHECK: bb0: { - // CHECK: [[CMP_LOCAL]] = Ne( - // CHECK: switchInt(move [[CMP_LOCAL]]) -> [ - // CHECK-NEXT: } - match (x, y) { - (1, 1) => 4, - (2, 2) => 5, - (3, 3) => 6, - _ => 0, - } -} - -// EMIT_MIR early_otherwise_branch.opt5_failed.EarlyOtherwiseBranch.diff -fn opt5_failed(x: u32, y: u32) -> u32 { - // CHECK-LABEL: fn opt5_failed( - // CHECK: bb0: { - // CHECK-NOT: Ne( - // CHECK: switchInt( - // CHECK-NEXT: } - match (x, y) { - (1, 1) => 4, - (2, 2) => 5, - (3, 2) => 6, - _ => 0, - } -} - -// EMIT_MIR early_otherwise_branch.opt5_failed_type.EarlyOtherwiseBranch.diff -fn opt5_failed_type(x: u32, y: u64) -> u32 { - // CHECK-LABEL: fn opt5_failed_type( - // CHECK: bb0: { - // CHECK-NOT: Ne( - // CHECK: switchInt( - // CHECK-NEXT: } - match (x, y) { - (1, 1) => 4, - (2, 2) => 5, - (3, 3) => 6, - _ => 0, - } -} - fn main() { opt1(None, Some(0)); opt2(None, Some(0)); opt3(Option2::None, Option2::Some(false)); opt4(Option2::None, Option2::Some(0)); - opt5(0, 0); - opt5_failed(0, 0); } diff --git a/tests/ui/mir/early-otherwise-branch-ice.rs b/tests/ui/mir/early-otherwise-branch-ice.rs new file mode 100644 index 0000000000000..c1938eb75077b --- /dev/null +++ b/tests/ui/mir/early-otherwise-branch-ice.rs @@ -0,0 +1,18 @@ +// Changes in https://github.com/rust-lang/rust/pull/129047 lead to several mir-opt ICE regressions, +// this test is added to make sure this does not regress. + +//@ compile-flags: -C opt-level=3 +//@ check-pass + +#![crate_type = "lib"] + +use std::task::Poll; + +pub fn poll(val: Poll>, u8>>) { + match val { + Poll::Ready(Ok(Some(_trailers))) => {} + Poll::Ready(Err(_err)) => {} + Poll::Ready(Ok(None)) => {} + Poll::Pending => {} + } +}