Skip to content

Inline drops in MIR #142583

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
235 changes: 145 additions & 90 deletions compiler/rustc_mir_transform/src/inline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use std::ops::{Range, RangeFrom};

use rustc_abi::{ExternAbi, FieldIdx};
use rustc_attr_data_structures::{InlineAttr, OptimizeAttr};
use rustc_hir::LangItem;
use rustc_hir::def::DefKind;
use rustc_hir::def_id::DefId;
use rustc_index::Idx;
Expand Down Expand Up @@ -553,46 +554,53 @@ fn resolve_callsite<'tcx, I: Inliner<'tcx>>(
let terminator = bb_data.terminator();

// FIXME(explicit_tail_calls): figure out if we can inline tail calls
if let TerminatorKind::Call { ref func, fn_span, .. } = terminator.kind {
let func_ty = func.ty(caller_body, tcx);
if let ty::FnDef(def_id, args) = *func_ty.kind() {
if !inliner.should_inline_for_callee(def_id) {
debug!("not enabled");
return None;
}

// To resolve an instance its args have to be fully normalized.
let args = tcx.try_normalize_erasing_regions(inliner.typing_env(), args).ok()?;
let callee =
Instance::try_resolve(tcx, inliner.typing_env(), def_id, args).ok().flatten()?;
let (def_id, args, fn_span) = match terminator.kind {
TerminatorKind::Call { ref func, fn_span, .. } => {
let func_ty = func.ty(caller_body, tcx);
let ty::FnDef(def_id, args) = *func_ty.kind() else { return None };
(def_id, args, fn_span)
}
TerminatorKind::Drop { place, .. } => {
let ty = place.ty(caller_body, tcx).ty;
let def_id = tcx.require_lang_item(LangItem::DropInPlace, terminator.source_info.span);
let args = tcx.mk_args(&[ty.into()]);
(def_id, args, terminator.source_info.span)
}
_ => return None,
};

if let InstanceKind::Virtual(..) | InstanceKind::Intrinsic(_) = callee.def {
return None;
}
if !inliner.should_inline_for_callee(def_id) {
debug!("not enabled");
return None;
}

if inliner.history().contains(&callee.def_id()) {
return None;
}
// To resolve an instance its args have to be fully normalized.
let args = tcx.try_normalize_erasing_regions(inliner.typing_env(), args).ok()?;
let callee = Instance::try_resolve(tcx, inliner.typing_env(), def_id, args).ok().flatten()?;

let fn_sig = tcx.fn_sig(def_id).instantiate(tcx, args);
if let InstanceKind::Virtual(..) | InstanceKind::Intrinsic(_) = callee.def {
return None;
}

// Additionally, check that the body that we're inlining actually agrees
// with the ABI of the trait that the item comes from.
if let InstanceKind::Item(instance_def_id) = callee.def
&& tcx.def_kind(instance_def_id) == DefKind::AssocFn
&& let instance_fn_sig = tcx.fn_sig(instance_def_id).skip_binder()
&& instance_fn_sig.abi() != fn_sig.abi()
{
return None;
}
if inliner.history().contains(&callee.def_id()) {
return None;
}

let source_info = SourceInfo { span: fn_span, ..terminator.source_info };
let fn_sig = tcx.fn_sig(def_id).instantiate(tcx, args);

return Some(CallSite { callee, fn_sig, block: bb, source_info });
}
// Additionally, check that the body that we're inlining actually agrees
// with the ABI of the trait that the item comes from.
if let InstanceKind::Item(instance_def_id) = callee.def
&& tcx.def_kind(instance_def_id) == DefKind::AssocFn
&& let instance_fn_sig = tcx.fn_sig(instance_def_id).skip_binder()
&& instance_fn_sig.abi() != fn_sig.abi()
{
return None;
}

None
let source_info = SourceInfo { span: fn_span, ..terminator.source_info };

Some(CallSite { callee, fn_sig, block: bb, source_info })
}

/// Attempts to inline a callsite into the caller body. When successful returns basic blocks
Expand All @@ -603,6 +611,23 @@ fn try_inlining<'tcx, I: Inliner<'tcx>>(
caller_body: &mut Body<'tcx>,
callsite: &CallSite<'tcx>,
) -> Result<std::ops::Range<BasicBlock>, &'static str> {
// Fast path to inline trivial drops.
if let InstanceKind::DropGlue(_, None) = callsite.callee.def {
let terminator = caller_body[callsite.block].terminator_mut();
let target = match terminator.kind {
TerminatorKind::Call { target, .. } => target,
TerminatorKind::Drop { target, .. } => Some(target),
_ => bug!("unexpected terminator kind {:?}", terminator.kind),
};
if let Some(target) = target {
terminator.kind = TerminatorKind::Goto { target };
} else {
terminator.kind = TerminatorKind::Unreachable;
}
let next_block = caller_body.basic_blocks.next_index();
return Ok(next_block..next_block);
}

let tcx = inliner.tcx();
check_mir_is_available(inliner, caller_body, callsite.callee)?;

Expand All @@ -611,17 +636,6 @@ fn try_inlining<'tcx, I: Inliner<'tcx>>(
check_codegen_attributes(inliner, callsite, callee_attrs)?;
inliner.check_codegen_attributes_extra(callee_attrs)?;

let terminator = caller_body[callsite.block].terminator.as_ref().unwrap();
let TerminatorKind::Call { args, destination, .. } = &terminator.kind else { bug!() };
let destination_ty = destination.ty(&caller_body.local_decls, tcx).ty;
for arg in args {
if !arg.node.ty(&caller_body.local_decls, tcx).is_sized(tcx, inliner.typing_env()) {
// We do not allow inlining functions with unsized params. Inlining these functions
// could create unsized locals, which are unsound and being phased out.
return Err("call has unsized argument");
}
}

let callee_body = try_instance_mir(tcx, callsite.callee.def)?;
check_inline::is_inline_valid_on_body(tcx, callee_body)?;
inliner.check_callee_mir_body(callsite, callee_body, callee_attrs)?;
Expand All @@ -642,54 +656,73 @@ fn try_inlining<'tcx, I: Inliner<'tcx>>(
return Err("implementation limitation -- callee body failed validation");
}

// Check call signature compatibility.
// Normally, this shouldn't be required, but trait normalization failure can create a
// validation ICE.
let output_type = callee_body.return_ty();
if !util::sub_types(tcx, inliner.typing_env(), output_type, destination_ty) {
trace!(?output_type, ?destination_ty);
return Err("implementation limitation -- return type mismatch");
}
if callsite.fn_sig.abi() == ExternAbi::RustCall {
let (self_arg, arg_tuple) = match &args[..] {
[arg_tuple] => (None, arg_tuple),
[self_arg, arg_tuple] => (Some(self_arg), arg_tuple),
_ => bug!("Expected `rust-call` to have 1 or 2 args"),
};
let terminator = caller_body[callsite.block].terminator.as_ref().unwrap();
match &terminator.kind {
TerminatorKind::Call { args, destination, .. } => {
let destination_ty = destination.ty(&caller_body.local_decls, tcx).ty;
for arg in args {
if !arg.node.ty(&caller_body.local_decls, tcx).is_sized(tcx, inliner.typing_env()) {
// We do not allow inlining functions with unsized params. Inlining these functions
// could create unsized locals, which are unsound and being phased out.
return Err("call has unsized argument");
}
}

let self_arg_ty = self_arg.map(|self_arg| self_arg.node.ty(&caller_body.local_decls, tcx));
// Check call signature compatibility.
// Normally, this shouldn't be required, but trait normalization failure can create a
// validation ICE.
let output_type = callee_body.return_ty();
if !util::sub_types(tcx, inliner.typing_env(), output_type, destination_ty) {
trace!(?output_type, ?destination_ty);
return Err("implementation limitation -- return type mismatch");
}
if callsite.fn_sig.abi() == ExternAbi::RustCall {
let (self_arg, arg_tuple) = match &args[..] {
[arg_tuple] => (None, arg_tuple),
[self_arg, arg_tuple] => (Some(self_arg), arg_tuple),
_ => bug!("Expected `rust-call` to have 1 or 2 args"),
};

let arg_tuple_ty = arg_tuple.node.ty(&caller_body.local_decls, tcx);
let arg_tys = if callee_body.spread_arg.is_some() {
std::slice::from_ref(&arg_tuple_ty)
} else {
let ty::Tuple(arg_tuple_tys) = *arg_tuple_ty.kind() else {
bug!("Closure arguments are not passed as a tuple");
};
arg_tuple_tys.as_slice()
};
let self_arg_ty =
self_arg.map(|self_arg| self_arg.node.ty(&caller_body.local_decls, tcx));

for (arg_ty, input) in
self_arg_ty.into_iter().chain(arg_tys.iter().copied()).zip(callee_body.args_iter())
{
let input_type = callee_body.local_decls[input].ty;
if !util::sub_types(tcx, inliner.typing_env(), input_type, arg_ty) {
trace!(?arg_ty, ?input_type);
debug!("failed to normalize tuple argument type");
return Err("implementation limitation");
}
}
} else {
for (arg, input) in args.iter().zip(callee_body.args_iter()) {
let input_type = callee_body.local_decls[input].ty;
let arg_ty = arg.node.ty(&caller_body.local_decls, tcx);
if !util::sub_types(tcx, inliner.typing_env(), input_type, arg_ty) {
trace!(?arg_ty, ?input_type);
debug!("failed to normalize argument type");
return Err("implementation limitation -- arg mismatch");
let arg_tuple_ty = arg_tuple.node.ty(&caller_body.local_decls, tcx);
let arg_tys = if callee_body.spread_arg.is_some() {
std::slice::from_ref(&arg_tuple_ty)
} else {
let ty::Tuple(arg_tuple_tys) = *arg_tuple_ty.kind() else {
bug!("Closure arguments are not passed as a tuple");
};
arg_tuple_tys.as_slice()
};

for (arg_ty, input) in self_arg_ty
.into_iter()
.chain(arg_tys.iter().copied())
.zip(callee_body.args_iter())
{
let input_type = callee_body.local_decls[input].ty;
if !util::sub_types(tcx, inliner.typing_env(), input_type, arg_ty) {
trace!(?arg_ty, ?input_type);
debug!("failed to normalize tuple argument type");
return Err("implementation limitation");
}
}
} else {
for (arg, input) in args.iter().zip(callee_body.args_iter()) {
let input_type = callee_body.local_decls[input].ty;
let arg_ty = arg.node.ty(&caller_body.local_decls, tcx);
if !util::sub_types(tcx, inliner.typing_env(), input_type, arg_ty) {
trace!(?arg_ty, ?input_type);
debug!("failed to normalize argument type");
return Err("implementation limitation -- arg mismatch");
}
}
}
}
}
TerminatorKind::Drop { .. } => {}
_ => bug!(),
};

let old_blocks = caller_body.basic_blocks.next_index();
inline_call(inliner, caller_body, callsite, callee_body);
Expand Down Expand Up @@ -854,9 +887,31 @@ fn inline_call<'tcx, I: Inliner<'tcx>>(
) {
let tcx = inliner.tcx();
let terminator = caller_body[callsite.block].terminator.take().unwrap();
let TerminatorKind::Call { func, args, destination, unwind, target, .. } = terminator.kind
else {
bug!("unexpected terminator kind {:?}", terminator.kind);
let (args, destination, unwind, target) = match terminator.kind {
TerminatorKind::Call { args, destination, unwind, target, .. } => {
(args, destination, unwind, target)
}
TerminatorKind::Drop { place, unwind, target, .. } => {
// `drop_in_place` takes a `*mut`, so we need to take the address to pass it.
let place_ty = place.ty(caller_body, tcx).ty;
let place_addr_ty = Ty::new_mut_ptr(tcx, place_ty);
let arg_place: Place<'tcx> =
new_call_temp(caller_body, callsite, place_addr_ty, Some(target)).into();
caller_body[callsite.block].statements.push(Statement {
source_info: callsite.source_info,
kind: StatementKind::Assign(Box::new((
arg_place,
Rvalue::RawPtr(RawPtrKind::Mut, place),
))),
});
let arg = Spanned { span: terminator.source_info.span, node: Operand::Move(arg_place) };

// Create a dummy destination place as calls have one.
let destination: Place<'tcx> =
new_call_temp(caller_body, callsite, tcx.types.unit, Some(target)).into();
(vec![arg].into_boxed_slice(), destination, unwind, Some(target))
}
_ => bug!("unexpected terminator kind {:?}", terminator.kind),
};

let return_block = if let Some(block) = target {
Expand Down Expand Up @@ -1014,7 +1069,7 @@ fn inline_call<'tcx, I: Inliner<'tcx>>(
// the actually used items. By doing this we can entirely avoid visiting the callee!
// We need to reconstruct the `required_item` for the callee so that we can find and
// remove it.
let callee_item = MentionedItem::Fn(func.ty(caller_body, tcx));
let callee_item = MentionedItem::Fn(callsite.callee.ty(tcx, inliner.typing_env()));
let caller_mentioned_items = caller_body.mentioned_items.as_mut().unwrap();
if let Some(idx) = caller_mentioned_items.iter().position(|item| item.node == callee_item) {
// We found the callee, so remove it and add its items instead.
Expand Down
9 changes: 7 additions & 2 deletions tests/mir-opt/c_unwind_terminate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,14 @@ fn panic() {
// EMIT_MIR c_unwind_terminate.test.AbortUnwindingCalls.after.mir
extern "C" fn test() {
// CHECK-LABEL: fn test(
// CHECK: panic
// CHECK-SAME: unwind: [[panic_unwind:bb.*]]]
// CHECK: drop
// CHECK-SAME: unwind: [[unwind:bb.*]]]
// CHECK: [[unwind]] (cleanup)
// CHECK-SAME: unwind: [[drop_unwind:bb.*]]]
// CHECK: [[panic_unwind]] (cleanup)
// CHECK-NEXT: drop
// CHECK-SAME: terminate(cleanup)
// CHECK: [[drop_unwind]] (cleanup)
// CHECK-NEXT: terminate(abi)
let _val = Noise;
panic();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
let mut _0: ();
let _1: A;
let mut _2: std::boxed::Box<[bool]>;
let mut _8: *mut A;
let mut _9: ();
scope 1 {
debug a => _1;
}
Expand Down Expand Up @@ -37,6 +39,8 @@
}
}
}
scope 14 (inlined drop_in_place::<A> - shim(Some(A))) {
}

bb0: {
StorageLive(_1);
Expand All @@ -60,10 +64,15 @@
_1 = const A {{ foo: Box::<[bool]>(Unique::<[bool]> {{ pointer: NonNull::<[bool]> {{ pointer: Indirect { alloc_id: ALLOC2, offset: Size(0 bytes) }: *const [bool] }}, _marker: PhantomData::<[bool]> }}, std::alloc::Global) }};
StorageDead(_2);
_0 = const ();
drop(_1) -> [return: bb1, unwind unreachable];
StorageLive(_8);
_8 = &raw mut _1;
StorageLive(_9);
drop(((*_8).0: std::boxed::Box<[bool]>)) -> [return: bb1, unwind unreachable];
}

bb1: {
StorageDead(_9);
StorageDead(_8);
StorageDead(_1);
return;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
let mut _0: ();
let _1: A;
let mut _2: std::boxed::Box<[bool]>;
let mut _8: *mut A;
let mut _9: ();
scope 1 {
debug a => _1;
}
Expand Down Expand Up @@ -37,6 +39,8 @@
}
}
}
scope 14 (inlined drop_in_place::<A> - shim(Some(A))) {
}

bb0: {
StorageLive(_1);
Expand All @@ -60,16 +64,21 @@
_1 = const A {{ foo: Box::<[bool]>(Unique::<[bool]> {{ pointer: NonNull::<[bool]> {{ pointer: Indirect { alloc_id: ALLOC2, offset: Size(0 bytes) }: *const [bool] }}, _marker: PhantomData::<[bool]> }}, std::alloc::Global) }};
StorageDead(_2);
_0 = const ();
drop(_1) -> [return: bb1, unwind: bb2];
StorageLive(_8);
_8 = &raw mut _1;
StorageLive(_9);
drop(((*_8).0: std::boxed::Box<[bool]>)) -> [return: bb2, unwind: bb1];
}

bb1: {
StorageDead(_1);
return;
bb1 (cleanup): {
resume;
}

bb2 (cleanup): {
resume;
bb2: {
StorageDead(_9);
StorageDead(_8);
StorageDead(_1);
return;
}
}

Expand Down
Loading
Loading