Skip to content

Commit

Permalink
Auto merge of #123179 - scottmcm:inlining-baseline-costs, r=<try>
Browse files Browse the repository at this point in the history
Rework MIR inlining costs

A bunch of the current costs are surprising, probably accidentally from from not writing out the matches in full.  For example, a runtime-length `memcpy` was treated as the same cost as an `Unreachable`.

This reworks things around two main ideas:
- Give everything a baseline cost, because even "free" things do take effort in the compiler (CPU & RAM) to MIR inline, and they're easy to calculate
- Then just penalize those things that are materially more than the baseline, like how `[foo; 123]` is far more work than `BinOp::AddUnchecked` in an `Rvalue`

By including costs for locals and vardebuginfo this makes some things overall more expensive, but because it also greatly reduces the cost for simple things like local variable addition, other things also become less expensive overall.

r? ghost
  • Loading branch information
bors committed Mar 29, 2024
2 parents ba52720 + 47a5a7f commit 8bbcded
Show file tree
Hide file tree
Showing 2 changed files with 139 additions and 19 deletions.
147 changes: 128 additions & 19 deletions compiler/rustc_mir_transform/src/cost_checker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,30 @@ use rustc_middle::mir::visit::*;
use rustc_middle::mir::*;
use rustc_middle::ty::{self, ParamEnv, Ty, TyCtxt};

const INSTR_COST: usize = 5;
const CALL_PENALTY: usize = 25;
const LANDINGPAD_PENALTY: usize = 50;
const RESUME_PENALTY: usize = 45;
// Even if they're zero-cost at runtime, everything has *some* cost to inline
// in terms of copying them into the MIR caller, processing them in codegen, etc.
// These baseline costs give a simple usually-too-low estimate of the cost,
// which will be updated afterwards to account for the "real" costs.
const STMT_BASELINE_COST: usize = 1;
const BLOCK_BASELINE_COST: usize = 3;
const DEBUG_BASELINE_COST: usize = 1;
const LOCAL_BASELINE_COST: usize = 1;

// These penalties represent the cost above baseline for those things which
// have substantially more cost than is typical for their kind.
const CALL_PENALTY: usize = 22;
const LANDINGPAD_PENALTY: usize = 47;
const RESUME_PENALTY: usize = 42;
const DEREF_PENALTY: usize = 4;
const CHECKED_OP_PENALTY: usize = 2;
const THREAD_LOCAL_PENALTY: usize = 20;
const SMALL_SWITCH_PENALTY: usize = 3;
const LARGE_SWITCH_PENALTY: usize = 20;

// Passing arguments isn't free, so give a bonus to functions with lots of them:
// if the body is small despite lots of arguments, some are probably unused.
const EXTRA_ARG_BONUS: usize = 4;
const MAX_ARG_BONUS: usize = CALL_PENALTY;

/// Verify that the callee body is compatible with the caller.
#[derive(Clone)]
Expand All @@ -27,6 +47,20 @@ impl<'b, 'tcx> CostChecker<'b, 'tcx> {
CostChecker { tcx, param_env, callee_body, instance, cost: 0 }
}

// `Inline` doesn't call `visit_body`, so this is separate from the visitor.
pub fn before_body(&mut self, body: &Body<'tcx>) {
self.cost += BLOCK_BASELINE_COST * body.basic_blocks.len();
self.cost += DEBUG_BASELINE_COST * body.var_debug_info.len();
self.cost += LOCAL_BASELINE_COST * body.local_decls.len();

let total_statements = body.basic_blocks.iter().map(|x| x.statements.len()).sum::<usize>();
self.cost += STMT_BASELINE_COST * total_statements;

if let Some(extra_args) = body.arg_count.checked_sub(2) {
self.cost = self.cost.saturating_sub((EXTRA_ARG_BONUS * extra_args).min(MAX_ARG_BONUS));
}
}

pub fn cost(&self) -> usize {
self.cost
}
Expand All @@ -41,14 +75,70 @@ impl<'b, 'tcx> CostChecker<'b, 'tcx> {
}

