Skip to content

Commit

Permalink
Factor out some repetitive code.
Browse files Browse the repository at this point in the history
  • Loading branch information
nnethercote committed Aug 30, 2024
1 parent 408481f commit 590a021
Showing 1 changed file with 38 additions and 81 deletions.
119 changes: 38 additions & 81 deletions compiler/rustc_mir_transform/src/coroutine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,9 @@ use rustc_index::bit_set::{BitMatrix, BitSet, GrowableBitSet};
use rustc_index::{Idx, IndexVec};
use rustc_middle::mir::visit::{MutVisitor, PlaceContext, Visitor};
use rustc_middle::mir::*;
use rustc_middle::ty::{self, CoroutineArgs, CoroutineArgsExt, InstanceKind, Ty, TyCtxt};
use rustc_middle::ty::{
self, CoroutineArgs, CoroutineArgsExt, GenericArgsRef, InstanceKind, Ty, TyCtxt,
};
use rustc_middle::{bug, span_bug};
use rustc_mir_dataflow::impls::{
MaybeBorrowedLocals, MaybeLiveLocals, MaybeRequiresStorage, MaybeStorageLive,
Expand Down Expand Up @@ -210,14 +212,10 @@ impl<'tcx> TransformVisitor<'tcx> {
// `gen` continues return `None`
CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {
let option_def_id = self.tcx.require_lang_item(LangItem::Option, None);
Rvalue::Aggregate(
Box::new(AggregateKind::Adt(
option_def_id,
VariantIdx::ZERO,
self.tcx.mk_args(&[self.old_yield_ty.into()]),
None,
None,
)),
make_aggregate_adt(
option_def_id,
VariantIdx::ZERO,
self.tcx.mk_args(&[self.old_yield_ty.into()]),
IndexVec::new(),
)
}
Expand Down Expand Up @@ -266,64 +264,28 @@ impl<'tcx> TransformVisitor<'tcx> {
is_return: bool,
statements: &mut Vec<Statement<'tcx>>,
) {
const ZERO: VariantIdx = VariantIdx::ZERO;
const ONE: VariantIdx = VariantIdx::from_usize(1);
let rvalue = match self.coroutine_kind {
CoroutineKind::Desugared(CoroutineDesugaring::Async, _) => {
let poll_def_id = self.tcx.require_lang_item(LangItem::Poll, None);
let args = self.tcx.mk_args(&[self.old_ret_ty.into()]);
if is_return {
// Poll::Ready(val)
Rvalue::Aggregate(
Box::new(AggregateKind::Adt(
poll_def_id,
VariantIdx::ZERO,
args,
None,
None,
)),
IndexVec::from_raw(vec![val]),
)
let (variant_idx, operands) = if is_return {
(ZERO, IndexVec::from_raw(vec![val])) // Poll::Ready(val)
} else {
// Poll::Pending
Rvalue::Aggregate(
Box::new(AggregateKind::Adt(
poll_def_id,
VariantIdx::from_usize(1),
args,
None,
None,
)),
IndexVec::new(),
)
}
(ONE, IndexVec::new()) // Poll::Pending
};
make_aggregate_adt(poll_def_id, variant_idx, args, operands)
}
CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {
let option_def_id = self.tcx.require_lang_item(LangItem::Option, None);
let args = self.tcx.mk_args(&[self.old_yield_ty.into()]);
if is_return {
// None
Rvalue::Aggregate(
Box::new(AggregateKind::Adt(
option_def_id,
VariantIdx::ZERO,
args,
None,
None,
)),
IndexVec::new(),
)
let (variant_idx, operands) = if is_return {
(ZERO, IndexVec::new()) // None
} else {
// Some(val)
Rvalue::Aggregate(
Box::new(AggregateKind::Adt(
option_def_id,
VariantIdx::from_usize(1),
args,
None,
None,
)),
IndexVec::from_raw(vec![val]),
)
}
(ONE, IndexVec::from_raw(vec![val])) // Some(val)
};
make_aggregate_adt(option_def_id, variant_idx, args, operands)
}
CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, _) => {
if is_return {
Expand All @@ -349,31 +311,17 @@ impl<'tcx> TransformVisitor<'tcx> {
let coroutine_state_def_id =
self.tcx.require_lang_item(LangItem::CoroutineState, None);
let args = self.tcx.mk_args(&[self.old_yield_ty.into(), self.old_ret_ty.into()]);
if is_return {
// CoroutineState::Complete(val)
Rvalue::Aggregate(
Box::new(AggregateKind::Adt(
coroutine_state_def_id,
VariantIdx::from_usize(1),
args,
None,
None,
)),
IndexVec::from_raw(vec![val]),
)
let variant_idx = if is_return {
ONE // CoroutineState::Complete(val)
} else {
// CoroutineState::Yielded(val)
Rvalue::Aggregate(
Box::new(AggregateKind::Adt(
coroutine_state_def_id,
VariantIdx::ZERO,
args,
None,
None,
)),
IndexVec::from_raw(vec![val]),
)
}
ZERO // CoroutineState::Yielded(val)
};
make_aggregate_adt(
coroutine_state_def_id,
variant_idx,
args,
IndexVec::from_raw(vec![val]),
)
}
};

Expand Down Expand Up @@ -509,6 +457,15 @@ impl<'tcx> MutVisitor<'tcx> for TransformVisitor<'tcx> {
}
}

fn make_aggregate_adt<'tcx>(
def_id: DefId,
variant_idx: VariantIdx,
args: GenericArgsRef<'tcx>,
operands: IndexVec<FieldIdx, Operand<'tcx>>,
) -> Rvalue<'tcx> {
Rvalue::Aggregate(Box::new(AggregateKind::Adt(def_id, variant_idx, args, None, None)), operands)
}

fn make_coroutine_state_argument_indirect<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
let coroutine_ty = body.local_decls.raw[1].ty;

Expand Down

0 comments on commit 590a021

Please sign in to comment.