Skip to content

Commit

Permalink
Avoid GenFuture shim when compiling async constructs
Browse files Browse the repository at this point in the history
Previously, async constructs would be lowered to "normal" generators,
with an additional `from_generator` / `GenFuture` shim in between to
convert from `Generator` to `Future`.

The compiler will now special-case these generators internally so that
async constructs will *directly* implement `Future` without the need
to go through the `from_generator` / `GenFuture` shim.

The primary motivation for this change was hiding this implementation
detail in stack traces and debuginfo, but it can in theory also help
the optimizer as there is less abstractions to see through.
  • Loading branch information
Swatinem committed Nov 18, 2022
1 parent 9d46c7a commit 8120660
Show file tree
Hide file tree
Showing 45 changed files with 427 additions and 441 deletions.
42 changes: 18 additions & 24 deletions compiler/rustc_ast_lowering/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,15 +148,16 @@ impl<'hir> LoweringContext<'_, 'hir> {
self.arena.alloc_from_iter(arms.iter().map(|x| self.lower_arm(x))),
hir::MatchSource::Normal,
),
ExprKind::Async(capture_clause, closure_node_id, ref block) => self
.make_async_expr(
ExprKind::Async(capture_clause, closure_node_id, ref block) => {
return self.make_async_expr(
capture_clause,
closure_node_id,
None,
block.span,
e.span,
hir::AsyncGeneratorKind::Block,
|this| this.with_new_scopes(|this| this.lower_block_expr(block)),
),
);
}
ExprKind::Await(ref expr) => {
let dot_await_span = if expr.span.hi() < e.span.hi() {
let span_with_whitespace = self
Expand Down Expand Up @@ -575,14 +576,14 @@ impl<'hir> LoweringContext<'_, 'hir> {
}
}

/// Lower an `async` construct to a generator that is then wrapped so it implements `Future`.
/// Lower an `async` construct to a generator that implements `Future`.
///
/// This results in:
///
/// ```text
/// std::future::from_generator(static move? |_task_context| -> <ret_ty> {
/// static move? |_task_context| -> <ret_ty> {
/// <body>
/// })
/// }
/// ```
pub(super) fn make_async_expr(
&mut self,
Expand All @@ -592,20 +593,22 @@ impl<'hir> LoweringContext<'_, 'hir> {
span: Span,
async_gen_kind: hir::AsyncGeneratorKind,
body: impl FnOnce(&mut Self) -> hir::Expr<'hir>,
) -> hir::ExprKind<'hir> {
) -> hir::Expr<'hir> {
let output = match ret_ty {
Some(ty) => hir::FnRetTy::Return(
self.lower_ty(&ty, &ImplTraitContext::Disallowed(ImplTraitPosition::AsyncBlock)),
),
None => hir::FnRetTy::DefaultReturn(self.lower_span(span)),
};

// Resume argument type. We let the compiler infer this to simplify the lowering. It is
// fully constrained by `future::from_generator`.
// Resume argument type: `ResumeTy`
let unstable_span =
self.mark_span_with_reason(DesugaringKind::Async, span, self.allow_gen_future.clone());
let resume_ty = hir::QPath::LangItem(hir::LangItem::ResumeTy, unstable_span, None);
let input_ty = hir::Ty {
hir_id: self.next_id(),
kind: hir::TyKind::Infer,
span: self.lower_span(span),
kind: hir::TyKind::Path(resume_ty),
span: unstable_span,
};

// The closure/generator `FnDecl` takes a single (resume) argument of type `input_ty`.
Expand Down Expand Up @@ -688,16 +691,7 @@ impl<'hir> LoweringContext<'_, 'hir> {

let generator = hir::Expr { hir_id, kind: generator_kind, span: self.lower_span(span) };

// `future::from_generator`:
let gen_future = self.expr_lang_item_path(
unstable_span,
hir::LangItem::FromGenerator,
AttrVec::new(),
None,
);

// `future::from_generator(generator)`:
hir::ExprKind::Call(self.arena.alloc(gen_future), arena_vec![self; generator])
generator
}

/// Desugar `<expr>.await` into:
Expand Down Expand Up @@ -1001,7 +995,7 @@ impl<'hir> LoweringContext<'_, 'hir> {
}

