-
Notifications
You must be signed in to change notification settings - Fork 13.3k
New MIR opt pass simplify_pow_of_two
#114254
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
7008cce
cb2b3ea
c4b8c57
9b29907
f944eba
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -95,6 +95,7 @@ mod required_consts; | |
mod reveal_all; | ||
mod separate_const_switch; | ||
mod shim; | ||
mod simplify_pow_of_two; | ||
mod ssa; | ||
// This pass is public to allow external drivers to perform MIR cleanup | ||
mod check_alignment; | ||
|
@@ -546,6 +547,7 @@ fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { | |
&lower_slice_len::LowerSliceLenCalls, // has to be done before inlining, otherwise actual call will be almost always inlined. Also simple, so can just do first | ||
&unreachable_prop::UnreachablePropagation, | ||
&uninhabited_enum_branching::UninhabitedEnumBranching, | ||
&simplify_pow_of_two::SimplifyPowOfTwo, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should probably go after const prop, to increase the likelihood to have constants. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's after inlining, with the current way it works there's a chance the inliner will inline the call, and thus we won't detect it. |
||
&o1(simplify::SimplifyCfg::AfterUninhabitedEnumBranching), | ||
&inline::Inline, | ||
&remove_storage_markers::RemoveStorageMarkers, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,212 @@ | ||
//! A pass that checks for and simplifies calls to `pow` where the receiver is a power of | ||
//! two. This can be done with `<<` instead. | ||
|
||
use crate::MirPass; | ||
use rustc_const_eval::interpret::{ConstValue, Scalar}; | ||
use rustc_middle::mir::patch::MirPatch; | ||
use rustc_middle::mir::*; | ||
use rustc_middle::ty::{self, Ty, TyCtxt, UintTy}; | ||
use rustc_span::sym; | ||
use rustc_target::abi::FieldIdx; | ||
|
||
pub struct SimplifyPowOfTwo; | ||
|
||
impl<'tcx> MirPass<'tcx> for SimplifyPowOfTwo { | ||
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { | ||
let mut patch = MirPatch::new(body); | ||
|
||
for (i, bb) in body.basic_blocks.iter_enumerated() { | ||
let term = bb.terminator(); | ||
let source_info = term.source_info; | ||
let span = source_info.span; | ||
|
||
if let TerminatorKind::Call { | ||
func, | ||
args, | ||
destination, | ||
target: Some(target), | ||
call_source: CallSource::Normal, | ||
.. | ||
} = &term.kind | ||
&& let Some(def_id) = func.const_fn_def().map(|def| def.0) | ||
&& let def_path = tcx.def_path(def_id) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do you need the Def path ? The crate is already 'def_id.krate'. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Artifact from how it worked before |
||
&& tcx.crate_name(def_path.krate) == sym::core | ||
&& let [recv, exp] = args.as_slice() | ||
&& let Some(recv_const) = recv.constant() | ||
&& let ConstantKind::Val( | ||
ConstValue::Scalar(Scalar::Int(recv_int)), | ||
recv_ty, | ||
) = recv_const.literal | ||
&& recv_ty.is_integral() | ||
&& tcx.item_name(def_id) == sym::pow | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This detection mecanism is brittle. The compiler should be as independent as possible from the exact paths in core. Can diagnostic items be used for this ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm ok with it, but if so we should use lang items instead as pointed out by #114254 (comment). But this feels like it'd be very verbose and slow, as we'd need 12 lang items for them all and then check whether it's any (rather than just if the item name is |
||
&& let Ok(recv_val) = match recv_ty.kind() { | ||
ty::Int(_) => { | ||
let result = recv_int.try_to_int(recv_int.size()).unwrap_or(-1).max(0); | ||
if result > 0 { | ||
Ok(result as u128) | ||
} else { | ||
continue; | ||
} | ||
}, | ||
ty::Uint(_) => recv_int.try_to_uint(recv_int.size()), | ||
_ => continue, | ||
} | ||
&& let power_used = f32::log2(recv_val as f32) | ||
// Precision loss means it's not a power of two | ||
&& power_used == (power_used as u32) as f32 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why use floats instead of ilog2? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Oh nvm, this can almost certainly be followed by a call to |
||
// `0` would be `1.pow()`, which we shouldn't try to optimize as it's | ||
// already entirely optimized away | ||
&& power_used != 0.0 | ||
// `-inf` would be `0.pow()` | ||
&& power_used.is_finite() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you split this huge chain into 'let ... else { continue}' ? |
||
{ | ||
let power_used = power_used as u32; | ||
let loc = Location { block: i, statement_index: bb.statements.len() }; | ||
let exp_ty = Ty::new(tcx, ty::Uint(UintTy::U32)); | ||
let checked_mul = | ||
patch.new_temp(Ty::new_tup(tcx, &[exp_ty, Ty::new_bool(tcx)]), span); | ||
|
||
// If this is not `2.pow(...)`, we need to multiply the number of times we | ||
// shift the bits left by the receiver's power of two used, e.g.: | ||
// | ||
// > 2 -> 1 | ||
// > 4 -> 2 | ||
// > 16 -> 4 | ||
// > 256 -> 8 | ||
// | ||
// If this is `1`, then we *could* remove this entirely but it'll be | ||
// optimized out anyway by later passes (or perhaps LLVM) so it's entirely | ||
// unnecessary to do so. | ||
patch.add_assign( | ||
loc, | ||
checked_mul.into(), | ||
Rvalue::CheckedBinaryOp( | ||
BinOp::Mul, | ||
Box::new(( | ||
exp.clone(), | ||
Operand::Constant(Box::new(Constant { | ||
span, | ||
user_ty: None, | ||
literal: ConstantKind::Val( | ||
ConstValue::from_u32(power_used), | ||
exp_ty, | ||
), | ||
})), | ||
)), | ||
), | ||
); | ||
|
||
let num_shl = tcx.mk_place_field(checked_mul.into(), FieldIdx::from_u32(0), exp_ty); | ||
let mul_result = tcx.mk_place_field( | ||
checked_mul.into(), | ||
FieldIdx::from_u32(1), | ||
Ty::new_bool(tcx), | ||
); | ||
let shl_result = patch.new_temp(Ty::new_bool(tcx), span); | ||
|
||
// Whether the shl will overflow, if so we return 0. We can do this rather | ||
// than doing a shr because only one bit is set on any power of two | ||
patch.add_assign( | ||
loc, | ||
shl_result.into(), | ||
Rvalue::BinaryOp( | ||
BinOp::Lt, | ||
Box::new(( | ||
Operand::Copy(num_shl), | ||
Operand::Constant(Box::new(Constant { | ||
span, | ||
user_ty: None, | ||
literal: ConstantKind::Val( | ||
ConstValue::from_u32(recv_int.size().bits() as u32), | ||
exp_ty, | ||
), | ||
})), | ||
)), | ||
), | ||
); | ||
|
||
let fine_bool = patch.new_temp(Ty::new_bool(tcx), span); | ||
let fine = patch.new_temp(recv_ty, span); | ||
|
||
patch.add_assign( | ||
loc, | ||
fine_bool.into(), | ||
Rvalue::BinaryOp( | ||
BinOp::BitOr, | ||
Box::new(( | ||
Operand::Copy(mul_result.into()), | ||
Operand::Copy(shl_result.into()), | ||
)), | ||
), | ||
); | ||
|
||
patch.add_assign( | ||
loc, | ||
fine.into(), | ||
Rvalue::Cast(CastKind::IntToInt, Operand::Copy(fine_bool.into()), recv_ty), | ||
); | ||
|
||
let shl = patch.new_temp(recv_ty, span); | ||
|
||
patch.add_assign( | ||
loc, | ||
shl.into(), | ||
Rvalue::BinaryOp( | ||
BinOp::Shl, | ||
Box::new(( | ||
Operand::Constant(Box::new(Constant { | ||
span, | ||
user_ty: None, | ||
literal: ConstantKind::Val( | ||
ConstValue::Scalar(Scalar::from_uint(1u128, recv_int.size())), | ||
recv_ty, | ||
), | ||
})), | ||
Operand::Copy(num_shl.into()), | ||
)), | ||
), | ||
); | ||
|
||
patch.add_assign( | ||
loc, | ||
*destination, | ||
Rvalue::BinaryOp( | ||
BinOp::MulUnchecked, | ||
Box::new((Operand::Copy(shl.into()), Operand::Copy(fine.into()))), | ||
), | ||
); | ||
|
||
// FIXME(Centri3): Do we use `debug_assertions` or `overflow_checks` here? | ||
if tcx.sess.opts.debug_assertions { | ||
patch.patch_terminator( | ||
i, | ||
TerminatorKind::Assert { | ||
cond: Operand::Copy(fine_bool.into()), | ||
expected: true, | ||
msg: Box::new(AssertMessage::Overflow( | ||
// For consistency with the previous error message, though | ||
// it's technically incorrect | ||
BinOp::Mul, | ||
Operand::Constant(Box::new(Constant { | ||
span, | ||
user_ty: None, | ||
literal: ConstantKind::Val( | ||
ConstValue::Scalar(Scalar::from_u32(1)), | ||
exp_ty, | ||
), | ||
})), | ||
Operand::Copy(num_shl.into()), | ||
)), | ||
target: *target, | ||
unwind: UnwindAction::Continue, | ||
}, | ||
); | ||
} else { | ||
patch.patch_terminator(i, TerminatorKind::Goto { target: *target }); | ||
} | ||
} | ||
} | ||
|
||
patch.apply(body); | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
// compile-flags: -Copt-level=3 -Cdebug-assertions=true | ||
|
||
// CHECK-LABEL: @slow_2_u( | ||
#[no_mangle] | ||
fn slow_2_u(a: u32) -> u32 { | ||
// CHECK: %_3 = icmp ult i32 %a, 32 | ||
// CHECK-NEXT: br i1 %_3, label %bb1, label %panic, !prof !{{[0-9]+}} | ||
// CHECK-EMPTY: | ||
// CHECK-NEXT: bb1: | ||
// CHECK-NEXT: %_01 = shl nuw i32 1, %a | ||
// CHECK-NEXT: ret i32 %_0 | ||
// CHECK-EMPTY: | ||
// CHECK-NEXT: panic: | ||
2u32.pow(a) | ||
} | ||
|
||
fn main() { | ||
slow_2_u(2); | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
// compile-flags: -Copt-level=3 | ||
|
||
// CHECK-LABEL: @slow_2_u( | ||
#[no_mangle] | ||
fn slow_2_u(a: u32) -> u32 { | ||
// CHECK: %_3 = icmp ult i32 %a, 32 | ||
// CHECK-NEXT: %_5 = zext i1 %_3 to i32 | ||
// CHECK-NEXT: %0 = and i32 %a, 31 | ||
// CHECK-NEXT: %_01 = shl nuw i32 %_5, %0 | ||
// CHECK-NEXT: ret i32 %_01 | ||
2u32.pow(a) | ||
} | ||
|
||
fn main() { | ||
slow_2_u(2); | ||
} |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,51 @@ | ||||||
// unit-test: SimplifyPowOfTwo | ||||||
// compile-flags: -Cdebug-assertions=false | ||||||
|
||||||
// EMIT_MIR simplify_pow_of_two_no_overflow_checks.slow_2_u.SimplifyPowOfTwo.after.mir | ||||||
fn slow_2_u(a: u32) -> u32 { | ||||||
2u32.pow(a) | ||||||
} | ||||||
|
||||||
// EMIT_MIR simplify_pow_of_two_no_overflow_checks.slow_2_i.SimplifyPowOfTwo.after.mir | ||||||
fn slow_2_i(a: u32) -> i32 { | ||||||
2i32.pow(a) | ||||||
} | ||||||
|
||||||
// EMIT_MIR simplify_pow_of_two_no_overflow_checks.slow_4_u.SimplifyPowOfTwo.after.mir | ||||||
fn slow_4_u(a: u32) -> u32 { | ||||||
4u32.pow(a) | ||||||
} | ||||||
|
||||||
// EMIT_MIR simplify_pow_of_two_no_overflow_checks.slow_4_i.SimplifyPowOfTwo.after.mir | ||||||
fn slow_4_i(a: u32) -> i32 { | ||||||
4i32.pow(a) | ||||||
} | ||||||
|
||||||
// EMIT_MIR simplify_pow_of_two_no_overflow_checks.slow_256_u.SimplifyPowOfTwo.after.mir | ||||||
fn slow_256_u(a: u32) -> u32 { | ||||||
256u32.pow(a) | ||||||
} | ||||||
|
||||||
// EMIT_MIR simplify_pow_of_two_no_overflow_checks.slow_256_i.SimplifyPowOfTwo.after.mir | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
makes it easier for us to see what the optimization does There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. edit: ah, but it doesn't really optimize away anything, just replace a function... oh well 🤷 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I actually wanted to use diff as well, just for clarity's sake though it seems from stage1 to stage2 it changes an There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's probably not a stage1/stage2 difference, but a problem with the fact that we are running mir-opt tests for panic=unwind and panic=abort and are expecting the results to match. You can use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ohh, that makes sense, thanks! |
||||||
fn slow_256_i(a: u32) -> i32 { | ||||||
256i32.pow(a) | ||||||
} | ||||||
|
||||||
fn main() { | ||||||
slow_2_u(0); | ||||||
slow_2_i(0); | ||||||
slow_2_u(1); | ||||||
slow_2_i(1); | ||||||
slow_2_u(2); | ||||||
slow_2_i(2); | ||||||
slow_4_u(4); | ||||||
slow_4_i(4); | ||||||
slow_4_u(15); | ||||||
slow_4_i(15); | ||||||
slow_4_u(16); | ||||||
slow_4_i(16); | ||||||
slow_4_u(17); | ||||||
slow_4_i(17); | ||||||
slow_256_u(2); | ||||||
slow_256_i(2); | ||||||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
// MIR for `slow_256_i` after SimplifyPowOfTwo | ||
|
||
fn slow_256_i(_1: u32) -> i32 { | ||
debug a => _1; | ||
let mut _0: i32; | ||
let mut _2: u32; | ||
let mut _3: (u32, bool); | ||
let mut _4: bool; | ||
let mut _5: bool; | ||
let mut _6: i32; | ||
let mut _7: i32; | ||
|
||
bb0: { | ||
StorageLive(_2); | ||
_2 = _1; | ||
_3 = CheckedMul(move _2, const 8_u32); | ||
_4 = Lt((_3.0: u32), const 32_u32); | ||
_5 = BitOr((_3.1: bool), _4); | ||
_6 = _5 as i32 (IntToInt); | ||
_7 = Shl(const 1_i32, (_3.0: u32)); | ||
_0 = MulUnchecked(_7, _6); | ||
goto -> bb1; | ||
} | ||
|
||
bb1: { | ||
StorageDead(_2); | ||
return; | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is pre-interned as tcx.types.bool