Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Infer async closure signature from (old-style) two-part Fn + Future bounds #127482

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 94 additions & 8 deletions compiler/rustc_hir_typeck/src/closure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -424,9 +424,10 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
if let Some(trait_def_id) = trait_def_id {
let found_kind = match closure_kind {
hir::ClosureKind::Closure => self.tcx.fn_trait_kind_from_def_id(trait_def_id),
hir::ClosureKind::CoroutineClosure(hir::CoroutineDesugaring::Async) => {
self.tcx.async_fn_trait_kind_from_def_id(trait_def_id)
}
hir::ClosureKind::CoroutineClosure(hir::CoroutineDesugaring::Async) => self
.tcx
.async_fn_trait_kind_from_def_id(trait_def_id)
.or_else(|| self.tcx.fn_trait_kind_from_def_id(trait_def_id)),
_ => None,
};

Expand Down Expand Up @@ -470,14 +471,37 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
// for closures and async closures, respectively.
match closure_kind {
hir::ClosureKind::Closure
if self.tcx.fn_trait_kind_from_def_id(trait_def_id).is_some() => {}
if self.tcx.fn_trait_kind_from_def_id(trait_def_id).is_some() =>
{
self.extract_sig_from_projection(cause_span, projection)
}
hir::ClosureKind::CoroutineClosure(hir::CoroutineDesugaring::Async)
if self.tcx.async_fn_trait_kind_from_def_id(trait_def_id).is_some() =>
{
self.extract_sig_from_projection(cause_span, projection)
}
// It's possible we've passed the closure to a (somewhat out-of-fashion)
// `F: FnOnce() -> Fut, Fut: Future<Output = T>` style bound. Let's still
// guide inference here, since it's beneficial for the user.
hir::ClosureKind::CoroutineClosure(hir::CoroutineDesugaring::Async)
if self.tcx.async_fn_trait_kind_from_def_id(trait_def_id).is_some() => {}
_ => return None,
if self.tcx.fn_trait_kind_from_def_id(trait_def_id).is_some() =>
{
self.extract_sig_from_projection_and_future_bound(cause_span, projection)
}
_ => None,
}
}

/// Given an `FnOnce::Output` or `AsyncFn::Output` projection, extract the args
/// and return type to infer a [`ty::PolyFnSig`] for the closure.
fn extract_sig_from_projection(
&self,
cause_span: Option<Span>,
projection: ty::PolyProjectionPredicate<'tcx>,
) -> Option<ExpectedSig<'tcx>> {
let projection = self.resolve_vars_if_possible(projection);

let arg_param_ty = projection.skip_binder().projection_term.args.type_at(1);
let arg_param_ty = self.resolve_vars_if_possible(arg_param_ty);
debug!(?arg_param_ty);

let ty::Tuple(input_tys) = *arg_param_ty.kind() else {
Expand All @@ -486,7 +510,6 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {

// Since this is a return parameter type it is safe to unwrap.
let ret_param_ty = projection.skip_binder().term.expect_type();
let ret_param_ty = self.resolve_vars_if_possible(ret_param_ty);
debug!(?ret_param_ty);

let sig = projection.rebind(self.tcx.mk_fn_sig(
Expand All @@ -500,6 +523,69 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
Some(ExpectedSig { cause_span, sig })
}

/// When an async closure is passed to a function that has a "two-part" `Fn`
/// and `Future` trait bound, like:
///
/// ```rust
/// use std::future::Future;
///
/// fn not_exactly_an_async_closure<F, Fut>(_f: F)
/// where
/// F: FnOnce(String, u32) -> Fut,
/// Fut: Future<Output = i32>,
/// {}
/// ```
///
/// The we want to be able to extract the signature to guide inference in the async
/// closure. We will have two projection predicates registered in this case. First,
/// we identify the `FnOnce<Args, Output = ?Fut>` bound, and if the output type is
/// an inference variable `?Fut`, we check if that is bounded by a `Future<Output = Ty>`
/// projection.
fn extract_sig_from_projection_and_future_bound(
&self,
cause_span: Option<Span>,
projection: ty::PolyProjectionPredicate<'tcx>,
) -> Option<ExpectedSig<'tcx>> {
let projection = self.resolve_vars_if_possible(projection);

let arg_param_ty = projection.skip_binder().projection_term.args.type_at(1);
debug!(?arg_param_ty);

let ty::Tuple(input_tys) = *arg_param_ty.kind() else {
return None;
};

// If the return type is a type variable, look for bounds on it.
// We could theoretically support other kinds of return types here,
// but none of them would be useful, since async closures return
// concrete anonymous future types, and their futures are not coerced
// into any other type within the body of the async closure.
let ty::Infer(ty::TyVar(return_vid)) = *projection.skip_binder().term.expect_type().kind()
else {
return None;
};

// FIXME: We may want to elaborate here, though I assume this will be exceedingly rare.
for bound in self.obligations_for_self_ty(return_vid) {
if let Some(ret_projection) = bound.predicate.as_projection_clause()
&& let Some(ret_projection) = ret_projection.no_bound_vars()
&& self.tcx.is_lang_item(ret_projection.def_id(), LangItem::FutureOutput)
{
let sig = projection.rebind(self.tcx.mk_fn_sig(
input_tys,
ret_projection.term.expect_type(),
false,
hir::Safety::Safe,
Abi::Rust,
));

return Some(ExpectedSig { cause_span, sig });
}
}

None
}

fn sig_of_closure(
&self,
expr_def_id: LocalDefId,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
//@ edition: 2021
//@ check-pass
//@ revisions: current next
//@ ignore-compare-mode-next-solver (explicit revisions)
//@[next] compile-flags: -Znext-solver

#![feature(async_closure)]

use std::future::Future;
use std::any::Any;

struct Struct;
impl Struct {
fn method(&self) {}
}

fn fake_async_closure<F, Fut>(_: F)
where
F: Fn(Struct) -> Fut,
Fut: Future<Output = ()>,
{}

fn main() {
fake_async_closure(async |s| {
s.method();
})
}
Loading