// Transform `async |x: u8| -> X { ... }` into
// `|x: u8| future_from_generator(|| -> X { ... })`.
// `|x: u8| || -> X { ... }`.
let body_id = this.lower_fn_body(&outer_decl, |this| {
let async_ret_ty =
if let FnRetTy::Ty(ty) = &decl.output { Some(ty.clone()) } else { None };
Expand All @@ -1013,7 +1007,7 @@ impl<'hir> LoweringContext<'_, 'hir> {
hir::AsyncGeneratorKind::Closure,
|this| this.with_new_scopes(|this| this.lower_expr_mut(body)),
);
this.expr(fn_decl_span, async_body, AttrVec::new())
async_body
});
body_id
});
Expand Down
5 changes: 1 addition & 4 deletions compiler/rustc_ast_lowering/src/item.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1251,10 +1251,7 @@ impl<'hir> LoweringContext<'_, 'hir> {
},
);

(
this.arena.alloc_from_iter(parameters),
this.expr(body.span, async_expr, AttrVec::new()),
)
(this.arena.alloc_from_iter(parameters), async_expr)
})
}

Expand Down
15 changes: 11 additions & 4 deletions compiler/rustc_borrowck/src/diagnostics/region_errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use rustc_middle::ty::subst::InternalSubsts;
use rustc_middle::ty::Region;
use rustc_middle::ty::TypeVisitor;
use rustc_middle::ty::{self, RegionVid, Ty};
use rustc_span::symbol::{kw, sym, Ident};
use rustc_span::symbol::{kw, Ident};
use rustc_span::Span;

