Skip to content

Commit

Permalink
Auto merge of rust-lang#113970 - cjgillot:assume-all-the-things, r=nikic
Browse files Browse the repository at this point in the history
Replace switch to unreachable by assume statements

`UnreachablePropagation` currently keeps some switch terminators alive in order to ensure codegen can infer the inequalities on the discriminants.

This PR proposes to encode those inequalities as `Assume` statements.

This allows to simplify MIR further by removing some useless terminators.
  • Loading branch information
bors committed Nov 1, 2023
2 parents 09ac6e4 + ae2e211 commit 98f5ebb
Show file tree
Hide file tree
Showing 26 changed files with 573 additions and 470 deletions.
5 changes: 3 additions & 2 deletions compiler/rustc_mir_transform/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -568,10 +568,11 @@ fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
&[
&check_alignment::CheckAlignment,
&lower_slice_len::LowerSliceLenCalls, // has to be done before inlining, otherwise actual call will be almost always inlined. Also simple, so can just do first
&unreachable_prop::UnreachablePropagation,
&inline::Inline,
// Substitutions during inlining may introduce switch on enums with uninhabited branches.
&uninhabited_enum_branching::UninhabitedEnumBranching,
&unreachable_prop::UnreachablePropagation,
&o1(simplify::SimplifyCfg::AfterUninhabitedEnumBranching),
&inline::Inline,
&remove_storage_markers::RemoveStorageMarkers,
&remove_zsts::RemoveZsts,
&normalize_array_len::NormalizeArrayLen, // has to run after `slice::len` lowering
Expand Down
19 changes: 18 additions & 1 deletion compiler/rustc_mir_transform/src/simplify_branches.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,25 @@ impl<'tcx> MirPass<'tcx> for SimplifyConstCondition {
}

fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
trace!("Running SimplifyConstCondition on {:?}", body.source);
let param_env = tcx.param_env_reveal_all_normalized(body.source.def_id());
for block in body.basic_blocks_mut() {
'blocks: for block in body.basic_blocks_mut() {
for stmt in block.statements.iter_mut() {
if let StatementKind::Intrinsic(box ref intrinsic) = stmt.kind
&& let NonDivergingIntrinsic::Assume(discr) = intrinsic
&& let Operand::Constant(ref c) = discr
&& let Some(constant) = c.const_.try_eval_bool(tcx, param_env)
{
if constant {
stmt.make_nop();
} else {
block.statements.clear();
block.terminator_mut().kind = TerminatorKind::Unreachable;
continue 'blocks;
}
}
}

let terminator = block.terminator_mut();
terminator.kind = match terminator.kind {
TerminatorKind::SwitchInt {
Expand Down
104 changes: 47 additions & 57 deletions compiler/rustc_mir_transform/src/uninhabited_enum_branching.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
use crate::MirPass;
use rustc_data_structures::fx::FxHashSet;
use rustc_middle::mir::{
BasicBlockData, Body, Local, Operand, Rvalue, StatementKind, SwitchTargets, Terminator,
TerminatorKind,
BasicBlockData, Body, Local, Operand, Rvalue, StatementKind, Terminator, TerminatorKind,
};
use rustc_middle::ty::layout::TyAndLayout;
use rustc_middle::ty::{Ty, TyCtxt};
Expand All @@ -30,17 +29,20 @@ fn get_switched_on_type<'tcx>(
let terminator = block_data.terminator();

// Only bother checking blocks which terminate by switching on a local.
if let Some(local) = get_discriminant_local(&terminator.kind)
&& let [.., stmt_before_term] = &block_data.statements[..]
&& let StatementKind::Assign(box (l, Rvalue::Discriminant(place))) = stmt_before_term.kind
let local = get_discriminant_local(&terminator.kind)?;

let stmt_before_term = block_data.statements.last()?;

if let StatementKind::Assign(box (l, Rvalue::Discriminant(place))) = stmt_before_term.kind
&& l.as_local() == Some(local)
&& let ty = place.ty(body, tcx).ty
&& ty.is_enum()
{
Some(ty)
} else {
None
let ty = place.ty(body, tcx).ty;
if ty.is_enum() {
return Some(ty);
}
}

None
}

fn variant_discriminants<'tcx>(
Expand All @@ -67,28 +69,6 @@ fn variant_discriminants<'tcx>(
}
}

/// Ensures that the `otherwise` branch leads to an unreachable bb, returning `None` if so and a new
/// bb to use as the new target if not.
fn ensure_otherwise_unreachable<'tcx>(
body: &Body<'tcx>,
targets: &SwitchTargets,
) -> Option<BasicBlockData<'tcx>> {
let otherwise = targets.otherwise();
let bb = &body.basic_blocks[otherwise];
if bb.terminator().kind == TerminatorKind::Unreachable
&& bb.statements.iter().all(|s| matches!(&s.kind, StatementKind::StorageDead(_)))
{
return None;
}

