Skip to content

Commit

Permalink
Rollup merge of #127482 - compiler-errors:closure-two-par-sig-inferen…
Browse files Browse the repository at this point in the history
…ce, r=oli-obk

Infer async closure signature from (old-style) two-part `Fn` + `Future` bounds

When an async closure is passed to a function that has a "two-part" `Fn` and `Future` trait bound, like:

```rust
use std::future::Future;

fn not_exactly_an_async_closure(_f: F)
where
    F: FnOnce(String) -> Fut,
    Fut: Future<Output = ()>,
{}
```

The we want to be able to extract the signature to guide inference in the async closure, like:

```rust
not_exactly_an_async_closure(async |string| {
    for x in string.split('\n') { ... }
    //~^ We need to know that the type of `string` is `String` to call methods on it.
})
```

Closure signature inference will see two bounds: `<?F as FnOnce<Args>>::Output = ?Fut`, `<?Fut as Future>::Output = String`. We need to extract the signature by looking through both projections.

### Why?

I expect the ecosystem's move onto `async Fn` trait bounds (which are not affected by this PR, and already do signature inference fine) to be slow. In the mean time, I don't see major overhead to supporting this "old–style" of trait bounds that were used to model async closures.

r? oli-obk
Fixes #127468
Fixes #127425
  • Loading branch information
GuillaumeGomez authored Jul 8, 2024
2 parents 081cca1 + f4f678f commit 72199b2
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 8 deletions.
102 changes: 94 additions & 8 deletions compiler/rustc_hir_typeck/src/closure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -424,9 +424,10 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
if let Some(trait_def_id) = trait_def_id {
let found_kind = match closure_kind {
hir::ClosureKind::Closure => self.tcx.fn_trait_kind_from_def_id(trait_def_id),
hir::ClosureKind::CoroutineClosure(hir::CoroutineDesugaring::Async) => {
self.tcx.async_fn_trait_kind_from_def_id(trait_def_id)
}
hir::ClosureKind::CoroutineClosure(hir::CoroutineDesugaring::Async) => self
.tcx
.async_fn_trait_kind_from_def_id(trait_def_id)
.or_else(|| self.tcx.fn_trait_kind_from_def_id(trait_def_id)),
_ => None,
};

Expand Down Expand Up @@ -470,14 +471,37 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
// for closures and async closures, respectively.
match closure_kind {
hir::ClosureKind::Closure
if self.tcx.fn_trait_kind_from_def_id(trait_def_id).is_some() => {}
if self.tcx.fn_trait_kind_from_def_id(trait_def_id).is_some() =>
{
self.extract_sig_from_projection(cause_span, projection)
}
hir::ClosureKind::CoroutineClosure(hir::CoroutineDesugaring::Async)
if self.tcx.async_fn_trait_kind_from_def_id(trait_def_id).is_some() =>
{
self.extract_sig_from_projection(cause_span, projection)
}
// It's possible we've passed the closure to a (somewhat out-of-fashion)
// `F: FnOnce() -> Fut, Fut: Future<Output = T>` style bound. Let's still
// guide inference here, since it's beneficial for the user.
hir::ClosureKind::CoroutineClosure(hir::CoroutineDesugaring::Async)
if self.tcx.async_fn_trait_kind_from_def_id(trait_def_id).is_some() => {}
_ => return None,
if self.tcx.fn_trait_kind_from_def_id(trait_def_id).is_some() =>
{
self.extract_sig_from_projection_and_future_bound(cause_span, projection)
}
_ => None,
}
}

/// Given an `FnOnce::Output` or `AsyncFn::Output` projection, extract the args
/// and return type to infer a [`ty::PolyFnSig`] for the closure.
fn extract_sig_from_projection(
&self,
cause_span: Option<Span>,
projection: ty::PolyProjectionPredicate<'tcx>,
) -> Option<ExpectedSig<'tcx>> {
let projection = self.resolve_vars_if_possible(projection);

let arg_param_ty = projection.skip_binder().projection_term.args.type_at(1);
let arg_param_ty = self.resolve_vars_if_possible(arg_param_ty);
debug!(?arg_param_ty);

let ty::Tuple(input_tys) = *arg_param_ty.kind() else {
Expand All @@ -486,7 +510,6 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {

// Since this is a return parameter type it is safe to unwrap.
let ret_param_ty = projection.skip_binder().term.expect_type();
let ret_param_ty = self.resolve_vars_if_possible(ret_param_ty);
debug!(?ret_param_ty);

let sig = projection.rebind(self.tcx.mk_fn_sig(
Expand All @@ -500,6 +523,69 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
Some(ExpectedSig { cause_span, sig })
}

/// When an async closure is passed to a function that has a "two-part" `Fn`
/// and `Future` trait bound, like:
///
/// ```rust
/// use std::future::Future;
///
/// fn not_exactly_an_async_closure<F, Fut>(_f: F)
/// where
/// F: FnOnce(String, u32) -> Fut,
/// Fut: Future<Output = i32>,
/// {}
/// ```
///
/// The we want to be able to extract the signature to guide inference in the async
/// closure. We will have two projection predicates registered in this case. First,
/// we identify the `FnOnce<Args, Output = ?Fut>` bound, and if the output type is
/// an inference variable `?Fut`, we check if that is bounded by a `Future<Output = Ty>`
/// projection.
fn extract_sig_from_projection_and_future_bound(
&self,
cause_span: Option<Span>,
projection: ty::PolyProjectionPredicate<'tcx>,
) -> Option<ExpectedSig<'tcx>> {
let projection = self.resolve_vars_if_possible(projection);

let arg_param_ty = projection.skip_binder().projection_term.args.type_at(1);
debug!(?arg_param_ty);

let ty::Tuple(input_tys) = *arg_param_ty.kind() else {
return None;
};

// If the return type is a type variable, look for bounds on it.
// We could theoretically support other kinds of return types here,
// but none of them would be useful, since async closures return
// concrete anonymous future types, and their futures are not coerced
// into any other type within the body of the async closure.
let ty::Infer(ty::TyVar(return_vid)) = *projection.skip_binder().term.expect_type().kind()
else {
return None;
};

// FIXME: We may want to elaborate here, though I assume this will be exceedingly rare.
for bound in self.obligations_for_self_ty(return_vid) {
if let Some(ret_projection) = bound.predicate.as_projection_clause()
&& let Some(ret_projection) = ret_projection.no_bound_vars()
&& self.tcx.is_lang_item(ret_projection.def_id(), LangItem::FutureOutput)
{
let sig = projection.rebind(self.tcx.mk_fn_sig(
input_tys,
ret_projection.term.expect_type(),
false,
hir::Safety::Safe,
Abi::Rust,
));

return Some(ExpectedSig { cause_span, sig });
}
}

None
}

fn sig_of_closure(
&self,
expr_def_id: LocalDefId,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
//@ edition: 2021
//@ check-pass
//@ revisions: current next
//@ ignore-compare-mode-next-solver (explicit revisions)
//@[next] compile-flags: -Znext-solver

#![feature(async_closure)]

use std::future::Future;
use std::any::Any;

struct Struct;
impl Struct {
fn method(&self) {}
}

fn fake_async_closure<F, Fut>(_: F)
where
F: Fn(Struct) -> Fut,
Fut: Future<Output = ()>,
{}

fn main() {
fake_async_closure(async |s| {
s.method();
})
}

0 comments on commit 72199b2

Please sign in to comment.