Skip to content

Commit

Permalink
Transform async ResumeTy in generator transform
Browse files Browse the repository at this point in the history
- Eliminates all the `get_context` calls that async lowering created.
- Replace all `Local` `ResumeTy` types with `&mut Context<'_>`.

The `Local`s that have their types replaced are:
- The `resume` argument itself.
- The argument to `get_context`.
- The yielded value of a `yield`.

The `ResumeTy` hides a `&mut Context<'_>` behind an unsafe raw pointer, and the
`get_context` function is being used to convert that back to a `&mut Context<'_>`.

Ideally the async lowering would not use the `ResumeTy`/`get_context` indirection,
but rather directly use `&mut Context<'_>`, however that would currently
lead to higher-kinded lifetime errors.
See <#105501>.

The async lowering step and the type / lifetime inference / checking are
still using the `ResumeTy` indirection for the time being, and that indirection
is removed here. After this transform, the generator body only knows about `&mut Context<'_>`.
  • Loading branch information
Swatinem committed Jan 17, 2023
1 parent 85357e3 commit 1db48e2
Show file tree
Hide file tree
Showing 10 changed files with 644 additions and 14 deletions.
1 change: 1 addition & 0 deletions compiler/rustc_hir/src/lang_items.rs
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ language_item_table! {
IdentityFuture, sym::identity_future, identity_future_fn, Target::Fn, GenericRequirement::None;
GetContext, sym::get_context, get_context_fn, Target::Fn, GenericRequirement::None;

Context, sym::Context, context, Target::Struct, GenericRequirement::None;
FuturePoll, sym::poll, future_poll_fn, Target::Method(MethodKind::Trait { body: false }), GenericRequirement::None;

FromFrom, sym::from, from_fn, Target::Method(MethodKind::Trait { body: false }), GenericRequirement::None;
Expand Down
9 changes: 9 additions & 0 deletions compiler/rustc_middle/src/ty/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1952,6 +1952,15 @@ impl<'tcx> TyCtxt<'tcx> {
self.mk_ty(GeneratorWitness(types))
}

/// Creates a `&mut Context<'_>` [`Ty`] with erased lifetimes.
pub fn mk_task_context(self) -> Ty<'tcx> {
let context_did = self.require_lang_item(LangItem::Context, None);
let context_adt_ref = self.adt_def(context_did);
let context_substs = self.intern_substs(&[self.lifetimes.re_erased.into()]);
let context_ty = self.mk_adt(context_adt_ref, context_substs);
self.mk_mut_ref(self.lifetimes.re_erased, context_ty)
}

#[inline]
pub fn mk_ty_var(self, v: TyVid) -> Ty<'tcx> {
self.mk_ty_infer(TyVar(v))
Expand Down
118 changes: 111 additions & 7 deletions compiler/rustc_mir_transform/src/generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,104 @@ fn replace_local<'tcx>(
new_local
}

/// Transforms the `body` of the generator applying the following transforms:
///
/// - Eliminates all the `get_context` calls that async lowering created.
/// - Replace all `Local` `ResumeTy` types with `&mut Context<'_>` (`context_mut_ref`).
///
/// The `Local`s that have their types replaced are:
/// - The `resume` argument itself.
/// - The argument to `get_context`.
/// - The yielded value of a `yield`.
///
/// The `ResumeTy` hides a `&mut Context<'_>` behind an unsafe raw pointer, and the
/// `get_context` function is being used to convert that back to a `&mut Context<'_>`.
///
/// Ideally the async lowering would not use the `ResumeTy`/`get_context` indirection,
/// but rather directly use `&mut Context<'_>`, however that would currently
/// lead to higher-kinded lifetime errors.
/// See <https://github.com/rust-lang/rust/issues/105501>.
///
/// The async lowering step and the type / lifetime inference / checking are
/// still using the `ResumeTy` indirection for the time being, and that indirection
/// is removed here. After this transform, the generator body only knows about `&mut Context<'_>`.
fn transform_async_context<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
let context_mut_ref = tcx.mk_task_context();

// replace the type of the `resume` argument
replace_resume_ty_local(tcx, body, Local::new(2), context_mut_ref);

let get_context_def_id = tcx.require_lang_item(LangItem::GetContext, None);

for bb in BasicBlock::new(0)..body.basic_blocks.next_index() {
let bb_data = &body[bb];
if bb_data.is_cleanup {
continue;
}

match &bb_data.terminator().kind {
TerminatorKind::Call { func, .. } => {
let func_ty = func.ty(body, tcx);
if let ty::FnDef(def_id, _) = *func_ty.kind() {
if def_id == get_context_def_id {
let local = eliminate_get_context_call(&mut body[bb]);
replace_resume_ty_local(tcx, body, local, context_mut_ref);
}
} else {
continue;
}
}
TerminatorKind::Yield { resume_arg, .. } => {
replace_resume_ty_local(tcx, body, resume_arg.local, context_mut_ref);
}
_ => {}
}
}
}

