Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework MIR inlining costs #123179

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 128 additions & 19 deletions compiler/rustc_mir_transform/src/cost_checker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,30 @@ use rustc_middle::mir::visit::*;
use rustc_middle::mir::*;
use rustc_middle::ty::{self, ParamEnv, Ty, TyCtxt};

const INSTR_COST: usize = 5;
const CALL_PENALTY: usize = 25;
const LANDINGPAD_PENALTY: usize = 50;
const RESUME_PENALTY: usize = 45;
// Even if they're zero-cost at runtime, everything has *some* cost to inline
// in terms of copying them into the MIR caller, processing them in codegen, etc.
// These baseline costs give a simple usually-too-low estimate of the cost,
// which will be updated afterwards to account for the "real" costs.
const STMT_BASELINE_COST: usize = 1;
const BLOCK_BASELINE_COST: usize = 3;
const DEBUG_BASELINE_COST: usize = 1;
const LOCAL_BASELINE_COST: usize = 1;

// These penalties represent the cost above baseline for those things which
// have substantially more cost than is typical for their kind.
const CALL_PENALTY: usize = 22;
const LANDINGPAD_PENALTY: usize = 47;
const RESUME_PENALTY: usize = 42;
const DEREF_PENALTY: usize = 4;
const CHECKED_OP_PENALTY: usize = 2;
const THREAD_LOCAL_PENALTY: usize = 20;
const SMALL_SWITCH_PENALTY: usize = 3;
const LARGE_SWITCH_PENALTY: usize = 20;

// Passing arguments isn't free, so give a bonus to functions with lots of them:
// if the body is small despite lots of arguments, some are probably unused.
const EXTRA_ARG_BONUS: usize = 4;
const MAX_ARG_BONUS: usize = CALL_PENALTY;

/// Verify that the callee body is compatible with the caller.
#[derive(Clone)]
Expand All @@ -27,6 +47,20 @@ impl<'b, 'tcx> CostChecker<'b, 'tcx> {
CostChecker { tcx, param_env, callee_body, instance, cost: 0 }
}

// `Inline` doesn't call `visit_body`, so this is separate from the visitor.
pub fn before_body(&mut self, body: &Body<'tcx>) {
self.cost += BLOCK_BASELINE_COST * body.basic_blocks.len();
self.cost += DEBUG_BASELINE_COST * body.var_debug_info.len();
self.cost += LOCAL_BASELINE_COST * body.local_decls.len();

let total_statements = body.basic_blocks.iter().map(|x| x.statements.len()).sum::<usize>();
self.cost += STMT_BASELINE_COST * total_statements;

if let Some(extra_args) = body.arg_count.checked_sub(2) {
self.cost = self.cost.saturating_sub((EXTRA_ARG_BONUS * extra_args).min(MAX_ARG_BONUS));
}
}

pub fn cost(&self) -> usize {
self.cost
}
Expand All @@ -41,14 +75,70 @@ impl<'b, 'tcx> CostChecker<'b, 'tcx> {
}