use crate::borrowck_errors;
Expand Down Expand Up @@ -514,8 +514,11 @@ impl<'a, 'tcx> MirBorrowckCtxt<'a, 'tcx> {
span: *span,
ty_err: match output_ty.kind() {
ty::Closure(_, _) => FnMutReturnTypeErr::ReturnClosure { span: *span },
ty::Adt(def, _)
if self.infcx.tcx.is_diagnostic_item(sym::gen_future, def.did()) =>
ty::Generator(def, ..)
if matches!(
self.infcx.tcx.generator_kind(def),
Some(hir::GeneratorKind::Async(_))
) =>
{
FnMutReturnTypeErr::ReturnAsyncBlock { span: *span }
}
Expand Down Expand Up @@ -927,10 +930,14 @@ impl<'a, 'tcx> MirBorrowckCtxt<'a, 'tcx> {
// only when the block is a closure
if let hir::ExprKind::Closure(hir::Closure {
capture_clause: hir::CaptureBy::Ref,
body,
..
}) = expr.kind
{
closure_span = Some(expr.span.shrink_to_lo());
let body = map.body(*body);
if !matches!(body.generator_kind, Some(hir::GeneratorKind::Async(..))) {
closure_span = Some(expr.span.shrink_to_lo());
}
}
}
}
Expand Down
21 changes: 11 additions & 10 deletions compiler/rustc_const_eval/src/transform/check_consts/check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -449,8 +449,17 @@ impl<'tcx> Visitor<'tcx> for Checker<'_, 'tcx> {
| Rvalue::CopyForDeref(..)
| Rvalue::Repeat(..)
| Rvalue::Discriminant(..)
| Rvalue::Len(_)
| Rvalue::Aggregate(..) => {}
| Rvalue::Len(_) => {}

Rvalue::Aggregate(ref kind, ..) => {
if let AggregateKind::Generator(def_id, ..) = kind.as_ref() {
if let Some(generator_kind) = self.tcx.generator_kind(def_id.to_def_id()) {
if matches!(generator_kind, hir::GeneratorKind::Async(..)) {
self.check_op(ops::Generator(generator_kind));
}
}
}
}

Rvalue::Ref(_, kind @ BorrowKind::Mut { .. }, ref place)
| Rvalue::Ref(_, kind @ BorrowKind::Unique, ref place) => {
Expand Down Expand Up @@ -889,14 +898,6 @@ impl<'tcx> Visitor<'tcx> for Checker<'_, 'tcx> {
return;
}

// `async` blocks get lowered to `std::future::from_generator(/* a closure */)`.
let is_async_block = Some(callee) == tcx.lang_items().from_generator_fn();
if is_async_block {
let kind = hir::GeneratorKind::Async(hir::AsyncGeneratorKind::Block);
self.check_op(ops::Generator(kind));
return;
}

if !tcx.is_const_fn_raw(callee) {
if !tcx.is_const_default_method(callee) {
// To get to here we must have already found a const impl for the
Expand Down
2 changes: 2 additions & 0 deletions compiler/rustc_hir/src/lang_items.rs
Original file line number Diff line number Diff line change
Expand Up @@ -270,9 +270,11 @@ language_item_table! {
TryTraitBranch, sym::branch, branch_fn, Target::Method(MethodKind::Trait { body: false }), GenericRequirement::None;
TryTraitFromYeet, sym::from_yeet, from_yeet_fn, Target::Fn, GenericRequirement::None;

Poll, sym::Poll, poll, Target::Enum, GenericRequirement::None;
PollReady, sym::Ready, poll_ready_variant, Target::Variant, GenericRequirement::None;
PollPending, sym::Pending, poll_pending_variant, Target::Variant, GenericRequirement::None;

ResumeTy, sym::ResumeTy, resume_ty, Target::Struct, GenericRequirement::None;
FromGenerator, sym::from_generator, from_generator_fn, Target::Fn, GenericRequirement::None;
GetContext, sym::get_context, get_context_fn, Target::Fn, GenericRequirement::None;

Expand Down
4 changes: 2 additions & 2 deletions compiler/rustc_hir_typeck/src/callee.rs
Original file line number Diff line number Diff line change
Expand Up @@ -302,9 +302,9 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
let fn_decl_span = if hir.body(body).generator_kind
== Some(hir::GeneratorKind::Async(hir::AsyncGeneratorKind::Closure))
{
// Actually need to unwrap a few more layers of HIR to get to
// Actually need to unwrap one more layer of HIR to get to
// the _real_ closure...
let async_closure = hir.get_parent_node(hir.get_parent_node(parent_hir_id));
let async_closure = hir.get_parent_node(parent_hir_id);
if let hir::Node::Expr(hir::Expr {
kind: hir::ExprKind::Closure(&hir::Closure { fn_decl_span, .. }),
..
Expand Down
13 changes: 9 additions & 4 deletions compiler/rustc_hir_typeck/src/check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,15 @@ pub(super) fn check_fn<'a, 'tcx>(

fn_maybe_err(tcx, span, fn_sig.abi);

if body.generator_kind.is_some() && can_be_generator.is_some() {
let yield_ty = fcx
.next_ty_var(TypeVariableOrigin { kind: TypeVariableOriginKind::TypeInference, span });
fcx.require_type_is_sized(yield_ty, span, traits::SizedYieldType);
if let Some(kind) = body.generator_kind && can_be_generator.is_some() {
let yield_ty = if kind == hir::GeneratorKind::Gen {
let yield_ty = fcx
.next_ty_var(TypeVariableOrigin { kind: TypeVariableOriginKind::TypeInference, span });
fcx.require_type_is_sized(yield_ty, span, traits::SizedYieldType);
yield_ty
} else {
tcx.mk_unit()
};

// Resume type defaults to `()` if the generator has no argument.
let resume_ty = fn_sig.inputs().get(0).copied().unwrap_or_else(|| tcx.mk_unit());
Expand Down
8 changes: 0 additions & 8 deletions compiler/rustc_hir_typeck/src/fn_ctxt/checks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1729,14 +1729,6 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
let hir = self.tcx.hir();
let hir::Node::Expr(expr) = hir.get(hir_id) else { return false; };

// Skip over mentioning async lang item
if Some(def_id) == self.tcx.lang_items().from_generator_fn()
&& error.obligation.cause.span.desugaring_kind()
== Some(rustc_span::DesugaringKind::Async)
{
return false;
}

let Some(unsubstituted_pred) =
self.tcx.predicates_of(def_id).instantiate_identity(self.tcx).predicates.into_iter().nth(idx)
else { return false; };
Expand Down
11 changes: 10 additions & 1 deletion compiler/rustc_lint/src/unused.rs
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,16 @@ impl<'tcx> LateLintPass<'tcx> for UnusedResults {
);
true
}
ty::Generator(..) => {
ty::Generator(def_id, ..) => {
// async fn should be treated as "implementor of `Future`"
if matches!(cx.tcx.generator_kind(def_id), Some(hir::GeneratorKind::Async(..)))
{
let def_id = cx.tcx.lang_items().future_trait().unwrap();
let descr_pre = &format!("{}implementer{} of ", descr_pre, plural_suffix,);
if check_must_use_def(cx, def_id, span, descr_pre, descr_post) {
return true;
}
}
cx.struct_span_lint(
UNUSED_MUST_USE,
span,
Expand Down
20 changes: 20 additions & 0 deletions compiler/rustc_middle/src/traits/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,9 @@ pub enum ImplSource<'tcx, N> {
/// ImplSource automatically generated for a generator.
Generator(ImplSourceGeneratorData<'tcx, N>),

/// ImplSource automatically generated for a generator backing an async future.
Future(ImplSourceFutureData<'tcx, N>),

/// ImplSource for a trait alias.
TraitAlias(ImplSourceTraitAliasData<'tcx, N>),

Expand All @@ -676,6 +679,7 @@ impl<'tcx, N> ImplSource<'tcx, N> {
ImplSource::AutoImpl(d) => d.nested,
ImplSource::Closure(c) => c.nested,
ImplSource::Generator(c) => c.nested,
ImplSource::Future(c) => c.nested,
ImplSource::Object(d) => d.nested,
ImplSource::FnPointer(d) => d.nested,
ImplSource::DiscriminantKind(ImplSourceDiscriminantKindData)
Expand All @@ -694,6 +698,7 @@ impl<'tcx, N> ImplSource<'tcx, N> {
ImplSource::AutoImpl(d) => &d.nested,
ImplSource::Closure(c) => &c.nested,
ImplSource::Generator(c) => &c.nested,
ImplSource::Future(c) => &c.nested,
ImplSource::Object(d) => &d.nested,
ImplSource::FnPointer(d) => &d.nested,
ImplSource::DiscriminantKind(ImplSourceDiscriminantKindData)
Expand Down Expand Up @@ -737,6 +742,11 @@ impl<'tcx, N> ImplSource<'tcx, N> {
substs: c.substs,
nested: c.nested.into_iter().map(f).collect(),
}),
ImplSource::Future(c) => ImplSource::Future(ImplSourceFutureData {
generator_def_id: c.generator_def_id,
substs: c.substs,
nested: c.nested.into_iter().map(f).collect(),
}),
ImplSource::FnPointer(p) => ImplSource::FnPointer(ImplSourceFnPointerData {
fn_ty: p.fn_ty,
nested: p.nested.into_iter().map(f).collect(),
Expand Down Expand Up @@ -796,6 +806,16 @@ pub struct ImplSourceGeneratorData<'tcx, N> {
pub nested: Vec<N>,
}

#[derive(Clone, PartialEq, Eq, TyEncodable, TyDecodable, HashStable, Lift)]
#[derive(TypeFoldable, TypeVisitable)]
pub struct ImplSourceFutureData<'tcx, N> {
pub generator_def_id: DefId,
pub substs: SubstsRef<'tcx>,
/// Nested obligations. This can be non-empty if the generator
/// signature contains associated types.
pub nested: Vec<N>,
}

#[derive(Clone, PartialEq, Eq, TyEncodable, TyDecodable, HashStable, Lift)]
#[derive(TypeFoldable, TypeVisitable)]
pub struct ImplSourceClosureData<'tcx, N> {
Expand Down
4 changes: 4 additions & 0 deletions compiler/rustc_middle/src/traits/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,10 @@ pub enum SelectionCandidate<'tcx> {
/// generated for a generator.
GeneratorCandidate,

/// Implementation of a `Future` trait by one of the generator types
/// generated for an async construct.
FutureCandidate,

/// Implementation of a `Fn`-family trait by one of the anonymous
/// types generated for a fn pointer type (e.g., `fn(int) -> int`)
FnPointerCandidate {
Expand Down
12 changes: 12 additions & 0 deletions compiler/rustc_middle/src/traits/structural_impls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ impl<'tcx, N: fmt::Debug> fmt::Debug for traits::ImplSource<'tcx, N> {

super::ImplSource::Generator(ref d) => write!(f, "{:?}", d),

super::ImplSource::Future(ref d) => write!(f, "{:?}", d),

super::ImplSource::FnPointer(ref d) => write!(f, "({:?})", d),

super::ImplSource::DiscriminantKind(ref d) => write!(f, "{:?}", d),
Expand Down Expand Up @@ -58,6 +60,16 @@ impl<'tcx, N: fmt::Debug> fmt::Debug for traits::ImplSourceGeneratorData<'tcx, N
}
}

impl<'tcx, N: fmt::Debug> fmt::Debug for traits::ImplSourceFutureData<'tcx, N> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"ImplSourceFutureData(generator_def_id={:?}, substs={:?}, nested={:?})",
self.generator_def_id, self.substs, self.nested
)
}
}

impl<'tcx, N: fmt::Debug> fmt::Debug for traits::ImplSourceClosureData<'tcx, N> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
Expand Down
11 changes: 11 additions & 0 deletions compiler/rustc_middle/src/ty/print/pretty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -679,6 +679,17 @@ pub trait PrettyPrinter<'tcx>:
}
ty::Str => p!("str"),
ty::Generator(did, substs, movability) => {
// FIXME(swatinem): async constructs used to be pretty printed
// as `impl Future` previously due to the `from_generator` wrapping.
// lets special case this here for now to avoid churn in diagnostics.
let generator_kind = self.tcx().generator_kind(did);
if matches!(generator_kind, Some(hir::GeneratorKind::Async(..))) {
let return_ty = substs.as_generator().return_ty();
p!(write("impl Future<Output = {}>", return_ty));

return Ok(self);
}

p!(write("["));
match movability {
hir::Movability::Movable => {}
Expand Down
Loading

0 comments on commit 8120660

Please sign in to comment.