fn eliminate_get_context_call<'tcx>(bb_data: &mut BasicBlockData<'tcx>) -> Local {
let terminator = bb_data.terminator.take().unwrap();
if let TerminatorKind::Call { mut args, destination, target, .. } = terminator.kind {
let arg = args.pop().unwrap();
let local = arg.place().unwrap().local;

let arg = Rvalue::Use(arg);
let assign = Statement {
source_info: terminator.source_info,
kind: StatementKind::Assign(Box::new((destination, arg))),
};
bb_data.statements.push(assign);
bb_data.terminator = Some(Terminator {
source_info: terminator.source_info,
kind: TerminatorKind::Goto { target: target.unwrap() },
});
local
} else {
bug!();
}
}

#[cfg_attr(not(debug_assertions), allow(unused))]
fn replace_resume_ty_local<'tcx>(
tcx: TyCtxt<'tcx>,
body: &mut Body<'tcx>,
local: Local,
context_mut_ref: Ty<'tcx>,
) {
let local_ty = std::mem::replace(&mut body.local_decls[local].ty, context_mut_ref);
// We have to replace the `ResumeTy` that is used for type and borrow checking
// with `&mut Context<'_>` in MIR.
#[cfg(debug_assertions)]
{
if let ty::Adt(resume_ty_adt, _) = local_ty.kind() {
let expected_adt = tcx.adt_def(tcx.require_lang_item(LangItem::ResumeTy, None));
assert_eq!(*resume_ty_adt, expected_adt);
} else {
panic!("expected `ResumeTy`, found `{:?}`", local_ty);
};
}
}

struct LivenessInfo {
/// Which locals are live across any suspension point.
saved_locals: GeneratorSavedLocals,
Expand Down Expand Up @@ -1283,13 +1381,13 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
}
};