let mut new_block = BasicBlockData::new(Some(Terminator {
source_info: bb.terminator().source_info,
kind: TerminatorKind::Unreachable,
}));
new_block.is_cleanup = bb.is_cleanup;
Some(new_block)
}

impl<'tcx> MirPass<'tcx> for UninhabitedEnumBranching {
fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
sess.mir_opt_level() > 0
Expand All @@ -97,13 +77,16 @@ impl<'tcx> MirPass<'tcx> for UninhabitedEnumBranching {
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
trace!("UninhabitedEnumBranching starting for {:?}", body.source);

for bb in body.basic_blocks.indices() {
let mut removable_switchs = Vec::new();

for (bb, bb_data) in body.basic_blocks.iter_enumerated() {
trace!("processing block {:?}", bb);

let Some(discriminant_ty) = get_switched_on_type(&body.basic_blocks[bb], tcx, body)
else {
if bb_data.is_cleanup {
continue;
};
}

let Some(discriminant_ty) = get_switched_on_type(&bb_data, tcx, body) else { continue };

let layout = tcx.layout_of(
tcx.param_env_reveal_all_normalized(body.source.def_id()).and(discriminant_ty),
Expand All @@ -117,31 +100,38 @@ impl<'tcx> MirPass<'tcx> for UninhabitedEnumBranching {

trace!("allowed_variants = {:?}", allowed_variants);

if let TerminatorKind::SwitchInt { targets, .. } =
&mut body.basic_blocks_mut()[bb].terminator_mut().kind
{
let mut new_targets = SwitchTargets::new(
targets.iter().filter(|(val, _)| allowed_variants.contains(val)),
targets.otherwise(),
);

if new_targets.iter().count() == allowed_variants.len() {
if let Some(updated) = ensure_otherwise_unreachable(body, &new_targets) {
let new_otherwise = body.basic_blocks_mut().push(updated);
*new_targets.all_targets_mut().last_mut().unwrap() = new_otherwise;
}
}
let terminator = bb_data.terminator();
let TerminatorKind::SwitchInt { targets, .. } = &terminator.kind else { bug!() };

if let TerminatorKind::SwitchInt { targets, .. } =
&mut body.basic_blocks_mut()[bb].terminator_mut().kind
{
*targets = new_targets;
let mut reachable_count = 0;
for (index, (val, _)) in targets.iter().enumerate() {
if allowed_variants.contains(&val) {
reachable_count += 1;
} else {
unreachable!()
removable_switchs.push((bb, index));
}
} else {
unreachable!()
}

if reachable_count == allowed_variants.len() {
removable_switchs.push((bb, targets.iter().count()));
}
}

if removable_switchs.is_empty() {
return;
}

let new_block = BasicBlockData::new(Some(Terminator {
source_info: body.basic_blocks[removable_switchs[0].0].terminator().source_info,
kind: TerminatorKind::Unreachable,
}));
let unreachable_block = body.basic_blocks.as_mut().push(new_block);

for (bb, index) in removable_switchs {
let bb = &mut body.basic_blocks.as_mut()[bb];
let terminator = bb.terminator_mut();
let TerminatorKind::SwitchInt { targets, .. } = &mut terminator.kind else { bug!() };
targets.all_targets_mut()[index] = unreachable_block;
}
}
}
Loading

0 comments on commit 98f5ebb

Please sign in to comment.