@@ -5,7 +5,7 @@ use std::{cmp, convert::Infallible, mem};
55use chalk_ir:: {
66 cast:: Cast ,
77 fold:: { FallibleTypeFolder , TypeFoldable } ,
8- AliasEq , AliasTy , BoundVar , DebruijnIndex , FnSubst , Mutability , TyKind , WhereClause ,
8+ BoundVar , DebruijnIndex , FnSubst , Mutability , TyKind ,
99} ;
1010use either:: Either ;
1111use hir_def:: {
@@ -22,13 +22,14 @@ use stdx::never;
2222
2323use crate :: {
2424 db:: { HirDatabase , InternedClosure } ,
25- from_placeholder_idx, make_binders,
25+ from_chalk_trait_id , from_placeholder_idx, make_binders,
2626 mir:: { BorrowKind , MirSpan , MutBorrowKind , ProjectionElem } ,
2727 static_lifetime, to_chalk_trait_id,
2828 traits:: FnTrait ,
29- utils:: { self , generics, Generics } ,
30- Adjust , Adjustment , Binders , BindingMode , ChalkTraitId , ClosureId , DynTy , FnAbi , FnPointer ,
31- FnSig , Interner , Substitution , Ty , TyExt ,
29+ utils:: { self , elaborate_clause_supertraits, generics, Generics } ,
30+ Adjust , Adjustment , AliasEq , AliasTy , Binders , BindingMode , ChalkTraitId , ClosureId , DynTy ,
31+ DynTyExt , FnAbi , FnPointer , FnSig , Interner , OpaqueTy , ProjectionTyExt , Substitution , Ty ,
32+ TyExt , WhereClause ,
3233} ;
3334
3435use super :: { Expectation , InferenceContext } ;
@@ -47,6 +48,15 @@ impl InferenceContext<'_> {
4748 None => return ,
4849 } ;
4950
51+ if let TyKind :: Closure ( closure_id, _) = closure_ty. kind ( Interner ) {
52+ if let Some ( closure_kind) = self . deduce_closure_kind_from_expectations ( & expected_ty) {
53+ self . result
54+ . closure_info
55+ . entry ( * closure_id)
56+ . or_insert_with ( || ( Vec :: new ( ) , closure_kind) ) ;
57+ }
58+ }
59+
5060 // Deduction from where-clauses in scope, as well as fn-pointer coercion are handled here.
5161 let _ = self . coerce ( Some ( closure_expr) , closure_ty, & expected_ty) ;
5262
@@ -65,6 +75,60 @@ impl InferenceContext<'_> {
6575 }
6676 }
6777
78+ // Closure kind deductions are mostly from `rustc_hir_typeck/src/closure.rs`.
79+ // Might need to port closure sig deductions too.
80+ fn deduce_closure_kind_from_expectations ( & mut self , expected_ty : & Ty ) -> Option < FnTrait > {
81+ match expected_ty. kind ( Interner ) {
82+ TyKind :: Alias ( AliasTy :: Opaque ( OpaqueTy { .. } ) ) | TyKind :: OpaqueType ( ..) => {
83+ let clauses = expected_ty
84+ . impl_trait_bounds ( self . db )
85+ . into_iter ( )
86+ . flatten ( )
87+ . map ( |b| b. into_value_and_skipped_binders ( ) . 0 ) ;
88+ self . deduce_closure_kind_from_predicate_clauses ( clauses)
89+ }
90+ TyKind :: Dyn ( dyn_ty) => dyn_ty. principal ( ) . and_then ( |trait_ref| {
91+ self . fn_trait_kind_from_trait_id ( from_chalk_trait_id ( trait_ref. trait_id ) )
92+ } ) ,
93+ TyKind :: InferenceVar ( ty, chalk_ir:: TyVariableKind :: General ) => {
94+ let clauses = self . clauses_for_self_ty ( * ty) ;
95+ self . deduce_closure_kind_from_predicate_clauses ( clauses. into_iter ( ) )
96+ }
97+ TyKind :: Function ( _) => Some ( FnTrait :: Fn ) ,
98+ _ => None ,
99+ }
100+ }
101+
102+ fn deduce_closure_kind_from_predicate_clauses (
103+ & self ,
104+ clauses : impl DoubleEndedIterator < Item = WhereClause > ,
105+ ) -> Option < FnTrait > {
106+ let mut expected_kind = None ;
107+
108+ for clause in elaborate_clause_supertraits ( self . db , clauses. rev ( ) ) {
109+ let trait_id = match clause {
110+ WhereClause :: AliasEq ( AliasEq {
111+ alias : AliasTy :: Projection ( projection) , ..
112+ } ) => Some ( projection. trait_ ( self . db ) ) ,
113+ WhereClause :: Implemented ( trait_ref) => {
114+ Some ( from_chalk_trait_id ( trait_ref. trait_id ) )
115+ }
116+ _ => None ,
117+ } ;
118+ if let Some ( closure_kind) =
119+ trait_id. and_then ( |trait_id| self . fn_trait_kind_from_trait_id ( trait_id) )
120+ {
121+ // `FnX`'s variants order is opposite from rustc, so use `cmp::max` instead of `cmp::min`
122+ expected_kind = Some (
123+ expected_kind
124+ . map_or_else ( || closure_kind, |current| cmp:: max ( current, closure_kind) ) ,
125+ ) ;
126+ }
127+ }
128+
129+ expected_kind
130+ }
131+
68132 fn deduce_sig_from_dyn_ty ( & self , dyn_ty : & DynTy ) -> Option < FnPointer > {
69133 // Search for a predicate like `<$self as FnX<Args>>::Output == Ret`
70134
@@ -111,6 +175,18 @@ impl InferenceContext<'_> {
111175
112176 None
113177 }
178+
179+ fn fn_trait_kind_from_trait_id ( & self , trait_id : hir_def:: TraitId ) -> Option < FnTrait > {
180+ utils:: fn_traits ( self . db . upcast ( ) , self . owner . module ( self . db . upcast ( ) ) . krate ( ) )
181+ . enumerate ( )
182+ . find_map ( |( i, t) | ( t == trait_id) . then_some ( i) )
183+ . map ( |i| match i {
184+ 0 => FnTrait :: Fn ,
185+ 1 => FnTrait :: FnMut ,
186+ 2 => FnTrait :: FnOnce ,
187+ _ => unreachable ! ( ) ,
188+ } )
189+ }
114190}
115191
116192// The below functions handle capture and closure kind (Fn, FnMut, ..)
@@ -962,8 +1038,14 @@ impl InferenceContext<'_> {
9621038 }
9631039 }
9641040 self . restrict_precision_for_unsafe ( ) ;
965- // closure_kind should be done before adjust_for_move_closure
966- let closure_kind = self . closure_kind ( ) ;
1041+ // `closure_kind` should be done before adjust_for_move_closure
1042+ // If there exists pre-deduced kind of a closure, use it instead of one determined by capture, as rustc does.
1043+ // rustc also does diagnostics here if the latter is not a subtype of the former.
1044+ let closure_kind = self
1045+ . result
1046+ . closure_info
1047+ . get ( & closure)
1048+ . map_or_else ( || self . closure_kind ( ) , |info| info. 1 ) ;
9671049 match capture_by {
9681050 CaptureBy :: Value => self . adjust_for_move_closure ( ) ,
9691051 CaptureBy :: Ref => ( ) ,
0 commit comments