Skip to content

Add MIR pass to lower call to core::slice::len into Len operand #86383

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions compiler/rustc_hir/src/lang_items.rs
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,8 @@ language_item_table! {

Try, sym::Try, try_trait, Target::Trait;

SliceLen, sym::slice_len_fn, slice_len_fn, Target::Method(MethodKind::Inherent);

// Language items from AST lowering
TryTraitFromResidual, sym::from_residual, from_residual_fn, Target::Method(MethodKind::Trait { body: false });
TryTraitFromOutput, sym::from_output, from_output_fn, Target::Method(MethodKind::Trait { body: false });
Expand Down
100 changes: 100 additions & 0 deletions compiler/rustc_mir/src/transform/lower_slice_len.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
//! This pass lowers calls to core::slice::len to just Len op.
//! It should run before inlining!

use crate::transform::MirPass;
use rustc_hir::def_id::DefId;
use rustc_index::vec::IndexVec;
use rustc_middle::mir::*;
use rustc_middle::ty::{self, TyCtxt};

pub struct LowerSliceLenCalls;

impl<'tcx> MirPass<'tcx> for LowerSliceLenCalls {
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
lower_slice_len_calls(tcx, body)
}
}

pub fn lower_slice_len_calls<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
let language_items = tcx.lang_items();
let slice_len_fn_item_def_id = if let Some(slice_len_fn_item) = language_items.slice_len_fn() {
slice_len_fn_item
} else {
// there is no language item to compare to :)
return;
};

let (basic_blocks, local_decls) = body.basic_blocks_and_local_decls_mut();

for block in basic_blocks {
// lower `<[_]>::len` calls
lower_slice_len_call(tcx, block, &*local_decls, slice_len_fn_item_def_id);
}
}

struct SliceLenPatchInformation<'tcx> {
add_statement: Statement<'tcx>,
new_terminator_kind: TerminatorKind<'tcx>,
}

