Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize MIR for comparison of references #112542

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions compiler/rustc_mir_transform/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ mod multiple_return_terminators;
mod normalize_array_len;
mod nrvo;
mod prettify;
mod ref_cmp_simplify;
mod ref_prop;
mod remove_noop_landing_pads;
mod remove_storage_markers;
Expand Down Expand Up @@ -561,6 +562,7 @@ fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
&instsimplify::InstSimplify,
&simplify::SimplifyLocals::BeforeConstProp,
&copy_prop::CopyProp,
&ref_cmp_simplify::RefCmpSimplify,
&ref_prop::ReferencePropagation,
// Perform `SeparateConstSwitch` after SSA-based analyses, as cloning blocks may
// destroy the SSA property. It should still happen before const-propagation, so the
Expand Down
93 changes: 93 additions & 0 deletions compiler/rustc_mir_transform/src/ref_cmp_simplify.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
use crate::MirPass;
use rustc_middle::mir::patch::MirPatch;
use rustc_middle::mir::*;
use rustc_middle::ty::TyCtxt;

pub struct RefCmpSimplify;

impl<'tcx> MirPass<'tcx> for RefCmpSimplify {
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
self.simplify_ref_cmp(tcx, body)
}
}

#[derive(Clone, Copy, Debug, PartialEq, Eq)]
enum MatchState {
Empty,
Deref { src_statement_idx: usize, dst: Local, src: Local },
CopiedFrom { src_statement_idx: usize, dst: Local, real_src: Local },
Completed { src_statement_idx: usize, dst: Local, real_src: Local },
}

impl RefCmpSimplify {
fn simplify_ref_cmp<'tcx>(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
debug!("body: {:#?}", body);

let n_bbs = body.basic_blocks.len() as u32;
for bb in 0..n_bbs {
let bb = BasicBlock::from_u32(bb);
let mut max = Local::MAX;
'repeat: loop {
let mut state = MatchState::Empty;
let bb_data = &body.basic_blocks[bb];
for (i, stmt) in bb_data.statements.iter().enumerate().rev() {
state = match (state, &stmt.kind) {
(
MatchState::Empty,
StatementKind::Assign(box (lhs, Rvalue::Use(Operand::Copy(rhs)))),
) if rhs.has_deref() && lhs.ty(body, tcx).ty.is_primitive() => {
let Some(dst) = lhs.as_local() else {
continue
};
let Some(src) = rhs.local_or_deref_local() else {
continue;
};
if max <= dst {
continue;
}
max = dst;
MatchState::Deref { dst, src, src_statement_idx: i }
}
(
MatchState::Deref { src, dst, src_statement_idx },
StatementKind::Assign(box (lhs, Rvalue::CopyForDeref(rhs))),
) if lhs.as_local() == Some(src) && rhs.has_deref() => {
let Some(real_src) = rhs.local_or_deref_local() else{
continue;
};
MatchState::CopiedFrom { src_statement_idx, dst, real_src }
}
(
MatchState::CopiedFrom { src_statement_idx, dst, real_src },
StatementKind::Assign(box (
lhs,
Rvalue::Ref(_, BorrowKind::Shared | BorrowKind::Shallow, rhs),
)),
) if lhs.as_local() == Some(real_src) => {
let Some(real_src) = rhs.as_local() else {
continue;
};
MatchState::Completed { dst, real_src, src_statement_idx }
}
_ => continue,
};
if let MatchState::Completed { dst, real_src, src_statement_idx } = state {
let mut patch = MirPatch::new(&body);
let src = Place::from(real_src);
let src = src.project_deeper(&[PlaceElem::Deref], tcx);
let dst = Place::from(dst);
let new_stmt =
StatementKind::Assign(Box::new((dst, Rvalue::Use(Operand::Copy(src)))));
patch.add_statement(
Location { block: bb, statement_index: src_statement_idx + 1 },
new_stmt,
);
patch.apply(body);
continue 'repeat;
}
}
break;
}
}
}
}
50 changes: 50 additions & 0 deletions tests/mir-opt/ref_int_cmp.opt1.RefCmpSimplify.diff
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
- // MIR for `opt1` before RefCmpSimplify
+ // MIR for `opt1` after RefCmpSimplify

