11//! Unification and canonicalization logic.
22
3- use std:: { fmt, iter, mem} ;
3+ use std:: { cmp , fmt, iter, mem} ;
44
55use chalk_ir:: {
66 cast:: Cast , fold:: TypeFoldable , interner:: HasInterner , zip:: Zip , CanonicalVarKind , FloatTy ,
7- IntTy , TyVariableKind , UniverseIndex ,
7+ IntTy , TyVariableKind , UniverseIndex , WhereClause ,
88} ;
99use chalk_solve:: infer:: ParameterEnaVariableExt ;
1010use either:: Either ;
@@ -14,11 +14,12 @@ use triomphe::Arc;
1414
1515use super :: { InferOk , InferResult , InferenceContext , TypeError } ;
1616use crate :: {
17- consteval:: unknown_const, db:: HirDatabase , fold_tys_and_consts, static_lifetime,
18- to_chalk_trait_id, traits:: FnTrait , AliasEq , AliasTy , BoundVar , Canonical , Const , ConstValue ,
19- DebruijnIndex , GenericArg , GenericArgData , Goal , Guidance , InEnvironment , InferenceVar ,
20- Interner , Lifetime , ParamKind , ProjectionTy , ProjectionTyExt , Scalar , Solution , Substitution ,
21- TraitEnvironment , Ty , TyBuilder , TyExt , TyKind , VariableKind ,
17+ chalk_db:: TraitId , consteval:: unknown_const, db:: HirDatabase , fold_tys_and_consts,
18+ static_lifetime, to_chalk_trait_id, traits:: FnTrait , AliasEq , AliasTy , BoundVar , Canonical ,
19+ ClosureId , Const , ConstValue , DebruijnIndex , DomainGoal , GenericArg , GenericArgData , Goal ,
20+ GoalData , Guidance , InEnvironment , InferenceVar , Interner , Lifetime , ParamKind , ProjectionTy ,
21+ ProjectionTyExt , Scalar , Solution , Substitution , TraitEnvironment , Ty , TyBuilder , TyExt ,
22+ TyKind , VariableKind ,
2223} ;
2324
2425impl InferenceContext < ' _ > {
@@ -146,6 +147,8 @@ pub(crate) struct InferenceTable<'a> {
146147 /// Double buffer used in [`Self::resolve_obligations_as_possible`] to cut down on
147148 /// temporary allocations.
148149 resolve_obligations_buffer : Vec < Canonicalized < InEnvironment < Goal > > > ,
150+ fn_trait_predicates : Vec < ( Ty , FnTrait ) > ,
151+ cached_fn_trait_ids : Option < CachedFnTraitIds > ,
149152}
150153
151154pub ( crate ) struct InferenceTableSnapshot {
@@ -154,15 +157,36 @@ pub(crate) struct InferenceTableSnapshot {
154157 type_variable_table_snapshot : Vec < TypeVariableFlags > ,
155158}
156159
160+ #[ derive( Clone ) ]
161+ struct CachedFnTraitIds {
162+ fn_trait : TraitId ,
163+ fn_mut_trait : TraitId ,
164+ fn_once_trait : TraitId ,
165+ }
166+
167+ impl CachedFnTraitIds {
168+ fn new ( db : & dyn HirDatabase , trait_env : & Arc < TraitEnvironment > ) -> Option < Self > {
169+ let fn_trait = FnTrait :: Fn . get_id ( db, trait_env. krate . clone ( ) ) . map ( to_chalk_trait_id) ?;
170+ let fn_mut_trait =
171+ FnTrait :: FnMut . get_id ( db, trait_env. krate . clone ( ) ) . map ( to_chalk_trait_id) ?;
172+ let fn_once_trait =
173+ FnTrait :: FnOnce . get_id ( db, trait_env. krate . clone ( ) ) . map ( to_chalk_trait_id) ?;
174+ Some ( Self { fn_trait, fn_mut_trait, fn_once_trait } )
175+ }
176+ }
177+
157178impl < ' a > InferenceTable < ' a > {
158179 pub ( crate ) fn new ( db : & ' a dyn HirDatabase , trait_env : Arc < TraitEnvironment > ) -> Self {
180+ let cached_fn_trait_ids = CachedFnTraitIds :: new ( db, & trait_env) ;
159181 InferenceTable {
160182 db,
161183 trait_env,
162184 var_unification_table : ChalkInferenceTable :: new ( ) ,
163185 type_variable_table : Vec :: new ( ) ,
164186 pending_obligations : Vec :: new ( ) ,
165187 resolve_obligations_buffer : Vec :: new ( ) ,
188+ fn_trait_predicates : Vec :: new ( ) ,
189+ cached_fn_trait_ids,
166190 }
167191 }
168192
@@ -498,6 +522,22 @@ impl<'a> InferenceTable<'a> {
498522 }
499523
500524 fn register_obligation_in_env ( & mut self , goal : InEnvironment < Goal > ) {
525+ if let Some ( fn_trait_ids) = & self . cached_fn_trait_ids {
526+ if let GoalData :: DomainGoal ( DomainGoal :: Holds ( WhereClause :: Implemented ( trait_ref) ) ) =
527+ goal. goal . data ( Interner )
528+ {
529+ if let Some ( ty) = trait_ref. substitution . type_parameters ( Interner ) . next ( ) {
530+ if trait_ref. trait_id == fn_trait_ids. fn_trait {
531+ self . fn_trait_predicates . push ( ( ty, FnTrait :: Fn ) ) ;
532+ } else if trait_ref. trait_id == fn_trait_ids. fn_mut_trait {
533+ self . fn_trait_predicates . push ( ( ty, FnTrait :: FnMut ) ) ;
534+ } else if trait_ref. trait_id == fn_trait_ids. fn_once_trait {
535+ self . fn_trait_predicates . push ( ( ty, FnTrait :: FnOnce ) ) ;
536+ }
537+ }
538+ }
539+ }
540+
501541 let canonicalized = self . canonicalize ( goal) ;
502542 if !self . try_resolve_obligation ( & canonicalized) {
503543 self . pending_obligations . push ( canonicalized) ;
@@ -791,6 +831,23 @@ impl<'a> InferenceTable<'a> {
791831 _ => c,
792832 }
793833 }
834+
835+ pub ( super ) fn get_closure_fn_trait_predicate (
836+ & mut self ,
837+ closure_id : ClosureId ,
838+ ) -> Option < FnTrait > {
839+ let predicates = mem:: take ( & mut self . fn_trait_predicates ) ;
840+ let res = predicates. iter ( ) . filter_map ( |( ty, fn_trait) | {
841+ if matches ! ( self . resolve_completely( ty. clone( ) ) . kind( Interner ) , TyKind :: Closure ( c, ..) if * c == closure_id) {
842+ Some ( * fn_trait)
843+ } else {
844+ None
845+ }
846+ } ) . fold ( None , |acc, x| Some ( cmp:: max ( acc. unwrap_or ( FnTrait :: FnOnce ) , x) ) ) ;
847+ self . fn_trait_predicates = predicates;
848+
849+ return res;
850+ }
794851}
795852
796853impl fmt:: Debug for InferenceTable < ' _ > {
0 commit comments