fn lower_slice_len_call<'tcx>(
tcx: TyCtxt<'tcx>,
block: &mut BasicBlockData<'tcx>,
local_decls: &IndexVec<Local, LocalDecl<'tcx>>,
slice_len_fn_item_def_id: DefId,
) {
let mut patch_found: Option<SliceLenPatchInformation<'_>> = None;

let terminator = block.terminator();
match &terminator.kind {
TerminatorKind::Call {
func,
args,
destination: Some((dest, bb)),
cleanup: None,
from_hir_call: true,
..
} => {
// some heuristics for fast rejection
if args.len() != 1 {
return;
}
let arg = match args[0].place() {
Some(arg) => arg,
None => return,
};
let func_ty = func.ty(local_decls, tcx);
match func_ty.kind() {
ty::FnDef(fn_def_id, _) if fn_def_id == &slice_len_fn_item_def_id => {
// perform modifications
// from something like `_5 = core::slice::<impl [u8]>::len(move _6) -> bb1`
// into `_5 = Len(*_6)
// goto bb1

// make new RValue for Len
let deref_arg = tcx.mk_place_deref(arg);
let r_value = Rvalue::Len(deref_arg);
let len_statement_kind = StatementKind::Assign(Box::new((*dest, r_value)));
let add_statement = Statement {
kind: len_statement_kind,
source_info: terminator.source_info.clone(),
};

// modify terminator into simple Goto
let new_terminator_kind = TerminatorKind::Goto { target: bb.clone() };

let patch = SliceLenPatchInformation { add_statement, new_terminator_kind };

patch_found = Some(patch);
}
_ => {}
}
}
_ => {}
}

if let Some(SliceLenPatchInformation { add_statement, new_terminator_kind }) = patch_found {
block.statements.push(add_statement);
block.terminator_mut().kind = new_terminator_kind;
}
}
2 changes: 2 additions & 0 deletions compiler/rustc_mir/src/transform/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ pub mod generator;
pub mod inline;
pub mod instcombine;
pub mod lower_intrinsics;
pub mod lower_slice_len;
pub mod match_branches;
pub mod multiple_return_terminators;
pub mod no_landing_pads;
Expand Down Expand Up @@ -479,6 +480,7 @@ fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
// to them. We run some optimizations before that, because they may be harder to do on the state
// machine than on MIR with async primitives.
let optimizations_with_generators: &[&dyn MirPass<'tcx>] = &[
&lower_slice_len::LowerSliceLenCalls, // has to be done before inlining, otherwise actual call will be almost always inlined. Also simple, so can just do first
&unreachable_prop::UnreachablePropagation,
&uninhabited_enum_branching::UninhabitedEnumBranching,
&simplify::SimplifyCfg::new("after-uninhabited-enum-branching"),
Expand Down
2 changes: 2 additions & 0 deletions compiler/rustc_span/src/symbol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -681,6 +681,7 @@ symbols! {
lateout,
lazy_normalization_consts,
le,
len,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the addition of the slice_len_fn lang item, this is now unused, right ?

Suggested change
len,

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe it was a source of a performance gain in the first performance CI test: I was using a combination of slice_impl + len to find out a function, that is not optimal, but in some tests performance gains were substantial and I believe it may have helped. In any case, it's quite common expression and it may be good to intern in any case

let_chains,
lhs,
lib,
Expand Down Expand Up @@ -1147,6 +1148,7 @@ symbols! {
skip,
slice,
slice_alloc,
slice_len_fn,
slice_patterns,
slice_u8,
slice_u8_alloc,
Expand Down
1 change: 1 addition & 0 deletions library/core/src/slice/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ impl<T> [T] {
/// assert_eq!(a.len(), 3);
/// ```
#[doc(alias = "length")]
#[cfg_attr(not(bootstrap), lang = "slice_len_fn")]
#[stable(feature = "rust1", since = "1.0.0")]
#[rustc_const_stable(feature = "const_slice_len", since = "1.39.0")]
#[inline]
Expand Down
2 changes: 1 addition & 1 deletion library/std/src/thread/local/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ fn join_orders_after_tls_destructors() {
.unwrap();

loop {
match SYNC_STATE.compare_exchange_weak(
match SYNC_STATE.compare_exchange(
THREAD1_WAITING,
MAIN_THREAD_RENDEZVOUS,
Ordering::SeqCst,
Expand Down
63 changes: 63 additions & 0 deletions src/test/mir-opt/lower_slice_len.bound.LowerSliceLenCalls.diff
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
- // MIR for `bound` before LowerSliceLenCalls
+ // MIR for `bound` after LowerSliceLenCalls

fn bound(_1: usize, _2: &[u8]) -> u8 {
debug index => _1; // in scope 0 at $DIR/lower_slice_len.rs:4:14: 4:19
debug slice => _2; // in scope 0 at $DIR/lower_slice_len.rs:4:28: 4:33
let mut _0: u8; // return place in scope 0 at $DIR/lower_slice_len.rs:4:45: 4:47
let mut _3: bool; // in scope 0 at $DIR/lower_slice_len.rs:5:8: 5:27
let mut _4: usize; // in scope 0 at $DIR/lower_slice_len.rs:5:8: 5:13
let mut _5: usize; // in scope 0 at $DIR/lower_slice_len.rs:5:16: 5:27
let mut _6: &[u8]; // in scope 0 at $DIR/lower_slice_len.rs:5:16: 5:21
let _7: usize; // in scope 0 at $DIR/lower_slice_len.rs:6:15: 6:20
let mut _8: usize; // in scope 0 at $DIR/lower_slice_len.rs:6:9: 6:21
let mut _9: bool; // in scope 0 at $DIR/lower_slice_len.rs:6:9: 6:21

bb0: {
StorageLive(_3); // scope 0 at $DIR/lower_slice_len.rs:5:8: 5:27
StorageLive(_4); // scope 0 at $DIR/lower_slice_len.rs:5:8: 5:13
_4 = _1; // scope 0 at $DIR/lower_slice_len.rs:5:8: 5:13
StorageLive(_5); // scope 0 at $DIR/lower_slice_len.rs:5:16: 5:27
StorageLive(_6); // scope 0 at $DIR/lower_slice_len.rs:5:16: 5:21
_6 = &(*_2); // scope 0 at $DIR/lower_slice_len.rs:5:16: 5:21
- _5 = core::slice::<impl [u8]>::len(move _6) -> bb1; // scope 0 at $DIR/lower_slice_len.rs:5:16: 5:27
- // mir::Constant
- // + span: $DIR/lower_slice_len.rs:5:22: 5:25
- // + literal: Const { ty: for<'r> fn(&'r [u8]) -> usize {core::slice::<impl [u8]>::len}, val: Value(Scalar(<ZST>)) }
+ _5 = Len((*_6)); // scope 0 at $DIR/lower_slice_len.rs:5:16: 5:27
+ goto -> bb1; // scope 0 at $DIR/lower_slice_len.rs:5:16: 5:27
}

bb1: {
StorageDead(_6); // scope 0 at $DIR/lower_slice_len.rs:5:26: 5:27
_3 = Lt(move _4, move _5); // scope 0 at $DIR/lower_slice_len.rs:5:8: 5:27
StorageDead(_5); // scope 0 at $DIR/lower_slice_len.rs:5:26: 5:27
StorageDead(_4); // scope 0 at $DIR/lower_slice_len.rs:5:26: 5:27
switchInt(move _3) -> [false: bb3, otherwise: bb2]; // scope 0 at $DIR/lower_slice_len.rs:5:5: 9:6
}

bb2: {
StorageLive(_7); // scope 0 at $DIR/lower_slice_len.rs:6:15: 6:20
_7 = _1; // scope 0 at $DIR/lower_slice_len.rs:6:15: 6:20
_8 = Len((*_2)); // scope 0 at $DIR/lower_slice_len.rs:6:9: 6:21
_9 = Lt(_7, _8); // scope 0 at $DIR/lower_slice_len.rs:6:9: 6:21
assert(move _9, "index out of bounds: the length is {} but the index is {}", move _8, _7) -> bb4; // scope 0 at $DIR/lower_slice_len.rs:6:9: 6:21
}

bb3: {
_0 = const 42_u8; // scope 0 at $DIR/lower_slice_len.rs:8:9: 8:11
goto -> bb5; // scope 0 at $DIR/lower_slice_len.rs:5:5: 9:6
}

bb4: {
_0 = (*_2)[_7]; // scope 0 at $DIR/lower_slice_len.rs:6:9: 6:21
StorageDead(_7); // scope 0 at $DIR/lower_slice_len.rs:7:5: 7:6
goto -> bb5; // scope 0 at $DIR/lower_slice_len.rs:5:5: 9:6
}

bb5: {
StorageDead(_3); // scope 0 at $DIR/lower_slice_len.rs:9:5: 9:6
return; // scope 0 at $DIR/lower_slice_len.rs:10:2: 10:2
}
}

14 changes: 14 additions & 0 deletions src/test/mir-opt/lower_slice_len.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// compile-flags: -Z mir-opt-level=3

// EMIT_MIR lower_slice_len.bound.LowerSliceLenCalls.diff
pub fn bound(index: usize, slice: &[u8]) -> u8 {
if index < slice.len() {
slice[index]
} else {
42
}
}

fn main() {
let _ = bound(1, &[1, 2, 3]);
}