fn opt1(_1: &u8, _2: &u8) -> bool {
debug x => _1; // in scope 0 at $DIR/ref_int_cmp.rs:+0:13: +0:14
debug y => _2; // in scope 0 at $DIR/ref_int_cmp.rs:+0:21: +0:22
let mut _0: bool; // return place in scope 0 at $DIR/ref_int_cmp.rs:+0:32: +0:36
let mut _3: &&u8; // in scope 0 at $DIR/ref_int_cmp.rs:+1:3: +1:4
let mut _4: &&u8; // in scope 0 at $DIR/ref_int_cmp.rs:+1:7: +1:8
let _5: &u8; // in scope 0 at $DIR/ref_int_cmp.rs:+1:7: +1:8
let mut _8: &u8; // in scope 0 at $SRC_DIR/core/src/cmp.rs:LL:COL
let mut _9: &u8; // in scope 0 at $SRC_DIR/core/src/cmp.rs:LL:COL
scope 1 (inlined cmp::impls::<impl PartialOrd for &u8>::lt) { // at $DIR/ref_int_cmp.rs:5:3: 5:8
debug self => _3; // in scope 1 at $SRC_DIR/core/src/cmp.rs:LL:COL
debug other => _4; // in scope 1 at $SRC_DIR/core/src/cmp.rs:LL:COL
let mut _6: &u8; // in scope 1 at $SRC_DIR/core/src/cmp.rs:LL:COL
let mut _7: &u8; // in scope 1 at $SRC_DIR/core/src/cmp.rs:LL:COL
scope 2 (inlined cmp::impls::<impl PartialOrd for u8>::lt) { // at $SRC_DIR/core/src/cmp.rs:LL:COL
debug self => _6; // in scope 2 at $SRC_DIR/core/src/cmp.rs:LL:COL
debug other => _7; // in scope 2 at $SRC_DIR/core/src/cmp.rs:LL:COL
let mut _10: u8; // in scope 2 at $SRC_DIR/core/src/cmp.rs:LL:COL
let mut _11: u8; // in scope 2 at $SRC_DIR/core/src/cmp.rs:LL:COL
}
}

bb0: {
StorageLive(_3); // scope 0 at $DIR/ref_int_cmp.rs:+1:3: +1:4
_3 = &_1; // scope 0 at $DIR/ref_int_cmp.rs:+1:3: +1:4
StorageLive(_4); // scope 0 at $DIR/ref_int_cmp.rs:+1:7: +1:8
StorageLive(_5); // scope 0 at $DIR/ref_int_cmp.rs:+1:7: +1:8
_5 = _2; // scope 0 at $DIR/ref_int_cmp.rs:+1:7: +1:8
_4 = &_5; // scope 0 at $DIR/ref_int_cmp.rs:+1:7: +1:8
_6 = deref_copy (*_3); // scope 1 at $SRC_DIR/core/src/cmp.rs:LL:COL
_7 = deref_copy (*_4); // scope 1 at $SRC_DIR/core/src/cmp.rs:LL:COL
StorageLive(_10); // scope 2 at $SRC_DIR/core/src/cmp.rs:LL:COL
_10 = (*_6); // scope 2 at $SRC_DIR/core/src/cmp.rs:LL:COL
+ _10 = (*_1); // scope 2 at $SRC_DIR/core/src/cmp.rs:LL:COL
StorageLive(_11); // scope 2 at $SRC_DIR/core/src/cmp.rs:LL:COL
_11 = (*_7); // scope 2 at $SRC_DIR/core/src/cmp.rs:LL:COL
+ _11 = (*_5); // scope 2 at $SRC_DIR/core/src/cmp.rs:LL:COL
_0 = Lt(move _10, move _11); // scope 2 at $SRC_DIR/core/src/cmp.rs:LL:COL
StorageDead(_11); // scope 2 at $SRC_DIR/core/src/cmp.rs:LL:COL
StorageDead(_10); // scope 2 at $SRC_DIR/core/src/cmp.rs:LL:COL
StorageDead(_4); // scope 0 at $DIR/ref_int_cmp.rs:+1:7: +1:8
StorageDead(_3); // scope 0 at $DIR/ref_int_cmp.rs:+1:7: +1:8
StorageDead(_5); // scope 0 at $DIR/ref_int_cmp.rs:+2:1: +2:2
return; // scope 0 at $DIR/ref_int_cmp.rs:+2:2: +2:2
}
}

10 changes: 10 additions & 0 deletions tests/mir-opt/ref_int_cmp.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// compile-flags: -O -Zmir-opt-level=3

// EMIT_MIR ref_int_cmp.opt1.RefCmpSimplify.diff
pub fn opt1(x: &u8, y: &u8) -> bool {
x < y
}

fn main() {
opt1(&1, &2);
}