impl<'tcx> Visitor<'tcx> for CostChecker<'_, 'tcx> {
fn visit_statement(&mut self, statement: &Statement<'tcx>, _: Location) {
// Don't count StorageLive/StorageDead in the inlining cost.
match statement.kind {
StatementKind::StorageLive(_)
fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) {
match &statement.kind {
StatementKind::Assign(place_and_rvalue) => {
if place_and_rvalue.0.is_indirect_first_projection() {
self.cost += DEREF_PENALTY;
}
self.visit_rvalue(&place_and_rvalue.1, location);
}
StatementKind::Intrinsic(intr) => match &**intr {
NonDivergingIntrinsic::Assume(..) => {}
NonDivergingIntrinsic::CopyNonOverlapping(_cno) => {
self.cost += CALL_PENALTY;
}
},
StatementKind::FakeRead(..)
| StatementKind::SetDiscriminant { .. }
| StatementKind::StorageLive(_)
| StatementKind::StorageDead(_)
| StatementKind::Retag(..)
| StatementKind::PlaceMention(..)
| StatementKind::AscribeUserType(..)
| StatementKind::Coverage(..)
| StatementKind::Deinit(_)
| StatementKind::Nop => {}
_ => self.cost += INSTR_COST,
| StatementKind::ConstEvalCounter
| StatementKind::Nop => {
// No extra cost for these
}
}
}

fn visit_rvalue(&mut self, rvalue: &Rvalue<'tcx>, _location: Location) {
match rvalue {
Rvalue::Use(operand) => {
if let Some(place) = operand.place()
&& place.is_indirect_first_projection()
{
self.cost += DEREF_PENALTY;
}
}
Rvalue::Repeat(_item, count) => {
let count = count.try_to_target_usize(self.tcx).unwrap_or(u64::MAX);
self.cost += (STMT_BASELINE_COST * count as usize).min(CALL_PENALTY);
}
Rvalue::Aggregate(_kind, fields) => {
self.cost += STMT_BASELINE_COST * fields.len();
}
Rvalue::CheckedBinaryOp(..) => {
self.cost += CHECKED_OP_PENALTY;
}
Rvalue::ThreadLocalRef(..) => {
self.cost += THREAD_LOCAL_PENALTY;
}
Rvalue::Ref(..)
| Rvalue::AddressOf(..)
| Rvalue::Len(..)
| Rvalue::Cast(..)
| Rvalue::BinaryOp(..)
| Rvalue::NullaryOp(..)
| Rvalue::UnaryOp(..)
| Rvalue::Discriminant(..)
| Rvalue::ShallowInitBox(..)
| Rvalue::CopyForDeref(..) => {
// No extra cost for these
}
}
}

Expand All @@ -63,24 +153,35 @@ impl<'tcx> Visitor<'tcx> for CostChecker<'_, 'tcx> {
if let UnwindAction::Cleanup(_) = unwind {
self.cost += LANDINGPAD_PENALTY;
}
} else {
self.cost += INSTR_COST;
}
}
TerminatorKind::Call { func: Operand::Constant(ref f), unwind, .. } => {
let fn_ty = self.instantiate_ty(f.const_.ty());
self.cost += if let ty::FnDef(def_id, _) = *fn_ty.kind()
TerminatorKind::Call { ref func, unwind, .. } => {
if let Some(f) = func.constant()
&& let fn_ty = self.instantiate_ty(f.ty())
&& let ty::FnDef(def_id, _) = *fn_ty.kind()
&& tcx.intrinsic(def_id).is_some()
{
// Don't give intrinsics the extra penalty for calls
INSTR_COST
} else {
CALL_PENALTY
self.cost += CALL_PENALTY;
};
if let UnwindAction::Cleanup(_) = unwind {
self.cost += LANDINGPAD_PENALTY;
}
}
TerminatorKind::SwitchInt { ref discr, ref targets } => {
if let Operand::Constant(..) = discr {
// This'll be a goto once we're monomorphizing
} else {
// 0/1/unreachable is extremely common (bool, Option, Result, ...)
// but once there's more this can be a fair bit of work.
self.cost += if targets.all_targets().len() <= 3 {
SMALL_SWITCH_PENALTY
} else {
LARGE_SWITCH_PENALTY
};
}
}
TerminatorKind::Assert { unwind, .. } => {
self.cost += CALL_PENALTY;
if let UnwindAction::Cleanup(_) = unwind {
Expand All @@ -89,12 +190,20 @@ impl<'tcx> Visitor<'tcx> for CostChecker<'_, 'tcx> {
}
TerminatorKind::UnwindResume => self.cost += RESUME_PENALTY,
TerminatorKind::InlineAsm { unwind, .. } => {
self.cost += INSTR_COST;
if let UnwindAction::Cleanup(_) = unwind {
self.cost += LANDINGPAD_PENALTY;
}
}
_ => self.cost += INSTR_COST,
TerminatorKind::Goto { .. }
| TerminatorKind::UnwindTerminate(..)
| TerminatorKind::Return
| TerminatorKind::Yield { .. }
| TerminatorKind::CoroutineDrop
| TerminatorKind::FalseEdge { .. }
| TerminatorKind::FalseUnwind { .. }
| TerminatorKind::Unreachable => {
// No extra cost for these
}
}
}
}
11 changes: 11 additions & 0 deletions compiler/rustc_mir_transform/src/inline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,17 @@ impl<'tcx> Inliner<'tcx> {
let mut checker =
CostChecker::new(self.tcx, self.param_env, Some(callsite.callee), callee_body);

checker.before_body(callee_body);

let baseline_cost = checker.cost();
if baseline_cost > threshold {
debug!(
"NOT inlining {:?} [baseline_cost={} > threshold={}]",
callsite, baseline_cost, threshold
);
return Err("baseline_cost above threshold");
}

// Traverse the MIR manually so we can account for the effects of inlining on the CFG.
let mut work_list = vec![START_BLOCK];
let mut visited = BitSet::new_empty(callee_body.basic_blocks.len());
Expand Down

0 comments on commit 8bbcded

Please sign in to comment.