let is_async_kind = body.generator_kind().unwrap() != GeneratorKind::Gen;
let is_async_kind = matches!(body.generator_kind(), Some(GeneratorKind::Async(_)));
let (state_adt_ref, state_substs) = if is_async_kind {
// Compute Poll<return_ty>
let state_did = tcx.require_lang_item(LangItem::Poll, None);
let state_adt_ref = tcx.adt_def(state_did);
let state_substs = tcx.intern_substs(&[body.return_ty().into()]);
(state_adt_ref, state_substs)
let poll_did = tcx.require_lang_item(LangItem::Poll, None);
let poll_adt_ref = tcx.adt_def(poll_did);
let poll_substs = tcx.intern_substs(&[body.return_ty().into()]);
(poll_adt_ref, poll_substs)
} else {
// Compute GeneratorState<yield_ty, return_ty>
let state_did = tcx.require_lang_item(LangItem::GeneratorState, None);
Expand All @@ -1303,13 +1401,19 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
// RETURN_PLACE then is a fresh unused local with type ret_ty.
let new_ret_local = replace_local(RETURN_PLACE, ret_ty, body, tcx);

// Replace all occurrences of `ResumeTy` with `&mut Context<'_>` within async bodies.
if is_async_kind {
transform_async_context(tcx, body);
}

// We also replace the resume argument and insert an `Assign`.
// This is needed because the resume argument `_2` might be live across a `yield`, in which
// case there is no `Assign` to it that the transform can turn into a store to the generator
// state. After the yield the slot in the generator state would then be uninitialized.
let resume_local = Local::new(2);
let new_resume_local =
replace_local(resume_local, body.local_decls[resume_local].ty, body, tcx);
let resume_ty =
if is_async_kind { tcx.mk_task_context() } else { body.local_decls[resume_local].ty };
let new_resume_local = replace_local(resume_local, resume_ty, body, tcx);

// When first entering the generator, move the resume argument into its new local.
let source_info = SourceInfo::outermost(body.span);
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_span/src/symbol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ symbols! {
Capture,
Center,
Clone,
Context,
Continue,
Copy,
Count,
Expand Down
34 changes: 27 additions & 7 deletions compiler/rustc_ty_utils/src/abi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,21 +108,41 @@ fn fn_sig_for_fn_abi<'tcx>(
// `Generator::resume(...) -> GeneratorState` function in case we
// have an ordinary generator, or the `Future::poll(...) -> Poll`
// function in case this is a special generator backing an async construct.
let ret_ty = if tcx.generator_is_async(did) {
let state_did = tcx.require_lang_item(LangItem::Poll, None);
let state_adt_ref = tcx.adt_def(state_did);
let state_substs = tcx.intern_substs(&[sig.return_ty.into()]);
tcx.mk_adt(state_adt_ref, state_substs)
let (resume_ty, ret_ty) = if tcx.generator_is_async(did) {
// The signature should be `Future::poll(_, &mut Context<'_>) -> Poll<Output>`
let poll_did = tcx.require_lang_item(LangItem::Poll, None);
let poll_adt_ref = tcx.adt_def(poll_did);
let poll_substs = tcx.intern_substs(&[sig.return_ty.into()]);
let ret_ty = tcx.mk_adt(poll_adt_ref, poll_substs);

// We have to replace the `ResumeTy` that is used for type and borrow checking
// with `&mut Context<'_>` which is used in codegen.
#[cfg(debug_assertions)]
{
if let ty::Adt(resume_ty_adt, _) = sig.resume_ty.kind() {
let expected_adt =
tcx.adt_def(tcx.require_lang_item(LangItem::ResumeTy, None));
assert_eq!(*resume_ty_adt, expected_adt);
} else {
panic!("expected `ResumeTy`, found `{:?}`", sig.resume_ty);
};
}
let context_mut_ref = tcx.mk_task_context();

(context_mut_ref, ret_ty)
} else {
// The signature should be `Generator::resume(_, Resume) -> GeneratorState<Yield, Return>`
let state_did = tcx.require_lang_item(LangItem::GeneratorState, None);
let state_adt_ref = tcx.adt_def(state_did);
let state_substs = tcx.intern_substs(&[sig.yield_ty.into(), sig.return_ty.into()]);
tcx.mk_adt(state_adt_ref, state_substs)
let ret_ty = tcx.mk_adt(state_adt_ref, state_substs);

(sig.resume_ty, ret_ty)
};

ty::Binder::bind_with_vars(
tcx.mk_fn_sig(
[env_ty, sig.resume_ty].iter(),
[env_ty, resume_ty].iter(),
&ret_ty,
false,
hir::Unsafety::Normal,
Expand Down
4 changes: 4 additions & 0 deletions library/core/src/future/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,10 @@ pub unsafe fn get_context<'a, 'b>(cx: ResumeTy) -> &'a mut Context<'b> {
unsafe { &mut *cx.0.as_ptr().cast() }
}

// FIXME(swatinem): This fn is currently needed to work around shortcomings
// in type and lifetime inference.
// See the comment at the bottom of `LoweringContext::make_async_expr` and
// <https://github.com/rust-lang/rust/issues/104826>.
#[doc(hidden)]
#[unstable(feature = "gen_future", issue = "50547")]
#[inline]
Expand Down
1 change: 1 addition & 0 deletions library/core/src/task/wake.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ impl RawWakerVTable {
/// Currently, `Context` only serves to provide access to a [`&Waker`](Waker)
/// which can be used to wake the current task.
#[stable(feature = "futures_api", since = "1.36.0")]
#[cfg_attr(not(bootstrap), lang = "Context")]
pub struct Context<'a> {
waker: &'a Waker,
// Ensure we future-proof against variance changes by forcing
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// MIR for `a::{closure#0}` 0 generator_resume
/* generator_layout = GeneratorLayout {
field_tys: {},
variant_fields: {
Unresumed(0): [],
Returned (1): [],
Panicked (2): [],
},
storage_conflicts: BitMatrix(0x0) {},
} */

fn a::{closure#0}(_1: Pin<&mut [async fn body@$DIR/async_await.rs:10:14: 10:16]>, _2: &mut Context<'_>) -> Poll<()> {
debug _task_context => _4; // in scope 0 at $DIR/async_await.rs:+0:14: +0:16
let mut _0: std::task::Poll<()>; // return place in scope 0 at $DIR/async_await.rs:+0:14: +0:16
let mut _3: (); // in scope 0 at $DIR/async_await.rs:+0:14: +0:16
let mut _4: &mut std::task::Context<'_>; // in scope 0 at $DIR/async_await.rs:+0:14: +0:16
let mut _5: u32; // in scope 0 at $DIR/async_await.rs:+0:14: +0:16

bb0: {
_5 = discriminant((*(_1.0: &mut [async fn body@$DIR/async_await.rs:10:14: 10:16]))); // scope 0 at $DIR/async_await.rs:+0:14: +0:16
switchInt(move _5) -> [0: bb1, 1: bb2, otherwise: bb3]; // scope 0 at $DIR/async_await.rs:+0:14: +0:16
}

bb1: {
_4 = move _2; // scope 0 at $DIR/async_await.rs:+0:14: +0:16
_3 = const (); // scope 0 at $DIR/async_await.rs:+0:14: +0:16
Deinit(_0); // scope 0 at $DIR/async_await.rs:+0:16: +0:16
((_0 as Ready).0: ()) = move _3; // scope 0 at $DIR/async_await.rs:+0:16: +0:16
discriminant(_0) = 0; // scope 0 at $DIR/async_await.rs:+0:16: +0:16
discriminant((*(_1.0: &mut [async fn body@$DIR/async_await.rs:10:14: 10:16]))) = 1; // scope 0 at $DIR/async_await.rs:+0:16: +0:16
return; // scope 0 at $DIR/async_await.rs:+0:16: +0:16
}

bb2: {
assert(const false, "`async fn` resumed after completion") -> bb2; // scope 0 at $DIR/async_await.rs:+0:14: +0:16
}

bb3: {
unreachable; // scope 0 at $DIR/async_await.rs:+0:14: +0:16
}
}
Loading

0 comments on commit 1db48e2

Please sign in to comment.