@@ -424,9 +424,10 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
424424 if let Some ( trait_def_id) = trait_def_id {
425425 let found_kind = match closure_kind {
426426 hir:: ClosureKind :: Closure => self . tcx . fn_trait_kind_from_def_id ( trait_def_id) ,
427- hir:: ClosureKind :: CoroutineClosure ( hir:: CoroutineDesugaring :: Async ) => {
428- self . tcx . async_fn_trait_kind_from_def_id ( trait_def_id)
429- }
427+ hir:: ClosureKind :: CoroutineClosure ( hir:: CoroutineDesugaring :: Async ) => self
428+ . tcx
429+ . async_fn_trait_kind_from_def_id ( trait_def_id)
430+ . or_else ( || self . tcx . fn_trait_kind_from_def_id ( trait_def_id) ) ,
430431 _ => None ,
431432 } ;
432433
@@ -470,14 +471,37 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
470471 // for closures and async closures, respectively.
471472 match closure_kind {
472473 hir:: ClosureKind :: Closure
473- if self . tcx . fn_trait_kind_from_def_id ( trait_def_id) . is_some ( ) => { }
474+ if self . tcx . fn_trait_kind_from_def_id ( trait_def_id) . is_some ( ) =>
475+ {
476+ self . extract_sig_from_projection ( cause_span, projection)
477+ }
478+ hir:: ClosureKind :: CoroutineClosure ( hir:: CoroutineDesugaring :: Async )
479+ if self . tcx . async_fn_trait_kind_from_def_id ( trait_def_id) . is_some ( ) =>
480+ {
481+ self . extract_sig_from_projection ( cause_span, projection)
482+ }
483+ // It's possible we've passed the closure to a (somewhat out-of-fashion)
484+ // `F: FnOnce() -> Fut, Fut: Future<Output = T>` style bound. Let's still
485+ // guide inference here, since it's beneficial for the user.
474486 hir:: ClosureKind :: CoroutineClosure ( hir:: CoroutineDesugaring :: Async )
475- if self . tcx . async_fn_trait_kind_from_def_id ( trait_def_id) . is_some ( ) => { }
487+ if self . tcx . fn_trait_kind_from_def_id ( trait_def_id) . is_some ( ) =>
488+ {
489+ self . extract_sig_from_projection_and_future_bound ( cause_span, projection)
490+ }
476491 _ => return None ,
477492 }
493+ }
494+
495+ /// Given an `FnOnce::Output` or `AsyncFn::Output` projection, extract the args
496+ /// and return type to infer a [`ty::PolyFnSig`] for the closure.
497+ fn extract_sig_from_projection (
498+ & self ,
499+ cause_span : Option < Span > ,
500+ projection : ty:: PolyProjectionPredicate < ' tcx > ,
501+ ) -> Option < ExpectedSig < ' tcx > > {
502+ let projection = self . resolve_vars_if_possible ( projection) ;
478503
479504 let arg_param_ty = projection. skip_binder ( ) . projection_term . args . type_at ( 1 ) ;
480- let arg_param_ty = self . resolve_vars_if_possible ( arg_param_ty) ;
481505 debug ! ( ?arg_param_ty) ;
482506
483507 let ty:: Tuple ( input_tys) = * arg_param_ty. kind ( ) else {
@@ -486,7 +510,6 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
486510
487511 // Since this is a return parameter type it is safe to unwrap.
488512 let ret_param_ty = projection. skip_binder ( ) . term . expect_type ( ) ;
489- let ret_param_ty = self . resolve_vars_if_possible ( ret_param_ty) ;
490513 debug ! ( ?ret_param_ty) ;
491514
492515 let sig = projection. rebind ( self . tcx . mk_fn_sig (
@@ -500,6 +523,65 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
500523 Some ( ExpectedSig { cause_span, sig } )
501524 }
502525
526+ /// When an async closure is passed to a function that has a "two-part" `Fn`
527+ /// and `Future` trait bound, like:
528+ ///
529+ /// ```rust
530+ /// use std::future::Future;
531+ ///
532+ /// fn not_exactly_an_async_closure(_f: F)
533+ /// where
534+ /// F: FnOnce(String, u32) -> Fut,
535+ /// Fut: Future<Output = i32>,
536+ /// {}
537+ /// ```
538+ ///
539+ /// The we want to be able to extract the signature to guide inference in the async
540+ /// closure. We will have two projection predicates registered in this case. First,
541+ /// we identify the `FnOnce<Args, Output = ?Fut>` bound, and if the output type is
542+ /// an inference variable `?Fut`, we check if that is bounded by a `Future<Output = Ty>`
543+ /// projection.
544+ fn extract_sig_from_projection_and_future_bound (
545+ & self ,
546+ cause_span : Option < Span > ,
547+ projection : ty:: PolyProjectionPredicate < ' tcx > ,
548+ ) -> Option < ExpectedSig < ' tcx > > {
549+ let projection = self . resolve_vars_if_possible ( projection) ;
550+
551+ let arg_param_ty = projection. skip_binder ( ) . projection_term . args . type_at ( 1 ) ;
552+ debug ! ( ?arg_param_ty) ;
553+
554+ let ty:: Tuple ( input_tys) = * arg_param_ty. kind ( ) else {
555+ return None ;
556+ } ;
557+
558+ // If the return type is a
559+ let ty:: Infer ( ty:: TyVar ( return_vid) ) = * projection. skip_binder ( ) . term . expect_type ( ) . kind ( )
560+ else {
561+ return None ;
562+ } ;
563+
564+ // FIXME: We may want to elaborate here, though I assume this will be exceedingly rare.
565+ for bound in self . obligations_for_self_ty ( return_vid) {
566+ if let Some ( ret_projection) = bound. predicate . as_projection_clause ( )
567+ && let Some ( ret_projection) = ret_projection. no_bound_vars ( )
568+ && self . tcx . is_lang_item ( ret_projection. def_id ( ) , LangItem :: FutureOutput )
569+ {
570+ let sig = projection. rebind ( self . tcx . mk_fn_sig (
571+ input_tys,
572+ ret_projection. term . expect_type ( ) ,
573+ false ,
574+ hir:: Safety :: Safe ,
575+ Abi :: Rust ,
576+ ) ) ;
577+
578+ return Some ( ExpectedSig { cause_span, sig } ) ;
579+ }
580+ }
581+
582+ None
583+ }
584+
503585 fn sig_of_closure (
504586 & self ,
505587 expr_def_id : LocalDefId ,
0 commit comments