Skip to content

New MIR opt pass simplify_pow_of_two #114254

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions compiler/rustc_middle/src/mir/interpret/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ impl<'tcx> ConstValue<'tcx> {
ConstValue::Scalar(Scalar::from_bool(b))
}

pub fn from_u32(i: u32) -> Self {
ConstValue::Scalar(Scalar::from_u32(i))
}

pub fn from_u64(i: u64) -> Self {
ConstValue::Scalar(Scalar::from_u64(i))
}
Expand Down
4 changes: 4 additions & 0 deletions compiler/rustc_middle/src/ty/sty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1872,6 +1872,10 @@ impl<'tcx> Region<'tcx> {

/// Constructors for `Ty`
impl<'tcx> Ty<'tcx> {
pub fn new_bool(tcx: TyCtxt<'tcx>) -> Ty<'tcx> {
Ty::new(tcx, TyKind::Bool)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is pre-interned as tcx.types.bool

}

// Avoid this in favour of more specific `new_*` methods, where possible.
#[allow(rustc::usage_of_ty_tykind)]
#[inline]
Expand Down
2 changes: 2 additions & 0 deletions compiler/rustc_mir_transform/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ mod required_consts;
mod reveal_all;
mod separate_const_switch;
mod shim;
mod simplify_pow_of_two;
mod ssa;
// This pass is public to allow external drivers to perform MIR cleanup
mod check_alignment;
Expand Down Expand Up @@ -546,6 +547,7 @@ fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
&lower_slice_len::LowerSliceLenCalls, // has to be done before inlining, otherwise actual call will be almost always inlined. Also simple, so can just do first
&unreachable_prop::UnreachablePropagation,
&uninhabited_enum_branching::UninhabitedEnumBranching,
&simplify_pow_of_two::SimplifyPowOfTwo,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should probably go after const prop, to increase the likelihood to have constants.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's after inlining, with the current way it works there's a chance the inliner will inline the call, and thus we won't detect it.

&o1(simplify::SimplifyCfg::AfterUninhabitedEnumBranching),
&inline::Inline,
&remove_storage_markers::RemoveStorageMarkers,
Expand Down
212 changes: 212 additions & 0 deletions compiler/rustc_mir_transform/src/simplify_pow_of_two.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
//! A pass that checks for and simplifies calls to `pow` where the receiver is a power of
//! two. This can be done with `<<` instead.

use crate::MirPass;
use rustc_const_eval::interpret::{ConstValue, Scalar};
use rustc_middle::mir::patch::MirPatch;
use rustc_middle::mir::*;
use rustc_middle::ty::{self, Ty, TyCtxt, UintTy};
use rustc_span::sym;
use rustc_target::abi::FieldIdx;

pub struct SimplifyPowOfTwo;

impl<'tcx> MirPass<'tcx> for SimplifyPowOfTwo {
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
let mut patch = MirPatch::new(body);

for (i, bb) in body.basic_blocks.iter_enumerated() {
let term = bb.terminator();
let source_info = term.source_info;
let span = source_info.span;

if let TerminatorKind::Call {
func,
args,
destination,
target: Some(target),
call_source: CallSource::Normal,
..
} = &term.kind
&& let Some(def_id) = func.const_fn_def().map(|def| def.0)
&& let def_path = tcx.def_path(def_id)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you need the Def path ? The crate is already 'def_id.krate'.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Artifact from how it worked before

&& tcx.crate_name(def_path.krate) == sym::core
&& let [recv, exp] = args.as_slice()
&& let Some(recv_const) = recv.constant()
&& let ConstantKind::Val(
ConstValue::Scalar(Scalar::Int(recv_int)),
recv_ty,
) = recv_const.literal
&& recv_ty.is_integral()
&& tcx.item_name(def_id) == sym::pow
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This detection mecanism is brittle. The compiler should be as independent as possible from the exact paths in core. Can diagnostic items be used for this ?

Copy link
Member Author

@Centri3 Centri3 Jul 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm ok with it, but if so we should use lang items instead as pointed out by #114254 (comment). But this feels like it'd be very verbose and slow, as we'd need 12 lang items for them all and then check whether it's any (rather than just if the item name is pow).

&& let Ok(recv_val) = match recv_ty.kind() {
ty::Int(_) => {
let result = recv_int.try_to_int(recv_int.size()).unwrap_or(-1).max(0);
if result > 0 {
Ok(result as u128)
} else {
continue;
}
},
ty::Uint(_) => recv_int.try_to_uint(recv_int.size()),
_ => continue,
}
&& let power_used = f32::log2(recv_val as f32)
// Precision loss means it's not a power of two
&& power_used == (power_used as u32) as f32
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why use floats instead of ilog2?

Copy link
Member Author

@Centri3 Centri3 Jul 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's unfortunately no distinction between whether it's power of two there. i32::ilog2(4) and i32::log2(5) both return 2.

Oh nvm, this can almost certainly be followed by a call to 2.pow and see if that matches the original value

// `0` would be `1.pow()`, which we shouldn't try to optimize as it's
// already entirely optimized away
&& power_used != 0.0
// `-inf` would be `0.pow()`
&& power_used.is_finite()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you split this huge chain into 'let ... else { continue}' ?

{
let power_used = power_used as u32;
let loc = Location { block: i, statement_index: bb.statements.len() };
let exp_ty = Ty::new(tcx, ty::Uint(UintTy::U32));
let checked_mul =
patch.new_temp(Ty::new_tup(tcx, &[exp_ty, Ty::new_bool(tcx)]), span);

// If this is not `2.pow(...)`, we need to multiply the number of times we
// shift the bits left by the receiver's power of two used, e.g.:
//
// > 2 -> 1
// > 4 -> 2
// > 16 -> 4
// > 256 -> 8
//
// If this is `1`, then we *could* remove this entirely but it'll be
// optimized out anyway by later passes (or perhaps LLVM) so it's entirely
// unnecessary to do so.
patch.add_assign(
loc,
checked_mul.into(),
Rvalue::CheckedBinaryOp(
BinOp::Mul,
Box::new((
exp.clone(),
Operand::Constant(Box::new(Constant {
span,
user_ty: None,
literal: ConstantKind::Val(
ConstValue::from_u32(power_used),
exp_ty,
),
})),
)),
),
);

let num_shl = tcx.mk_place_field(checked_mul.into(), FieldIdx::from_u32(0), exp_ty);
let mul_result = tcx.mk_place_field(
checked_mul.into(),
FieldIdx::from_u32(1),
Ty::new_bool(tcx),
);
let shl_result = patch.new_temp(Ty::new_bool(tcx), span);

// Whether the shl will overflow, if so we return 0. We can do this rather
// than doing a shr because only one bit is set on any power of two
patch.add_assign(
loc,
shl_result.into(),
Rvalue::BinaryOp(
BinOp::Lt,
Box::new((
Operand::Copy(num_shl),
Operand::Constant(Box::new(Constant {
span,
user_ty: None,
literal: ConstantKind::Val(
ConstValue::from_u32(recv_int.size().bits() as u32),
exp_ty,
),
})),
)),
),
);

let fine_bool = patch.new_temp(Ty::new_bool(tcx), span);
let fine = patch.new_temp(recv_ty, span);

patch.add_assign(
loc,
fine_bool.into(),
Rvalue::BinaryOp(
BinOp::BitOr,
Box::new((
Operand::Copy(mul_result.into()),
Operand::Copy(shl_result.into()),
)),
),
);

patch.add_assign(
loc,
fine.into(),
Rvalue::Cast(CastKind::IntToInt, Operand::Copy(fine_bool.into()), recv_ty),
);

let shl = patch.new_temp(recv_ty, span);

patch.add_assign(
loc,
shl.into(),
Rvalue::BinaryOp(
BinOp::Shl,
Box::new((
Operand::Constant(Box::new(Constant {
span,
user_ty: None,
literal: ConstantKind::Val(
ConstValue::Scalar(Scalar::from_uint(1u128, recv_int.size())),
recv_ty,
),
})),
Operand::Copy(num_shl.into()),
)),
),
);

patch.add_assign(
loc,
*destination,
Rvalue::BinaryOp(
BinOp::MulUnchecked,
Box::new((Operand::Copy(shl.into()), Operand::Copy(fine.into()))),
),
);

// FIXME(Centri3): Do we use `debug_assertions` or `overflow_checks` here?
if tcx.sess.opts.debug_assertions {
patch.patch_terminator(
i,
TerminatorKind::Assert {
cond: Operand::Copy(fine_bool.into()),
expected: true,
msg: Box::new(AssertMessage::Overflow(
// For consistency with the previous error message, though
// it's technically incorrect
BinOp::Mul,
Operand::Constant(Box::new(Constant {
span,
user_ty: None,
literal: ConstantKind::Val(
ConstValue::Scalar(Scalar::from_u32(1)),
exp_ty,
),
})),
Operand::Copy(num_shl.into()),
)),
target: *target,
unwind: UnwindAction::Continue,
},
);
} else {
patch.patch_terminator(i, TerminatorKind::Goto { target: *target });
}
}
}

patch.apply(body);
}
}
2 changes: 2 additions & 0 deletions compiler/rustc_span/src/symbol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1054,6 +1054,7 @@ symbols! {
not,
notable_trait,
note,
num,
object_safe_for_dispatch,
of,
offset,
Expand Down Expand Up @@ -1121,6 +1122,7 @@ symbols! {
poll,
position,
post_dash_lto: "post-lto",
pow,
powerpc_target_feature,
powf32,
powf64,
Expand Down
19 changes: 19 additions & 0 deletions tests/codegen/simplify-pow-of-two-debug-assertions.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// compile-flags: -Copt-level=3 -Cdebug-assertions=true

// CHECK-LABEL: @slow_2_u(
#[no_mangle]
fn slow_2_u(a: u32) -> u32 {
// CHECK: %_3 = icmp ult i32 %a, 32
// CHECK-NEXT: br i1 %_3, label %bb1, label %panic, !prof !{{[0-9]+}}
// CHECK-EMPTY:
// CHECK-NEXT: bb1:
// CHECK-NEXT: %_01 = shl nuw i32 1, %a
// CHECK-NEXT: ret i32 %_0
// CHECK-EMPTY:
// CHECK-NEXT: panic:
2u32.pow(a)
}

fn main() {
slow_2_u(2);
}
16 changes: 16 additions & 0 deletions tests/codegen/simplify-pow-of-two.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// compile-flags: -Copt-level=3

// CHECK-LABEL: @slow_2_u(
#[no_mangle]
fn slow_2_u(a: u32) -> u32 {
// CHECK: %_3 = icmp ult i32 %a, 32
// CHECK-NEXT: %_5 = zext i1 %_3 to i32
// CHECK-NEXT: %0 = and i32 %a, 31
// CHECK-NEXT: %_01 = shl nuw i32 %_5, %0
// CHECK-NEXT: ret i32 %_01
2u32.pow(a)
}

fn main() {
slow_2_u(2);
}
51 changes: 51 additions & 0 deletions tests/mir-opt/simplify_pow_of_two_no_overflow_checks.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// unit-test: SimplifyPowOfTwo
// compile-flags: -Cdebug-assertions=false

// EMIT_MIR simplify_pow_of_two_no_overflow_checks.slow_2_u.SimplifyPowOfTwo.after.mir
fn slow_2_u(a: u32) -> u32 {
2u32.pow(a)
}

// EMIT_MIR simplify_pow_of_two_no_overflow_checks.slow_2_i.SimplifyPowOfTwo.after.mir
fn slow_2_i(a: u32) -> i32 {
2i32.pow(a)
}

// EMIT_MIR simplify_pow_of_two_no_overflow_checks.slow_4_u.SimplifyPowOfTwo.after.mir
fn slow_4_u(a: u32) -> u32 {
4u32.pow(a)
}

// EMIT_MIR simplify_pow_of_two_no_overflow_checks.slow_4_i.SimplifyPowOfTwo.after.mir
fn slow_4_i(a: u32) -> i32 {
4i32.pow(a)
}

// EMIT_MIR simplify_pow_of_two_no_overflow_checks.slow_256_u.SimplifyPowOfTwo.after.mir
fn slow_256_u(a: u32) -> u32 {
256u32.pow(a)
}

// EMIT_MIR simplify_pow_of_two_no_overflow_checks.slow_256_i.SimplifyPowOfTwo.after.mir
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// EMIT_MIR simplify_pow_of_two_no_overflow_checks.slow_256_i.SimplifyPowOfTwo.after.mir
// EMIT_MIR simplify_pow_of_two_no_overflow_checks.slow_256_i.SimplifyPowOfTwo.diff

makes it easier for us to see what the optimization does

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

edit: ah, but it doesn't really optimize away anything, just replace a function... oh well 🤷

Copy link
Member Author

@Centri3 Centri3 Jul 31, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually wanted to use diff as well, just for clarity's sake though it seems from stage1 to stage2 it changes an unwind unreachable to unwind continue. It's the reason why CI was failing before. Any idea why that happens/how it can be prevented? Because I'd definitely prefer diff, even just for the highlighting

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's probably not a stage1/stage2 difference, but a problem with the fact that we are running mir-opt tests for panic=unwind and panic=abort and are expecting the results to match. You can use // EMIT_MIR_FOR_EACH_PANIC_STRATEGY to generate separate output files

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ohh, that makes sense, thanks!

fn slow_256_i(a: u32) -> i32 {
256i32.pow(a)
}

fn main() {
slow_2_u(0);
slow_2_i(0);
slow_2_u(1);
slow_2_i(1);
slow_2_u(2);
slow_2_i(2);
slow_4_u(4);
slow_4_i(4);
slow_4_u(15);
slow_4_i(15);
slow_4_u(16);
slow_4_i(16);
slow_4_u(17);
slow_4_i(17);
slow_256_u(2);
slow_256_i(2);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// MIR for `slow_256_i` after SimplifyPowOfTwo

fn slow_256_i(_1: u32) -> i32 {
debug a => _1;
let mut _0: i32;
let mut _2: u32;
let mut _3: (u32, bool);
let mut _4: bool;
let mut _5: bool;
let mut _6: i32;
let mut _7: i32;

bb0: {
StorageLive(_2);
_2 = _1;
_3 = CheckedMul(move _2, const 8_u32);
_4 = Lt((_3.0: u32), const 32_u32);
_5 = BitOr((_3.1: bool), _4);
_6 = _5 as i32 (IntToInt);
_7 = Shl(const 1_i32, (_3.0: u32));
_0 = MulUnchecked(_7, _6);
goto -> bb1;
}

bb1: {
StorageDead(_2);
return;
}
}
Loading