Skip to content

Do not count inline(always) in inlining depth control #108788

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 86 additions & 61 deletions compiler/rustc_mir_transform/src/inline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ fn inline<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) -> bool {
changed: false,
};
let blocks = START_BLOCK..body.basic_blocks.next_index();
this.process_blocks(body, blocks);
this.process_blocks(body, blocks, 0);
this.changed
}

Expand All @@ -115,10 +115,15 @@ struct Inliner<'tcx> {
}

impl<'tcx> Inliner<'tcx> {
fn process_blocks(&mut self, caller_body: &mut Body<'tcx>, blocks: Range<BasicBlock>) {
fn process_blocks(
&mut self,
caller_body: &mut Body<'tcx>,
blocks: Range<BasicBlock>,
depth: usize,
) {
// How many callsites in this body are we allowed to inline? We need to limit this in order
// to prevent super-linear growth in MIR size
let inline_limit = match self.history.len() {
let inline_limit = match depth {
0 => usize::MAX,
1..=TOP_DOWN_DEPTH_LIMIT => 1,
_ => return,
Expand All @@ -137,26 +142,35 @@ impl<'tcx> Inliner<'tcx> {
let span = trace_span!("process_blocks", %callsite.callee, ?bb);
let _guard = span.enter();

match self.try_inlining(caller_body, &callsite) {
let callee_attrs = self.tcx.codegen_fn_attrs(callsite.callee.def_id());
let inline_always = matches!(callee_attrs.inline, InlineAttr::Always);

if inlined_count >= inline_limit && !inline_always {
debug!("inline count reached");
continue;
}

let new_blocks = match self.try_inlining(caller_body, &callsite, callee_attrs) {
Ok(new_blocks) => new_blocks,
Err(reason) => {
debug!("not-inlined {} [{}]", callsite.callee, reason);
continue;
}
Ok(new_blocks) => {
debug!("inlined {}", callsite.callee);
self.changed = true;
};

self.history.push(callsite.callee.def_id());
self.process_blocks(caller_body, new_blocks);
self.history.pop();
debug!("inlined {}", callsite.callee);
self.changed = true;

inlined_count += 1;
if inlined_count == inline_limit {
debug!("inline count reached");
return;
}
}
}
let new_depth = if inline_always {
// This call was `inline(always)`. Do not count it in the inline cost.
depth
} else {
inlined_count += 1;
depth + 1
};
self.history.push(callsite.callee.def_id());
self.process_blocks(caller_body, new_blocks, new_depth);
self.history.pop();
}
}

Expand All @@ -167,13 +181,12 @@ impl<'tcx> Inliner<'tcx> {
&self,
caller_body: &mut Body<'tcx>,
callsite: &CallSite<'tcx>,
callee_attrs: &CodegenFnAttrs,
) -> Result<std::ops::Range<BasicBlock>, &'static str> {
let callee_attrs = self.tcx.codegen_fn_attrs(callsite.callee.def_id());
self.check_codegen_attributes(callsite, callee_attrs)?;

let terminator = caller_body[callsite.block].terminator.as_ref().unwrap();
let TerminatorKind::Call { args, destination, .. } = &terminator.kind else { bug!() };
let destination_ty = destination.ty(&caller_body.local_decls, self.tcx).ty;
let TerminatorKind::Call { args, .. } = &terminator.kind else { bug!() };
for arg in args {
if !arg.ty(&caller_body.local_decls, self.tcx).is_sized(self.tcx, self.param_env) {
// We do not allow inlining functions with unsized params. Inlining these functions
Expand All @@ -184,7 +197,7 @@ impl<'tcx> Inliner<'tcx> {

self.check_mir_is_available(caller_body, &callsite.callee)?;
let callee_body = try_instance_mir(self.tcx, callsite.callee.def)?;
self.check_mir_body(callsite, callee_body, callee_attrs)?;
self.check_mir_body(caller_body, callsite, callee_body, callee_attrs)?;

if !self.tcx.consider_optimizing(|| {
format!("Inline {:?} into {:?}", callsite.callee, caller_body.source)
Expand All @@ -200,46 +213,6 @@ impl<'tcx> Inliner<'tcx> {
return Err("failed to normalize callee body");
};

// Check call signature compatibility.
// Normally, this shouldn't be required, but trait normalization failure can create a
// validation ICE.
let output_type = callee_body.return_ty();
if !util::is_subtype(self.tcx, self.param_env, output_type, destination_ty) {
trace!(?output_type, ?destination_ty);
return Err("failed to normalize return type");
}
if callsite.fn_sig.abi() == Abi::RustCall {
let (arg_tuple, skipped_args) = match &args[..] {
[arg_tuple] => (arg_tuple, 0),
[_, arg_tuple] => (arg_tuple, 1),
_ => bug!("Expected `rust-call` to have 1 or 2 args"),
};

let arg_tuple_ty = arg_tuple.ty(&caller_body.local_decls, self.tcx);
let ty::Tuple(arg_tuple_tys) = arg_tuple_ty.kind() else {
bug!("Closure arguments are not passed as a tuple");
};

for (arg_ty, input) in
arg_tuple_tys.iter().zip(callee_body.args_iter().skip(skipped_args))
{
let input_type = callee_body.local_decls[input].ty;
if !util::is_subtype(self.tcx, self.param_env, input_type, arg_ty) {
trace!(?arg_ty, ?input_type);
return Err("failed to normalize tuple argument type");
}
}
} else {
for (arg, input) in args.iter().zip(callee_body.args_iter()) {
let input_type = callee_body.local_decls[input].ty;
let arg_ty = arg.ty(&caller_body.local_decls, self.tcx);
if !util::is_subtype(self.tcx, self.param_env, input_type, arg_ty) {
trace!(?arg_ty, ?input_type);
return Err("failed to normalize argument type");
}
}
}

let old_blocks = caller_body.basic_blocks.next_index();
self.inline_call(caller_body, &callsite, callee_body);
let new_blocks = old_blocks..caller_body.basic_blocks.next_index();
Expand Down Expand Up @@ -409,6 +382,7 @@ impl<'tcx> Inliner<'tcx> {
#[instrument(level = "debug", skip(self, callee_body))]
fn check_mir_body(
&self,
caller_body: &Body<'tcx>,
callsite: &CallSite<'tcx>,
callee_body: &Body<'tcx>,
callee_attrs: &CodegenFnAttrs,
Expand Down Expand Up @@ -479,6 +453,57 @@ impl<'tcx> Inliner<'tcx> {
// Abort if type validation found anything fishy.
checker.validation?;

let substitute = |ty| {
let ty = ty::EarlyBinder::bind(ty);
callsite
.callee
.try_subst_mir_and_normalize_erasing_regions(self.tcx, self.param_env, ty)
.map_err(|_| "failed to normalize callee body")
};

// Check call signature compatibility.
// Normally, this shouldn't be required, but trait normalization failure can create a
// validation ICE.
let terminator = caller_body[callsite.block].terminator.as_ref().unwrap();
let TerminatorKind::Call { args, destination, .. } = &terminator.kind else { bug!() };
let destination_ty = destination.ty(&caller_body.local_decls, self.tcx).ty;
let output_type = substitute(callee_body.return_ty())?;
if !util::is_subtype(self.tcx, self.param_env, output_type, destination_ty) {
trace!(?output_type, ?destination_ty);
return Err("failed to normalize return type");
}
if callsite.fn_sig.abi() == Abi::RustCall {
let (arg_tuple, skipped_args) = match &args[..] {
[arg_tuple] => (arg_tuple, 0),
[_, arg_tuple] => (arg_tuple, 1),
_ => bug!("Expected `rust-call` to have 1 or 2 args"),
};

let arg_tuple_ty = arg_tuple.ty(&caller_body.local_decls, self.tcx);
let ty::Tuple(arg_tuple_tys) = arg_tuple_ty.kind() else {
bug!("Closure arguments are not passed as a tuple");
};

for (arg_ty, input) in
arg_tuple_tys.iter().zip(callee_body.args_iter().skip(skipped_args))
{
let input_type = substitute(callee_body.local_decls[input].ty)?;
if !util::is_subtype(self.tcx, self.param_env, input_type, arg_ty) {
trace!(?arg_ty, ?input_type);
return Err("failed to normalize tuple argument type");
}
}
} else {
for (arg, input) in args.iter().zip(callee_body.args_iter()) {
let input_type = substitute(callee_body.local_decls[input].ty)?;
let arg_ty = arg.ty(&caller_body.local_decls, self.tcx);
if !util::is_subtype(self.tcx, self.param_env, input_type, arg_ty) {
trace!(?arg_ty, ?input_type);
return Err("failed to normalize argument type");
}
}
}

let cost = checker.cost;
if let InlineAttr::Always = callee_attrs.inline {
debug!("INLINING {:?} because inline(always) [cost={}]", callsite, cost);
Expand Down
Loading