Skip to content

Use multiple returns in MIR if it saves a block; still have only one in LLVM #138144

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions compiler/rustc_codegen_ssa/src/mir/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
}
}

fn codegen_return_terminator(&mut self, bx: &mut Bx) {
pub(super) fn codegen_return_terminator(&mut self, bx: &mut Bx) {
// Call `va_end` if this is the definition of a C-variadic function.
if self.fn_abi.c_variadic {
// The `VaList` "spoofed" argument is just after all the real arguments.
Expand Down Expand Up @@ -1343,7 +1343,15 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
}

mir::TerminatorKind::Return => {
self.codegen_return_terminator(bx);
match self.return_block {
CachedLlbb::Skip => self.codegen_return_terminator(bx),
CachedLlbb::Some(target) => bx.br(target),
CachedLlbb::None => {
let return_llbb = bx.append_sibling_block("return");
self.return_block = CachedLlbb::Some(return_llbb);
bx.br(return_llbb);
}
}
MergingSucc::False
}

Expand Down
24 changes: 24 additions & 0 deletions compiler/rustc_codegen_ssa/src/mir/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,12 @@ pub struct FunctionCx<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> {
/// Cached terminate upon unwinding block and its reason
terminate_block: Option<(Bx::BasicBlock, UnwindTerminateReason)>,

/// Shared return block, because LLVM would prefer only one `ret`.
///
/// If this is `Skip`, there's only one return in the function (or none at all)
/// so there's no shared return, just the one in the normal BB.
return_block: CachedLlbb<Bx::BasicBlock>,

/// A bool flag for each basic block indicating whether it is a cold block.
/// A cold block is a block that is unlikely to be executed at runtime.
cold_blocks: IndexVec<mir::BasicBlock, bool>,
Expand Down Expand Up @@ -204,6 +210,18 @@ pub fn codegen_mir<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>(
})
.collect();

let return_block = if mir
.basic_blocks
.iter()
.filter(|bbd| matches!(bbd.terminator().kind, mir::TerminatorKind::Return))
.count()
> 1
{
CachedLlbb::None
} else {
CachedLlbb::Skip
};

let mut fx = FunctionCx {
instance,
mir,
Expand All @@ -214,6 +232,7 @@ pub fn codegen_mir<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>(
cached_llbbs,
unreachable_block: None,
terminate_block: None,
return_block,
cleanup_kinds,
landing_pads: IndexVec::from_elem(None, &mir.basic_blocks),
funclets: IndexVec::from_fn_n(|_| None, mir.basic_blocks.len()),
Expand Down Expand Up @@ -308,6 +327,11 @@ pub fn codegen_mir<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>>(
for bb in unreached_blocks.iter() {
fx.codegen_block_as_unreachable(bb);
}

if let CachedLlbb::Some(llbb) = fx.return_block {
let bx = &mut Bx::build(fx.cx, llbb);
fx.codegen_return_terminator(bx);
}
}

/// Produces, for each argument, a `Value` pointing at the
Expand Down
4 changes: 2 additions & 2 deletions compiler/rustc_mir_transform/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -694,7 +694,6 @@ fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
&sroa::ScalarReplacementOfAggregates,
&match_branches::MatchBranchSimplification,
// inst combine is after MatchBranchSimplification to clean up Ne(_1, false)
&multiple_return_terminators::MultipleReturnTerminators,
// After simplifycfg, it allows us to discover new opportunities for peephole
// optimizations.
&instsimplify::InstSimplify::AfterSimplifyCfg,
Expand All @@ -711,14 +710,15 @@ fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
&dest_prop::DestinationPropagation,
&o1(simplify_branches::SimplifyConstCondition::Final),
&o1(remove_noop_landing_pads::RemoveNoopLandingPads),
// Can make blocks unused, so before the last simplify-cfg
&multiple_return_terminators::MultipleReturnTerminators,
&o1(simplify::SimplifyCfg::Final),
// After the last SimplifyCfg, because this wants one-block functions.
&strip_debuginfo::StripDebugInfo,
&copy_prop::CopyProp,
&dead_store_elimination::DeadStoreElimination::Final,
&nrvo::RenameReturnPlace,
&simplify::SimplifyLocals::Final,
&multiple_return_terminators::MultipleReturnTerminators,
&large_enums::EnumSizeOpt { discrepancy: 128 },
// Some cleanup necessary at least for LLVM and potentially other codegen backends.
&add_call_guards::CriticalCallEdges,
Expand Down
47 changes: 29 additions & 18 deletions compiler/rustc_mir_transform/src/multiple_return_terminators.rs
Original file line number Diff line number Diff line change
@@ -1,40 +1,51 @@
//! This pass removes jumps to basic blocks containing only a return, and replaces them with a
//! return instead.

use rustc_index::bit_set::DenseBitSet;
use rustc_middle::mir::*;
use rustc_middle::ty::TyCtxt;

use crate::simplify;
use smallvec::SmallVec;

pub(super) struct MultipleReturnTerminators;

impl<'tcx> crate::MirPass<'tcx> for MultipleReturnTerminators {
fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
sess.mir_opt_level() >= 4
sess.mir_opt_level() >= 2
}

fn run_pass(&self, _: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
// find basic blocks with no statement and a return terminator
let mut bbs_simple_returns = DenseBitSet::new_empty(body.basic_blocks.len());
let bbs = body.basic_blocks_mut();
for idx in bbs.indices() {
if bbs[idx].statements.is_empty()
&& bbs[idx].terminator().kind == TerminatorKind::Return
let mut to_handle = <Vec<(BasicBlock, SmallVec<_>)>>::new();
for (bb, bbdata) in body.basic_blocks.iter_enumerated() {
// Look for returns where, if we lift them into the parents, we can save a block.
if let TerminatorKind::Return = bbdata.terminator().kind
&& bbdata
.statements
.iter()
.all(|stmt| matches!(stmt.kind, StatementKind::StorageDead(_)))
&& let predecessors = &body.basic_blocks.predecessors()[bb]
&& predecessors.len() >= 2
&& predecessors.iter().all(|pred| {
matches!(
body.basic_blocks[*pred].terminator().kind,
TerminatorKind::Goto { .. },
)
})
{
bbs_simple_returns.insert(idx);
to_handle.push((bb, predecessors.clone()));
}
}

for bb in bbs {
if let TerminatorKind::Goto { target } = bb.terminator().kind {
if bbs_simple_returns.contains(target) {
bb.terminator_mut().kind = TerminatorKind::Return;
}
}
if to_handle.is_empty() {
return;
}

simplify::remove_dead_blocks(body)
let bbs = body.basic_blocks_mut();
for (succ, predecessors) in to_handle {
for pred in predecessors {
let (pred_block, succ_block) = bbs.pick2_mut(pred, succ);
pred_block.statements.extend(succ_block.statements.iter().cloned());
*pred_block.terminator_mut() = succ_block.terminator().clone();
}
}
}

fn is_required(&self) -> bool {
Expand Down
6 changes: 3 additions & 3 deletions tests/codegen/asm/goto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ pub unsafe fn asm_goto() {
pub unsafe fn asm_goto_with_outputs() -> u64 {
let out: u64;
// CHECK: [[RES:%[0-9]+]] = callbr i64 asm sideeffect alignstack inteldialect "
// CHECK-NEXT: to label %[[FALLTHROUGHBB:[a-b0-9]+]] [label %[[JUMPBB:[a-b0-9]+]]]
// CHECK-NEXT: to label %[[FALLTHROUGHBB:[a-b0-9]+]] [label %[[JUMPBB:return]]]
asm!("{} /* {} */", out(reg) out, label { return 1; });
// CHECK: [[JUMPBB]]:
// CHECK-NEXT: [[RET:%.+]] = phi i64 [ [[RES]], %[[FALLTHROUGHBB]] ], [ 1, %start ]
Expand All @@ -32,7 +32,7 @@ pub unsafe fn asm_goto_with_outputs() -> u64 {
pub unsafe fn asm_goto_with_outputs_use_in_label() -> u64 {
let out: u64;
// CHECK: [[RES:%[0-9]+]] = callbr i64 asm sideeffect alignstack inteldialect "
// CHECK-NEXT: to label %[[FALLTHROUGHBB:[a-b0-9]+]] [label %[[JUMPBB:[a-b0-9]+]]]
// CHECK-NEXT: to label %[[FALLTHROUGHBB:[a-b0-9]+]] [label %[[JUMPBB:return]]]
asm!("{} /* {} */", out(reg) out, label { return out; });
// CHECK: [[JUMPBB]]:
// CHECK-NEXT: [[RET:%.+]] = phi i64 [ 1, %[[FALLTHROUGHBB]] ], [ [[RES]], %start ]
Expand All @@ -55,7 +55,7 @@ pub unsafe fn asm_goto_noreturn() -> u64 {
pub unsafe fn asm_goto_noreturn_with_outputs() -> u64 {
let out: u64;
// CHECK: [[RES:%[0-9]+]] = callbr i64 asm sideeffect alignstack inteldialect "
// CHECK-NEXT: to label %[[FALLTHROUGHBB:[a-b0-9]+]] [label %[[JUMPBB:[a-b0-9]+]]]
// CHECK-NEXT: to label %[[FALLTHROUGHBB:return]] [label %[[JUMPBB:return]]]
asm!("mov {}, 1", "jmp {}", out(reg) out, label { return out; });
// CHECK: [[JUMPBB]]:
// CHECK-NEXT: ret i64 [[RES]]
Expand Down
2 changes: 1 addition & 1 deletion tests/codegen/issues/issue-112509-slice-get-andthen-get.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

// CHECK-LABEL: @write_u8_variant_a
// CHECK-NEXT: {{.*}}:
// CHECK-NEXT: icmp ugt
// CHECK-NEXT: getelementptr
// CHECK-NEXT: icmp ugt
// CHECK-NEXT: select i1 {{.+}} null
// CHECK-NEXT: insertvalue
// CHECK-NEXT: insertvalue
Expand Down
34 changes: 34 additions & 0 deletions tests/codegen/multiple_returns.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
//@ compile-flags: -Copt-level=3 -C no-prepopulate-passes

#![crate_type = "lib"]

// CHECK-LABEL: @simple_is_one_block
#[no_mangle]
pub unsafe fn simple_is_one_block(x: i32) -> i32 {
// CHECK: start:
// CHECK-NEXT: ret i32 %x

// CHECK-NOT: return

x
}

// CHECK-LABEL: @branch_has_shared_block
#[no_mangle]
pub unsafe fn branch_has_shared_block(b: bool) -> i32 {
// CHECK: start:
// CHECK-NEXT: %[[A:.+]] = alloca [4 x i8]
// CHECK-NEXT: br i1 %b

// CHECK: store i32 {{42|2015}}, ptr %[[A]]
// CHECK-NEXT: br label %return

// CHECK: store i32 {{42|2015}}, ptr %[[A]]
// CHECK-NEXT: br label %return

// CHECK: return:
// CHECK-NEXT: %[[R:.+]] = load i32, ptr %[[A]]
// CHECK-NEXT: ret i32 %[[R]]

if b { 42 } else { 2015 }
}
25 changes: 15 additions & 10 deletions tests/mir-opt/inline/issue_106141.outer.Inline.panic-abort.diff
Original file line number Diff line number Diff line change
Expand Up @@ -20,30 +20,35 @@
+ StorageLive(_1);
+ StorageLive(_2);
+ _1 = const inner::promoted[0];
+ _0 = index() -> [return: bb1, unwind unreachable];
+ _0 = index() -> [return: bb2, unwind unreachable];
}

bb1: {
+ StorageDead(_2);
+ StorageDead(_1);
return;
+ }
+
+ bb2: {
+ StorageLive(_3);
+ _2 = Lt(copy _0, const 1_usize);
+ assert(move _2, "index out of bounds: the length is {} but the index is {}", const 1_usize, copy _0) -> [success: bb2, unwind unreachable];
+ assert(move _2, "index out of bounds: the length is {} but the index is {}", const 1_usize, copy _0) -> [success: bb3, unwind unreachable];
+ }
+
+ bb2: {
+ bb3: {
+ _3 = copy (*_1)[_0];
+ switchInt(move _3) -> [0: bb3, otherwise: bb4];
+ switchInt(move _3) -> [0: bb4, otherwise: bb5];
+ }
+
+ bb3: {
+ bb4: {
+ _0 = const 0_usize;
+ goto -> bb4;
+ StorageDead(_3);
+ goto -> bb1;
+ }
+
+ bb4: {
+ bb5: {
+ StorageDead(_3);
+ StorageDead(_2);
+ StorageDead(_1);
return;
+ goto -> bb1;
}
}

25 changes: 15 additions & 10 deletions tests/mir-opt/inline/issue_106141.outer.Inline.panic-unwind.diff
Original file line number Diff line number Diff line change
Expand Up @@ -20,30 +20,35 @@
+ StorageLive(_1);
+ StorageLive(_2);
+ _1 = const inner::promoted[0];
+ _0 = index() -> [return: bb1, unwind continue];
+ _0 = index() -> [return: bb2, unwind continue];
}

bb1: {
+ StorageDead(_2);
+ StorageDead(_1);
return;
+ }
+
+ bb2: {
+ StorageLive(_3);
+ _2 = Lt(copy _0, const 1_usize);
+ assert(move _2, "index out of bounds: the length is {} but the index is {}", const 1_usize, copy _0) -> [success: bb2, unwind continue];
+ assert(move _2, "index out of bounds: the length is {} but the index is {}", const 1_usize, copy _0) -> [success: bb3, unwind continue];
+ }
+
+ bb2: {
+ bb3: {
+ _3 = copy (*_1)[_0];
+ switchInt(move _3) -> [0: bb3, otherwise: bb4];
+ switchInt(move _3) -> [0: bb4, otherwise: bb5];
+ }
+
+ bb3: {
+ bb4: {
+ _0 = const 0_usize;
+ goto -> bb4;
+ StorageDead(_3);
+ goto -> bb1;
+ }
+
+ bb4: {
+ bb5: {
+ StorageDead(_3);
+ StorageDead(_2);
+ StorageDead(_1);
return;
+ goto -> bb1;
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ fn num_to_digit(_1: char) -> u32 {
_0 = move ((_4 as Some).0: u32);
StorageDead(_5);
StorageDead(_4);
goto -> bb8;
return;
}

bb6: {
Expand All @@ -59,10 +59,6 @@ fn num_to_digit(_1: char) -> u32 {
bb7: {
StorageDead(_3);
_0 = const 0_u32;
goto -> bb8;
}

bb8: {
return;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ fn num_to_digit(_1: char) -> u32 {
_0 = move ((_4 as Some).0: u32);
StorageDead(_5);
StorageDead(_4);
goto -> bb8;
return;
}

bb6: {
Expand All @@ -59,10 +59,6 @@ fn num_to_digit(_1: char) -> u32 {
bb7: {
StorageDead(_3);
_0 = const 0_u32;
goto -> bb8;
}

bb8: {
return;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ fn num_to_digit(_1: char) -> u32 {
_0 = move ((_4 as Some).0: u32);
StorageDead(_5);
StorageDead(_4);
goto -> bb8;
return;
}

bb6: {
Expand All @@ -59,10 +59,6 @@ fn num_to_digit(_1: char) -> u32 {
bb7: {
StorageDead(_3);
_0 = const 0_u32;
goto -> bb8;
}

bb8: {
return;
}
}
Loading
Loading