From 8f4ada44fcda31fccae5e2a60d746c8e4c957e94 Mon Sep 17 00:00:00 2001 From: John Bobbo Date: Fri, 16 Jun 2023 20:46:42 -0700 Subject: [PATCH] Implement a `SimplifyStaticSwitch` MIR pass which removes unnecessary switches on statically known discriminants. --- compiler/rustc_mir_transform/src/lib.rs | 3 + .../src/simplify_static_switch.rs | 320 ++++++++++++++++++ ...t_switch.identity.SeparateConstSwitch.diff | 27 +- ...witch.too_complex.SeparateConstSwitch.diff | 15 +- ...rrowed_aggregate.SimplifyStaticSwitch.diff | 63 ++++ ...wed_discriminant.SimplifyStaticSwitch.diff | 31 ++ ...tch.custom_discr.SimplifyStaticSwitch.diff | 56 +++ ...itch.loop_header.SimplifyStaticSwitch.diff | 58 ++++ ...utated_aggregate.SimplifyStaticSwitch.diff | 62 ++++ ...ted_discriminant.SimplifyStaticSwitch.diff | 30 ++ ...as_mut_unchecked.SimplifyStaticSwitch.diff | 83 +++++ ...as_ref_unchecked.SimplifyStaticSwitch.diff | 73 ++++ tests/mir-opt/simplify_static_switch.rs | 176 ++++++++++ ...itch.too_complex.SimplifyStaticSwitch.diff | 89 +++++ 14 files changed, 1059 insertions(+), 27 deletions(-) create mode 100644 compiler/rustc_mir_transform/src/simplify_static_switch.rs create mode 100644 tests/mir-opt/simplify_static_switch.borrowed_aggregate.SimplifyStaticSwitch.diff create mode 100644 tests/mir-opt/simplify_static_switch.borrowed_discriminant.SimplifyStaticSwitch.diff create mode 100644 tests/mir-opt/simplify_static_switch.custom_discr.SimplifyStaticSwitch.diff create mode 100644 tests/mir-opt/simplify_static_switch.loop_header.SimplifyStaticSwitch.diff create mode 100644 tests/mir-opt/simplify_static_switch.mutated_aggregate.SimplifyStaticSwitch.diff create mode 100644 tests/mir-opt/simplify_static_switch.mutated_discriminant.SimplifyStaticSwitch.diff create mode 100644 tests/mir-opt/simplify_static_switch.opt_as_mut_unchecked.SimplifyStaticSwitch.diff create mode 100644 tests/mir-opt/simplify_static_switch.opt_as_ref_unchecked.SimplifyStaticSwitch.diff create mode 100644 tests/mir-opt/simplify_static_switch.rs create mode 100644 tests/mir-opt/simplify_static_switch.too_complex.SimplifyStaticSwitch.diff diff --git a/compiler/rustc_mir_transform/src/lib.rs b/compiler/rustc_mir_transform/src/lib.rs index 9c8c0ea0be004..d073ffb50a365 100644 --- a/compiler/rustc_mir_transform/src/lib.rs +++ b/compiler/rustc_mir_transform/src/lib.rs @@ -101,6 +101,7 @@ mod check_alignment; pub mod simplify; mod simplify_branches; mod simplify_comparison_integral; +mod simplify_static_switch; mod sroa; mod uninhabited_enum_branching; mod unreachable_prop; @@ -561,6 +562,8 @@ fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { &simplify::SimplifyLocals::BeforeConstProp, ©_prop::CopyProp, &ref_prop::ReferencePropagation, + // Remove switches on a statically known discriminant, which can happen as a result of inlining. + &simplify_static_switch::SimplifyStaticSwitch, // Perform `SeparateConstSwitch` after SSA-based analyses, as cloning blocks may // destroy the SSA property. It should still happen before const-propagation, so the // latter pass will leverage the created opportunities. diff --git a/compiler/rustc_mir_transform/src/simplify_static_switch.rs b/compiler/rustc_mir_transform/src/simplify_static_switch.rs new file mode 100644 index 0000000000000..addcf12280e9e --- /dev/null +++ b/compiler/rustc_mir_transform/src/simplify_static_switch.rs @@ -0,0 +1,320 @@ +use super::MirPass; + +use rustc_data_structures::fx::FxHashMap; +use rustc_middle::mir::visit::{PlaceContext, Visitor}; +use rustc_middle::mir::{ + AggregateKind, BasicBlock, Body, Local, Location, Operand, Place, Rvalue, StatementKind, + TerminatorKind, +}; +use rustc_middle::ty::TyCtxt; +use rustc_mir_dataflow::impls::MaybeBorrowedLocals; +use rustc_mir_dataflow::Analysis; +use rustc_session::Session; + +use super::simplify; +use super::ssa::SsaLocals; + +/// # Overview +/// +/// This pass looks to optimize a pattern in MIR where variants of an aggregate +/// are constructed in one or more blocks with the same successor and then that +/// aggregate/discriminant is switched on in that successor block, in which case +/// we can remove the switch on the discriminant because we statically know +/// what target block will be taken for each variant. +/// +/// Note that an aggregate which is returned from a function call or passed as +/// an argument is not viable for this optimization because we do not statically +/// know the discriminant/variant of the aggregate. +/// +/// For example, the following CFG: +/// ```text +/// x = Foo::A(y); --- Foo::A ---> ... +/// / \ / +/// ... --> switch x +/// \ / \ +/// x = Foo::B(y); --- Foo::B ---> ... +/// ``` +/// would become: +/// ```text +/// x = Foo::A(y); --------- Foo::A ---> ... +/// / +/// ... +/// \ +/// x = Foo::B(y); --------- Foo::B ---> ... +/// ``` +/// +/// # Soundness +/// +/// - If the discriminant being switched on is not SSA, or if the aggregate is +/// mutated before the discriminant is assigned, the optimization cannot be +/// applied because we no longer statically know what variant the aggregate +/// could be, or what discriminant is being switched on. +/// +/// - If the discriminant is borrowed before being switched on, or the aggregate +/// is borrowed before the discriminant is assigned, we also cannot optimize due +/// to the possibilty stated in the first paragraph. +/// +/// - An aggregate being constructed has a known variant, and if it is not borrowed +/// or mutated before being switched on, then it does not actually need a runtime +/// switch on the discriminant (aka variant) of said aggregate. +/// +pub struct SimplifyStaticSwitch; + +impl<'tcx> MirPass<'tcx> for SimplifyStaticSwitch { + fn is_enabled(&self, sess: &Session) -> bool { + sess.mir_opt_level() >= 2 + } + + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + debug!("Running SimplifyStaticSwitch on {:?}", body.source.def_id()); + + let ssa_locals = SsaLocals::new(body); + if simplify_static_switches(tcx, body, &ssa_locals) { + simplify::remove_dead_blocks(tcx, body); + } + } +} + +#[instrument(level = "debug", skip_all, ret)] +fn simplify_static_switches<'tcx>( + tcx: TyCtxt<'tcx>, + body: &mut Body<'tcx>, + ssa_locals: &SsaLocals, +) -> bool { + let dominators = body.basic_blocks.dominators(); + let predecessors = body.basic_blocks.predecessors(); + let mut discriminants = FxHashMap::default(); + let mut static_switches = FxHashMap::default(); + let mut borrowed_locals = + MaybeBorrowedLocals.into_engine(tcx, body).iterate_to_fixpoint().into_results_cursor(body); + for (switched, rvalue, location) in ssa_locals.assignments(body) { + let Rvalue::Discriminant(discr) = rvalue else { + continue + }; + + borrowed_locals.seek_after_primary_effect(location); + // If `discr` was borrowed before its discriminant was assigned to `switched`, + // or if it was borrowed in the assignment, we cannot optimize. + if borrowed_locals.contains(discr.local) { + debug!("The aggregate: {discr:?} was borrowed before its discriminant was read"); + continue; + } + + let Location { block, statement_index } = location; + let mut finder = MutatedLocalFinder { local: discr.local, mutated: false }; + for (statement_index, statement) in body.basic_blocks[block] + .statements + .iter() + .enumerate() + .take_while(|&(index, _)| index != statement_index) + { + finder.visit_statement(statement, Location { block, statement_index }); + } + + if finder.mutated { + debug!("The aggregate: {discr:?} was mutated before its discriminant was read"); + continue; + } + + // If `switched` is borrowed by the time we actually switch on it, we also cannot optimize. + borrowed_locals.seek_to_block_end(block); + if borrowed_locals.contains(switched) { + debug!("The local: {switched:?} was borrowed before being switched on"); + continue; + } + + discriminants.insert( + switched, + Discriminant { + block, + discr: discr.local, + exclude: if ssa_locals.num_direct_uses(switched) == 1 { + // If there is only one direct use of `switched` we do not need to keep + // it around because the only use is in the switch. + Some(statement_index) + } else { + None + }, + }, + ); + } + + if discriminants.is_empty() { + debug!("No SSA locals were assigned a discriminant"); + return false; + } + + for (switched, Discriminant { discr, block, exclude }) in discriminants { + let data = &body.basic_blocks[block]; + if data.is_cleanup { + continue; + } + + let predecessors = &predecessors[block]; + if predecessors.is_empty() { + continue; + } + + if predecessors.iter().any(|&pred| { + // If we find a backedge from: `pred -> block`, this indicates that + // `block` is a loop header. To avoid creating irreducible CFGs we do + // not thread through loop headers. + dominators.dominates(block, pred) + }) { + debug!("Unable to thread through loop header: {block:?}"); + continue; + } + + let terminator = data.terminator(); + let TerminatorKind::SwitchInt { + discr: Operand::Copy(place) | Operand::Move(place), + targets + } = &terminator.kind else { + continue + }; + + if place.local != switched { + continue; + } + + let mut finder = MutatedLocalFinder { local: discr, mutated: false }; + 'preds: for &pred in predecessors { + let data = &body.basic_blocks[pred]; + let terminator = data.terminator(); + let TerminatorKind::Goto { .. } = terminator.kind else { + continue + }; + + for (statement_index, statement) in data.statements.iter().enumerate().rev() { + match statement.kind { + StatementKind::SetDiscriminant { box place, variant_index: variant } + | StatementKind::Assign(box ( + place, + Rvalue::Aggregate(box AggregateKind::Adt(_, variant, ..), ..), + )) if place.local == discr => { + if finder.mutated { + debug!( + "The discriminant: {discr:?} was mutated in predecessor: {pred:?}" + ); + // We can't optimize this predecessor, so try the next one. + finder.mutated = false; + + continue 'preds; + } + + let discr_ty = body.local_decls[discr].ty; + if let Some(discr) = discr_ty.discriminant_for_variant(tcx, variant) { + debug!( + "{pred:?}: {place:?} = {discr_ty:?}::{variant:?}; goto -> {block:?}", + ); + + let target = targets.target_for_value(discr.val); + static_switches + .entry(block) + .and_modify(|static_switches: &mut &mut [StaticSwitch]| { + if static_switches.iter_mut().all(|switch| { + if switch.pred == pred { + switch.target = target; + false + } else { + true + } + }) { + *static_switches = + tcx.arena.alloc_from_iter( + static_switches.iter().copied().chain([ + StaticSwitch { pred, target, exclude }, + ]), + ); + } + }) + .or_insert_with(|| { + tcx.arena.alloc([StaticSwitch { pred, target, exclude }]) + }); + } + + continue 'preds; + } + _ if finder.mutated => { + debug!("The discriminant: {discr:?} was mutated in predecessor: {pred:?}"); + // Note that the discriminant could have been mutated in one predecessor + // but not the others, in which case only the predecessors which did not mutate + // the discriminant can be optimized. + finder.mutated = false; + + continue 'preds; + } + _ => finder.visit_statement(statement, Location { block, statement_index }), + } + } + } + } + + if static_switches.is_empty() { + debug!("No static switches were found in the current body"); + return false; + } + + let basic_blocks = body.basic_blocks.as_mut(); + let num_switches: usize = static_switches.iter().map(|(_, switches)| switches.len()).sum(); + for (block, static_switches) in static_switches { + for switch in static_switches { + debug!("{block:?}: Removing static switch: {switch:?}"); + + // We use the SSA, to destroy the SSA. + let data = { + let (block, pred) = basic_blocks.pick2_mut(block, switch.pred); + match switch.exclude { + Some(exclude) => { + pred.statements.extend(block.statements.iter().enumerate().filter_map( + |(index, statement)| { + if index == exclude { None } else { Some(statement.clone()) } + }, + )); + } + None => pred.statements.extend_from_slice(&block.statements), + } + pred + }; + let terminator = data.terminator_mut(); + + // Make sure that we have not overwritten the terminator and it is still + // a `goto -> block`. + assert_eq!(terminator.kind, TerminatorKind::Goto { target: block }); + // Something to be noted is that, this creates an edge from: `pred -> target`, + // and because we ensure that we do not thread through any loop headers, meaning + // it is not part of a loop, this edge will only ever appear once in the CFG. + terminator.kind = TerminatorKind::Goto { target: switch.target }; + } + } + + debug!("Removed {num_switches} static switches from: {:?}", body.source.def_id()); + true +} + +#[derive(Debug, Copy, Clone)] +struct StaticSwitch { + pred: BasicBlock, + target: BasicBlock, + exclude: Option, +} + +#[derive(Debug, Copy, Clone)] +struct Discriminant { + discr: Local, + block: BasicBlock, + exclude: Option, +} + +struct MutatedLocalFinder { + local: Local, + mutated: bool, +} + +impl<'tcx> Visitor<'tcx> for MutatedLocalFinder { + fn visit_place(&mut self, place: &Place<'tcx>, context: PlaceContext, _: Location) { + if self.local == place.local && let PlaceContext::MutatingUse(..) = context { + self.mutated = true; + } + } +} diff --git a/tests/mir-opt/separate_const_switch.identity.SeparateConstSwitch.diff b/tests/mir-opt/separate_const_switch.identity.SeparateConstSwitch.diff index ca1528b6ab1a4..8bd183e9d8fd0 100644 --- a/tests/mir-opt/separate_const_switch.identity.SeparateConstSwitch.diff +++ b/tests/mir-opt/separate_const_switch.identity.SeparateConstSwitch.diff @@ -51,28 +51,17 @@ StorageLive(_10); StorageLive(_11); _9 = discriminant(_1); - switchInt(move _9) -> [0: bb7, 1: bb5, otherwise: bb6]; + switchInt(move _9) -> [0: bb5, 1: bb3, otherwise: bb4]; } bb1: { - StorageDead(_11); - StorageDead(_10); - _5 = discriminant(_3); - switchInt(move _5) -> [0: bb2, 1: bb4, otherwise: bb3]; - } - - bb2: { _8 = ((_3 as Continue).0: i32); _0 = Result::::Ok(_8); StorageDead(_3); return; } - bb3: { - unreachable; - } - - bb4: { + bb2: { _6 = ((_3 as Break).0: std::result::Result); _13 = ((_6 as Err).0: i32); _0 = Result::::Err(move _13); @@ -80,22 +69,26 @@ return; } - bb5: { + bb3: { _11 = ((_1 as Err).0: i32); StorageLive(_12); _12 = Result::::Err(move _11); _3 = ControlFlow::, i32>::Break(move _12); StorageDead(_12); - goto -> bb1; + StorageDead(_11); + StorageDead(_10); + goto -> bb2; } - bb6: { + bb4: { unreachable; } - bb7: { + bb5: { _10 = ((_1 as Ok).0: i32); _3 = ControlFlow::, i32>::Continue(move _10); + StorageDead(_11); + StorageDead(_10); goto -> bb1; } } diff --git a/tests/mir-opt/separate_const_switch.too_complex.SeparateConstSwitch.diff b/tests/mir-opt/separate_const_switch.too_complex.SeparateConstSwitch.diff index e2bf33f7fbcc0..bff8bb758a961 100644 --- a/tests/mir-opt/separate_const_switch.too_complex.SeparateConstSwitch.diff +++ b/tests/mir-opt/separate_const_switch.too_complex.SeparateConstSwitch.diff @@ -46,29 +46,24 @@ bb3: { _4 = ((_1 as Ok).0: i32); _2 = ControlFlow::::Continue(_4); - goto -> bb4; + goto -> bb5; } bb4: { - _8 = discriminant(_2); - switchInt(move _8) -> [0: bb6, 1: bb5, otherwise: bb2]; - } - - bb5: { StorageLive(_11); _11 = ((_2 as Break).0: usize); _0 = Option::::None; StorageDead(_11); - goto -> bb7; + goto -> bb6; } - bb6: { + bb5: { _9 = ((_2 as Continue).0: i32); _0 = Option::::Some(_9); - goto -> bb7; + goto -> bb6; } - bb7: { + bb6: { StorageDead(_2); return; } diff --git a/tests/mir-opt/simplify_static_switch.borrowed_aggregate.SimplifyStaticSwitch.diff b/tests/mir-opt/simplify_static_switch.borrowed_aggregate.SimplifyStaticSwitch.diff new file mode 100644 index 0000000000000..2d0e5cba36823 --- /dev/null +++ b/tests/mir-opt/simplify_static_switch.borrowed_aggregate.SimplifyStaticSwitch.diff @@ -0,0 +1,63 @@ +- // MIR for `borrowed_aggregate` before SimplifyStaticSwitch ++ // MIR for `borrowed_aggregate` after SimplifyStaticSwitch + + fn borrowed_aggregate(_1: bool) -> bool { + debug cond => _1; + let mut _0: bool; + let mut _2: Foo; + let mut _3: bool; + let mut _5: isize; + scope 1 { + debug foo => _2; + let _4: &mut Foo; + scope 2 { + debug bar => _4; + } + } + + bb0: { + StorageLive(_2); + StorageLive(_3); + _3 = _1; + switchInt(move _3) -> [0: bb2, otherwise: bb1]; + } + + bb1: { + _2 = Foo::A; + goto -> bb3; + } + + bb2: { + _2 = Foo::B; + goto -> bb3; + } + + bb3: { + StorageDead(_3); + StorageLive(_4); + _4 = &mut _2; + _5 = discriminant(_2); + switchInt(move _5) -> [0: bb5, 1: bb6, 2: bb6, otherwise: bb4]; + } + + bb4: { + unreachable; + } + + bb5: { + _0 = const true; + goto -> bb7; + } + + bb6: { + _0 = const false; + goto -> bb7; + } + + bb7: { + StorageDead(_4); + StorageDead(_2); + return; + } + } + diff --git a/tests/mir-opt/simplify_static_switch.borrowed_discriminant.SimplifyStaticSwitch.diff b/tests/mir-opt/simplify_static_switch.borrowed_discriminant.SimplifyStaticSwitch.diff new file mode 100644 index 0000000000000..83d9907fa6396 --- /dev/null +++ b/tests/mir-opt/simplify_static_switch.borrowed_discriminant.SimplifyStaticSwitch.diff @@ -0,0 +1,31 @@ +- // MIR for `borrowed_discriminant` before SimplifyStaticSwitch ++ // MIR for `borrowed_discriminant` after SimplifyStaticSwitch + + fn borrowed_discriminant() -> bool { + let mut _0: bool; + let mut _1: Foo; + let mut _2: isize; + let mut _3: &mut isize; + + bb0: { + _1 = Foo::A; + goto -> bb1; + } + + bb1: { + _2 = discriminant(_1); + _3 = &mut _2; + switchInt(_2) -> [0: bb2, otherwise: bb3]; + } + + bb2: { + _0 = const true; + return; + } + + bb3: { + _0 = const false; + return; + } + } + diff --git a/tests/mir-opt/simplify_static_switch.custom_discr.SimplifyStaticSwitch.diff b/tests/mir-opt/simplify_static_switch.custom_discr.SimplifyStaticSwitch.diff new file mode 100644 index 0000000000000..1daa723ed4659 --- /dev/null +++ b/tests/mir-opt/simplify_static_switch.custom_discr.SimplifyStaticSwitch.diff @@ -0,0 +1,56 @@ +- // MIR for `custom_discr` before SimplifyStaticSwitch ++ // MIR for `custom_discr` after SimplifyStaticSwitch + + fn custom_discr(_1: bool) -> u8 { + debug x => _1; + let mut _0: u8; + let mut _2: custom_discr::CustomDiscr; + let mut _3: bool; + let mut _4: u8; + + bb0: { + StorageLive(_2); + StorageLive(_3); + _3 = _1; + switchInt(move _3) -> [0: bb2, otherwise: bb1]; + } + + bb1: { + _2 = CustomDiscr::A; +- goto -> bb3; ++ StorageDead(_3); ++ goto -> bb4; + } + + bb2: { + _2 = CustomDiscr::B; ++ StorageDead(_3); + goto -> bb3; + } + + bb3: { +- StorageDead(_3); +- _4 = discriminant(_2); +- switchInt(move _4) -> [35: bb5, otherwise: bb4]; +- } +- +- bb4: { + _0 = const 13_u8; +- goto -> bb6; ++ goto -> bb5; + } + +- bb5: { ++ bb4: { + _0 = const 5_u8; +- goto -> bb6; ++ goto -> bb5; + } + +- bb6: { ++ bb5: { + StorageDead(_2); + return; + } + } + diff --git a/tests/mir-opt/simplify_static_switch.loop_header.SimplifyStaticSwitch.diff b/tests/mir-opt/simplify_static_switch.loop_header.SimplifyStaticSwitch.diff new file mode 100644 index 0000000000000..48a8162bd7404 --- /dev/null +++ b/tests/mir-opt/simplify_static_switch.loop_header.SimplifyStaticSwitch.diff @@ -0,0 +1,58 @@ +- // MIR for `loop_header` before SimplifyStaticSwitch ++ // MIR for `loop_header` after SimplifyStaticSwitch + + fn loop_header() -> () { + let mut _0: (); + let mut _1: Foo; + let mut _2: !; + let mut _3: (); + let mut _4: isize; + let mut _5: Foo; + let mut _6: Foo; + let mut _7: !; + scope 1 { + debug foo => _1; + } + + bb0: { + StorageLive(_1); + _1 = Foo::A; + StorageLive(_2); + goto -> bb1; + } + + bb1: { + _4 = discriminant(_1); + switchInt(move _4) -> [0: bb4, 1: bb5, 2: bb2, otherwise: bb3]; + } + + bb2: { + _0 = const (); + StorageDead(_2); + StorageDead(_1); + return; + } + + bb3: { + unreachable; + } + + bb4: { + StorageLive(_5); + _5 = Foo::B; + _1 = move _5; + _3 = const (); + StorageDead(_5); + goto -> bb1; + } + + bb5: { + StorageLive(_6); + _6 = Foo::A; + _1 = move _6; + _3 = const (); + StorageDead(_6); + goto -> bb1; + } + } + diff --git a/tests/mir-opt/simplify_static_switch.mutated_aggregate.SimplifyStaticSwitch.diff b/tests/mir-opt/simplify_static_switch.mutated_aggregate.SimplifyStaticSwitch.diff new file mode 100644 index 0000000000000..c6b572152c3b7 --- /dev/null +++ b/tests/mir-opt/simplify_static_switch.mutated_aggregate.SimplifyStaticSwitch.diff @@ -0,0 +1,62 @@ +- // MIR for `mutated_aggregate` before SimplifyStaticSwitch ++ // MIR for `mutated_aggregate` after SimplifyStaticSwitch + + fn mutated_aggregate(_1: bool, _2: Foo) -> bool { + debug cond => _1; + debug bar => _2; + let mut _0: bool; + let mut _3: Foo; + let mut _4: bool; + let mut _5: Foo; + let mut _6: isize; + scope 1 { + debug foo => _3; + } + + bb0: { + StorageLive(_3); + StorageLive(_4); + _4 = _1; + switchInt(move _4) -> [0: bb2, otherwise: bb1]; + } + + bb1: { + _3 = Foo::A; + goto -> bb3; + } + + bb2: { + _3 = Foo::B; + goto -> bb3; + } + + bb3: { + StorageDead(_4); + StorageLive(_5); + _5 = move _2; + _3 = move _5; + StorageDead(_5); + _6 = discriminant(_3); + switchInt(move _6) -> [0: bb5, 1: bb6, 2: bb6, otherwise: bb4]; + } + + bb4: { + unreachable; + } + + bb5: { + _0 = const true; + goto -> bb7; + } + + bb6: { + _0 = const false; + goto -> bb7; + } + + bb7: { + StorageDead(_3); + return; + } + } + diff --git a/tests/mir-opt/simplify_static_switch.mutated_discriminant.SimplifyStaticSwitch.diff b/tests/mir-opt/simplify_static_switch.mutated_discriminant.SimplifyStaticSwitch.diff new file mode 100644 index 0000000000000..a7bb4a7d11f8f --- /dev/null +++ b/tests/mir-opt/simplify_static_switch.mutated_discriminant.SimplifyStaticSwitch.diff @@ -0,0 +1,30 @@ +- // MIR for `mutated_discriminant` before SimplifyStaticSwitch ++ // MIR for `mutated_discriminant` after SimplifyStaticSwitch + + fn mutated_discriminant(_1: isize) -> bool { + let mut _0: bool; + let mut _2: Foo; + let mut _3: isize; + + bb0: { + _2 = Foo::A; + goto -> bb1; + } + + bb1: { + _3 = discriminant(_2); + _3 = _1; + switchInt(_3) -> [0: bb2, otherwise: bb3]; + } + + bb2: { + _0 = const true; + return; + } + + bb3: { + _0 = const false; + return; + } + } + diff --git a/tests/mir-opt/simplify_static_switch.opt_as_mut_unchecked.SimplifyStaticSwitch.diff b/tests/mir-opt/simplify_static_switch.opt_as_mut_unchecked.SimplifyStaticSwitch.diff new file mode 100644 index 0000000000000..31ca6faee33c7 --- /dev/null +++ b/tests/mir-opt/simplify_static_switch.opt_as_mut_unchecked.SimplifyStaticSwitch.diff @@ -0,0 +1,83 @@ +- // MIR for `opt_as_mut_unchecked` before SimplifyStaticSwitch ++ // MIR for `opt_as_mut_unchecked` after SimplifyStaticSwitch + + fn opt_as_mut_unchecked(_1: &mut Option) -> &mut T { + debug opt => _1; + let mut _0: &mut T; + let mut _2: &mut T; + let _3: std::option::Option<&mut T>; + let mut _4: isize; + let _5: &mut T; + let mut _6: &mut T; + let mut _7: &mut T; + let mut _8: isize; + let mut _10: !; + scope 1 { + debug opt => _3; + let _9: &mut T; + scope 3 { + debug val => _9; + } + scope 4 { + } + } + scope 2 { + debug val => _5; + } + + bb0: { + StorageLive(_2); + StorageLive(_3); + _4 = discriminant((*_1)); + switchInt(move _4) -> [0: bb1, 1: bb3, otherwise: bb2]; + } + + bb1: { + _3 = Option::<&mut T>::None; ++ StorageLive(_7); + goto -> bb4; + } + + bb2: { + unreachable; + } + + bb3: { + StorageLive(_5); + _5 = &mut (((*_1) as Some).0: T); + StorageLive(_6); + _6 = move _5; + _3 = Option::<&mut T>::Some(move _6); + StorageDead(_6); + StorageDead(_5); +- goto -> bb4; +- } +- +- bb4: { + StorageLive(_7); +- _8 = discriminant(_3); +- switchInt(move _8) -> [0: bb5, 1: bb6, otherwise: bb2]; ++ goto -> bb5; + } + +- bb5: { ++ bb4: { + StorageLive(_10); + _10 = unreachable_unchecked(); + } + +- bb6: { ++ bb5: { + StorageLive(_9); + _9 = move ((_3 as Some).0: &mut T); + _7 = &mut (*_9); + StorageDead(_9); + _2 = &mut (*_7); + StorageDead(_3); + _0 = &mut (*_2); + StorageDead(_7); + StorageDead(_2); + return; + } + } + diff --git a/tests/mir-opt/simplify_static_switch.opt_as_ref_unchecked.SimplifyStaticSwitch.diff b/tests/mir-opt/simplify_static_switch.opt_as_ref_unchecked.SimplifyStaticSwitch.diff new file mode 100644 index 0000000000000..20dd256bee90f --- /dev/null +++ b/tests/mir-opt/simplify_static_switch.opt_as_ref_unchecked.SimplifyStaticSwitch.diff @@ -0,0 +1,73 @@ +- // MIR for `opt_as_ref_unchecked` before SimplifyStaticSwitch ++ // MIR for `opt_as_ref_unchecked` after SimplifyStaticSwitch + + fn opt_as_ref_unchecked(_1: &Option) -> &T { + debug opt => _1; + let mut _0: &T; + let _2: std::option::Option<&T>; + let mut _3: isize; + let _4: &T; + let mut _5: &T; + let mut _6: isize; + let mut _8: !; + scope 1 { + debug opt => _2; + let _7: &T; + scope 3 { + debug val => _7; + } + scope 4 { + } + } + scope 2 { + debug val => _4; + } + + bb0: { + StorageLive(_2); + _3 = discriminant((*_1)); + switchInt(move _3) -> [0: bb1, 1: bb3, otherwise: bb2]; + } + + bb1: { + _2 = Option::<&T>::None; + goto -> bb4; + } + + bb2: { + unreachable; + } + + bb3: { + StorageLive(_4); + _4 = &(((*_1) as Some).0: T); + StorageLive(_5); + _5 = _4; + _2 = Option::<&T>::Some(move _5); + StorageDead(_5); + StorageDead(_4); +- goto -> bb4; ++ goto -> bb5; + } + + bb4: { +- _6 = discriminant(_2); +- switchInt(move _6) -> [0: bb5, 1: bb6, otherwise: bb2]; +- } +- +- bb5: { + StorageLive(_8); + _8 = unreachable_unchecked(); + } + +- bb6: { ++ bb5: { + StorageLive(_7); + _7 = ((_2 as Some).0: &T); + _0 = &(*_7); + StorageDead(_7); + StorageDead(_2); + return; + } + } + diff --git a/tests/mir-opt/simplify_static_switch.rs b/tests/mir-opt/simplify_static_switch.rs new file mode 100644 index 0000000000000..f5cc766cf2bc9 --- /dev/null +++ b/tests/mir-opt/simplify_static_switch.rs @@ -0,0 +1,176 @@ +// unit-test: SimplifyStaticSwitch + +#![crate_type = "lib"] +#![feature(core_intrinsics, custom_mir)] + +use std::hint; +use std::intrinsics::mir::*; +use std::ops::ControlFlow; + +// EMIT_MIR simplify_static_switch.too_complex.SimplifyStaticSwitch.diff +pub fn too_complex(x: Result) -> Option { + match { + match x { + Ok(v) => ControlFlow::Continue(v), + Err(r) => ControlFlow::Break(r), + } + } { + ControlFlow::Continue(v) => Some(v), + ControlFlow::Break(_) => None, + } +} + +// EMIT_MIR simplify_static_switch.custom_discr.SimplifyStaticSwitch.diff +pub fn custom_discr(x: bool) -> u8 { + #[repr(u8)] + enum CustomDiscr { + A = 35, + B = 73, + C = 99, + } + + match if x { CustomDiscr::A } else { CustomDiscr::B } { + CustomDiscr::A => 5, + _ => 13, + } +} + +pub enum Foo { + A, + B, + C, +} + +// Make sure we do not thread through loop headers, to avoid +// creating irreducible CFGs. +// EMIT_MIR simplify_static_switch.loop_header.SimplifyStaticSwitch.diff +pub fn loop_header() { + let mut foo = Foo::A; + loop { + match foo { + Foo::A => foo = Foo::B, + Foo::B => foo = Foo::A, + Foo::C => return, + } + } +} + +// EMIT_MIR simplify_static_switch.opt_as_ref_unchecked.SimplifyStaticSwitch.diff +pub unsafe fn opt_as_ref_unchecked(opt: &Option) -> &T { + let opt = match opt { + Some(ref val) => Some(val), + None => None, + }; + match opt { + Some(val) => val, + None => unsafe { hint::unreachable_unchecked() }, + } +} + +// EMIT_MIR simplify_static_switch.opt_as_mut_unchecked.SimplifyStaticSwitch.diff +pub unsafe fn opt_as_mut_unchecked(opt: &mut Option) -> &mut T { + let opt = match opt { + Some(ref mut val) => Some(val), + None => None, + }; + match opt { + Some(val) => val, + None => unsafe { hint::unreachable_unchecked() }, + } +} + +// Make sure that we do not apply this opt if the aggregate is borrowed before +// being switched on. +// EMIT_MIR simplify_static_switch.borrowed_aggregate.SimplifyStaticSwitch.diff +pub fn borrowed_aggregate(cond: bool) -> bool { + let mut foo = if cond { + Foo::A + } else { + Foo::B + }; + // `bar` could indirectly mutate `foo` so we cannot optimize. + let bar = &mut foo; + match foo { + Foo::A => true, + Foo::B | Foo::C => false, + } +} + +// Make sure that we do not apply this opt if the aggregate is mutated before +// being switched on. +// EMIT_MIR simplify_static_switch.mutated_aggregate.SimplifyStaticSwitch.diff +pub fn mutated_aggregate(cond: bool, bar: Foo) -> bool { + let mut foo = if cond { + Foo::A + } else { + Foo::B + }; + // We no longer know what variant `foo` is. + foo = bar; + match foo { + Foo::A => true, + Foo::B | Foo::C => false, + } +} + +// Make sure that we do not apply this opt if the discriminant is borrowed before +// being switched on. +// EMIT_MIR simplify_static_switch.borrowed_discriminant.SimplifyStaticSwitch.diff +#[custom_mir(dialect = "runtime", phase = "post-cleanup")] +pub fn borrowed_discriminant() -> bool { + mir!( + let x: Foo; + { + x = Foo::A; + Goto(bb1) + } + bb1 = { + let a = Discriminant(x); + // `a` could be indirectly mutated through `b`. + let b = &mut a; + match a { + 0 => bb2, + _ => bb3, + } + } + bb2 = { + RET = true; + Return() + } + bb3 = { + RET = false; + Return() + } + ) +} + +// Make sure that we do not apply this opt if the discriminant is mutated +// before we switch on it. +// EMIT_MIR simplify_static_switch.mutated_discriminant.SimplifyStaticSwitch.diff +#[custom_mir(dialect = "runtime", phase = "post-cleanup")] +pub fn mutated_discriminant(b: isize) -> bool { + mir!( + let x: Foo; + { + x = Foo::A; + Goto(bb1) + } + bb1 = { + let a = Discriminant(x); + // We no longer know what discriminant `a` is. + a = b; + match a { + 0 => bb2, + _ => bb3, + } + } + bb2 = { + RET = true; + Return() + } + bb3 = { + RET = false; + Return() + } + ) +} diff --git a/tests/mir-opt/simplify_static_switch.too_complex.SimplifyStaticSwitch.diff b/tests/mir-opt/simplify_static_switch.too_complex.SimplifyStaticSwitch.diff new file mode 100644 index 0000000000000..ef802a3aa08c3 --- /dev/null +++ b/tests/mir-opt/simplify_static_switch.too_complex.SimplifyStaticSwitch.diff @@ -0,0 +1,89 @@ +- // MIR for `too_complex` before SimplifyStaticSwitch ++ // MIR for `too_complex` after SimplifyStaticSwitch + + fn too_complex(_1: Result) -> Option { + debug x => _1; + let mut _0: std::option::Option; + let mut _2: std::ops::ControlFlow; + let mut _3: isize; + let _4: i32; + let mut _5: i32; + let _6: usize; + let mut _7: usize; + let mut _8: isize; + let _9: i32; + let mut _10: i32; + scope 1 { + debug v => _4; + } + scope 2 { + debug r => _6; + } + scope 3 { + debug v => _9; + } + + bb0: { + StorageLive(_2); + _3 = discriminant(_1); + switchInt(move _3) -> [0: bb3, 1: bb1, otherwise: bb2]; + } + + bb1: { + StorageLive(_6); + _6 = ((_1 as Err).0: usize); + StorageLive(_7); + _7 = _6; + _2 = ControlFlow::::Break(move _7); + StorageDead(_7); + StorageDead(_6); + goto -> bb4; + } + + bb2: { + unreachable; + } + + bb3: { + StorageLive(_4); + _4 = ((_1 as Ok).0: i32); + StorageLive(_5); + _5 = _4; + _2 = ControlFlow::::Continue(move _5); + StorageDead(_5); + StorageDead(_4); +- goto -> bb4; ++ goto -> bb5; + } + + bb4: { +- _8 = discriminant(_2); +- switchInt(move _8) -> [0: bb6, 1: bb5, otherwise: bb2]; +- } +- +- bb5: { + _0 = Option::::None; +- goto -> bb7; ++ goto -> bb6; + } + +- bb6: { ++ bb5: { + StorageLive(_9); + _9 = ((_2 as Continue).0: i32); + StorageLive(_10); + _10 = _9; + _0 = Option::::Some(move _10); + StorageDead(_10); + StorageDead(_9); +- goto -> bb7; ++ goto -> bb6; + } + +- bb7: { ++ bb6: { + StorageDead(_2); + return; + } + } +