impl<'tcx> Visitor<'tcx> for CostChecker<'_, 'tcx> {
fn visit_statement(&mut self, statement: &Statement<'tcx>, _: Location) {
// Don't count StorageLive/StorageDead in the inlining cost.
match statement.kind {
StatementKind::StorageLive(_)
fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) {
match &statement.kind {
StatementKind::Assign(place_and_rvalue) => {
if place_and_rvalue.0.is_indirect_first_projection() {
self.cost += DEREF_PENALTY;
}
self.visit_rvalue(&place_and_rvalue.1, location);
}
StatementKind::Intrinsic(intr) => match &**intr {
NonDivergingIntrinsic::Assume(..) => {}
NonDivergingIntrinsic::CopyNonOverlapping(_cno) => {
self.cost += CALL_PENALTY;
}
},
StatementKind::FakeRead(..)
| StatementKind::SetDiscriminant { .. }
| StatementKind::StorageLive(_)
| StatementKind::StorageDead(_)
| StatementKind::Retag(..)
| StatementKind::PlaceMention(..)
| StatementKind::AscribeUserType(..)
| StatementKind::Coverage(..)
| StatementKind::Deinit(_)
| StatementKind::Nop => {}
_ => self.cost += INSTR_COST,
| StatementKind::ConstEvalCounter
| StatementKind::Nop => {
// No extra cost for these
}
}
}

fn visit_rvalue(&mut self, rvalue: &Rvalue<'tcx>, _location: Location) {
match rvalue {
Rvalue::Use(operand) => {
if let Some(place) = operand.place()
&& place.is_indirect_first_projection()
{
self.cost += DEREF_PENALTY;
}
}
Rvalue::Repeat(_item, count) => {
let count = count.try_to_target_usize(self.tcx).unwrap_or(u64::MAX);
self.cost += (STMT_BASELINE_COST * count as usize).min(CALL_PENALTY);
}
Rvalue::Aggregate(_kind, fields) => {
self.cost += STMT_BASELINE_COST * fields.len();
}
Rvalue::CheckedBinaryOp(..) => {
self.cost += CHECKED_OP_PENALTY;
}
Rvalue::ThreadLocalRef(..) => {
self.cost += THREAD_LOCAL_PENALTY;
}
Rvalue::Ref(..)
| Rvalue::AddressOf(..)
| Rvalue::Len(..)
| Rvalue::Cast(..)
| Rvalue::BinaryOp(..)
| Rvalue::NullaryOp(..)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example, _1 = sizeof(T) is now just cost 1 (statement baseline) instead of the previous 5 (instr_cost).

| Rvalue::UnaryOp(..)
| Rvalue::Discriminant(..)
| Rvalue::ShallowInitBox(..)
| Rvalue::CopyForDeref(..) => {
// No extra cost for these
}
}
}

Expand All @@ -63,24 +153,35 @@ impl<'tcx> Visitor<'tcx> for CostChecker<'_, 'tcx> {
if let UnwindAction::Cleanup(_) = unwind {
self.cost += LANDINGPAD_PENALTY;
}
} else {
self.cost += INSTR_COST;
}
}
TerminatorKind::Call { func: Operand::Constant(ref f), unwind, .. } => {
Copy link
Member Author

@scottmcm scottmcm Mar 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As an interesting example, Calls to non-constants weren't given the CALL_PENALTY before, because they were hidden down in the _ => arm.

let fn_ty = self.instantiate_ty(f.const_.ty());
self.cost += if let ty::FnDef(def_id, _) = *fn_ty.kind()
TerminatorKind::Call { ref func, unwind, .. } => {
if let Some(f) = func.constant()
&& let fn_ty = self.instantiate_ty(f.ty())
&& let ty::FnDef(def_id, _) = *fn_ty.kind()
&& tcx.intrinsic(def_id).is_some()
{
// Don't give intrinsics the extra penalty for calls
INSTR_COST
} else {
CALL_PENALTY
self.cost += CALL_PENALTY;
};
if let UnwindAction::Cleanup(_) = unwind {
self.cost += LANDINGPAD_PENALTY;
}
}
TerminatorKind::SwitchInt { ref discr, ref targets } => {
if let Operand::Constant(..) = discr {
// This'll be a goto once we're monomorphizing
} else {
// 0/1/unreachable is extremely common (bool, Option, Result, ...)
// but once there's more this can be a fair bit of work.
self.cost += if targets.all_targets().len() <= 3 {
SMALL_SWITCH_PENALTY
} else {
LARGE_SWITCH_PENALTY
};
}
}
TerminatorKind::Assert { unwind, .. } => {
self.cost += CALL_PENALTY;
if let UnwindAction::Cleanup(_) = unwind {
Expand All @@ -89,12 +190,20 @@ impl<'tcx> Visitor<'tcx> for CostChecker<'_, 'tcx> {
}
TerminatorKind::UnwindResume => self.cost += RESUME_PENALTY,
TerminatorKind::InlineAsm { unwind, .. } => {
self.cost += INSTR_COST;
if let UnwindAction::Cleanup(_) = unwind {
self.cost += LANDINGPAD_PENALTY;
}
}
_ => self.cost += INSTR_COST,
TerminatorKind::Goto { .. }
| TerminatorKind::UnwindTerminate(..)
| TerminatorKind::Return
| TerminatorKind::Yield { .. }
| TerminatorKind::CoroutineDrop
| TerminatorKind::FalseEdge { .. }
| TerminatorKind::FalseUnwind { .. }
| TerminatorKind::Unreachable => {
// No extra cost for these
}
}
}
}
11 changes: 11 additions & 0 deletions compiler/rustc_mir_transform/src/inline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,17 @@ impl<'tcx> Inliner<'tcx> {
let mut checker =
CostChecker::new(self.tcx, self.param_env, Some(callsite.callee), callee_body);

checker.before_body(callee_body);

let baseline_cost = checker.cost();
if baseline_cost > threshold {
debug!(
"NOT inlining {:?} [baseline_cost={} > threshold={}]",
callsite, baseline_cost, threshold
);
return Err("baseline_cost above threshold");
}

// Traverse the MIR manually so we can account for the effects of inlining on the CFG.
let mut work_list = vec![START_BLOCK];
let mut visited = BitSet::new_empty(callee_body.basic_blocks.len());
Expand Down
Loading