@@ -5,7 +5,7 @@ use std::{cmp, convert::Infallible, mem};
55use chalk_ir:: {
66 cast:: Cast ,
77 fold:: { FallibleTypeFolder , TypeFoldable } ,
8- AliasEq , AliasTy , BoundVar , DebruijnIndex , FnSubst , Mutability , TyKind , WhereClause ,
8+ BoundVar , DebruijnIndex , FnSubst , Mutability , TyKind ,
99} ;
1010use either:: Either ;
1111use hir_def:: {
@@ -22,13 +22,14 @@ use stdx::never;
2222
2323use crate :: {
2424 db:: { HirDatabase , InternedClosure } ,
25- from_placeholder_idx, make_binders,
25+ from_chalk_trait_id , from_placeholder_idx, make_binders,
2626 mir:: { BorrowKind , MirSpan , MutBorrowKind , ProjectionElem } ,
2727 static_lifetime, to_chalk_trait_id,
2828 traits:: FnTrait ,
29- utils:: { self , generics, Generics } ,
30- Adjust , Adjustment , Binders , BindingMode , ChalkTraitId , ClosureId , DynTy , FnAbi , FnPointer ,
31- FnSig , Interner , Substitution , Ty , TyExt ,
29+ utils:: { self , elaborate_clause_supertraits, generics, Generics } ,
30+ Adjust , Adjustment , AliasEq , AliasTy , Binders , BindingMode , ChalkTraitId , ClosureId , DynTy ,
31+ DynTyExt , FnAbi , FnPointer , FnSig , Interner , OpaqueTy , ProjectionTyExt , Substitution , Ty ,
32+ TyExt , WhereClause ,
3233} ;
3334
3435use super :: { Expectation , InferenceContext } ;
@@ -47,6 +48,15 @@ impl InferenceContext<'_> {
4748 None => return ,
4849 } ;
4950
51+ if let TyKind :: Closure ( closure_id, _) = closure_ty. kind ( Interner ) {
52+ if let Some ( closure_kind) = self . deduce_closure_kind_from_expectations ( & expected_ty) {
53+ self . result
54+ . closure_info
55+ . entry ( * closure_id)
56+ . or_insert_with ( || ( Vec :: new ( ) , closure_kind) ) ;
57+ }
58+ }
59+
5060 // Deduction from where-clauses in scope, as well as fn-pointer coercion are handled here.
5161 let _ = self . coerce ( Some ( closure_expr) , closure_ty, & expected_ty) ;
5262
@@ -65,6 +75,60 @@ impl InferenceContext<'_> {
6575 }
6676 }
6777
78+ // Closure kind deductions are mostly from `rustc_hir_typeck/src/closure.rs`.
79+ // Might need to port closure sig deductions too.
80+ fn deduce_closure_kind_from_expectations ( & mut self , expected_ty : & Ty ) -> Option < FnTrait > {
81+ match expected_ty. kind ( Interner ) {
82+ TyKind :: Alias ( AliasTy :: Opaque ( OpaqueTy { .. } ) ) | TyKind :: OpaqueType ( ..) => {
83+ let clauses = expected_ty
84+ . impl_trait_bounds ( self . db )
85+ . into_iter ( )
86+ . flatten ( )
87+ . map ( |b| b. into_value_and_skipped_binders ( ) . 0 ) ;
88+ self . deduce_closure_kind_from_predicate_clauses ( clauses)
89+ }
90+ TyKind :: Dyn ( dyn_ty) => dyn_ty. principal ( ) . and_then ( |trait_ref| {
91+ self . fn_trait_kind_from_trait_id ( from_chalk_trait_id ( trait_ref. trait_id ) )
92+ } ) ,
93+ TyKind :: InferenceVar ( ty, chalk_ir:: TyVariableKind :: General ) => {
94+ let clauses = self . clauses_for_self_ty ( * ty) ;
95+ self . deduce_closure_kind_from_predicate_clauses ( clauses. into_iter ( ) )
96+ }
97+ TyKind :: Function ( _) => Some ( FnTrait :: Fn ) ,
98+ _ => None ,
99+ }
100+ }
101+
102+ fn deduce_closure_kind_from_predicate_clauses (
103+ & self ,
104+ clauses : impl DoubleEndedIterator < Item = WhereClause > ,
105+ ) -> Option < FnTrait > {
106+ let mut expected_kind = None ;
107+
108+ for clause in elaborate_clause_supertraits ( self . db , clauses. rev ( ) ) {
109+ let trait_id = match clause {
110+ WhereClause :: AliasEq ( AliasEq {
111+ alias : AliasTy :: Projection ( projection) , ..
112+ } ) => Some ( projection. trait_ ( self . db ) ) ,
113+ WhereClause :: Implemented ( trait_ref) => {
114+ Some ( from_chalk_trait_id ( trait_ref. trait_id ) )
115+ }
116+ _ => None ,
117+ } ;
118+ if let Some ( closure_kind) =
119+ trait_id. and_then ( |trait_id| self . fn_trait_kind_from_trait_id ( trait_id) )
120+ {
121+ // `FnX`'s variants order is opposite from rustc, so use `cmp::max` instead of `cmp::min`
122+ expected_kind = Some (
123+ expected_kind
124+ . map_or_else ( || closure_kind, |current| cmp:: max ( current, closure_kind) ) ,
125+ ) ;
126+ }
127+ }
128+
129+ expected_kind
130+ }
131+
68132 fn deduce_sig_from_dyn_ty ( & self , dyn_ty : & DynTy ) -> Option < FnPointer > {
69133 // Search for a predicate like `<$self as FnX<Args>>::Output == Ret`
70134
@@ -111,6 +175,10 @@ impl InferenceContext<'_> {
111175
112176 None
113177 }
178+
179+ fn fn_trait_kind_from_trait_id ( & self , trait_id : hir_def:: TraitId ) -> Option < FnTrait > {
180+ FnTrait :: from_lang_item ( self . db . lang_attr ( trait_id. into ( ) ) ?)
181+ }
114182}
115183
116184// The below functions handle capture and closure kind (Fn, FnMut, ..)
@@ -962,8 +1030,14 @@ impl InferenceContext<'_> {
9621030 }
9631031 }
9641032 self . restrict_precision_for_unsafe ( ) ;
965- // closure_kind should be done before adjust_for_move_closure
966- let closure_kind = self . closure_kind ( ) ;
1033+ // `closure_kind` should be done before adjust_for_move_closure
1034+ // If there exists pre-deduced kind of a closure, use it instead of one determined by capture, as rustc does.
1035+ // rustc also does diagnostics here if the latter is not a subtype of the former.
1036+ let closure_kind = self
1037+ . result
1038+ . closure_info
1039+ . get ( & closure)
1040+ . map_or_else ( || self . closure_kind ( ) , |info| info. 1 ) ;
9671041 match capture_by {
9681042 CaptureBy :: Value => self . adjust_for_move_closure ( ) ,
9691043 CaptureBy :: Ref => ( ) ,
0 commit comments