@@ -2,10 +2,10 @@ pub mod on_unimplemented;
2
2
pub mod suggestions;
3
3
4
4
use super :: {
5
- EvaluationResult , FulfillmentError , FulfillmentErrorCode , MismatchedProjectionTypes ,
6
- Obligation , ObligationCause , ObligationCauseCode , OnUnimplementedDirective ,
7
- OnUnimplementedNote , OutputTypeParameterMismatch , Overflow , PredicateObligation ,
8
- SelectionContext , SelectionError , TraitNotObjectSafe ,
5
+ EvaluationResult , FulfillmentContext , FulfillmentError , FulfillmentErrorCode ,
6
+ MismatchedProjectionTypes , Obligation , ObligationCause , ObligationCauseCode ,
7
+ OnUnimplementedDirective , OnUnimplementedNote , OutputTypeParameterMismatch , Overflow ,
8
+ PredicateObligation , SelectionContext , SelectionError , TraitNotObjectSafe ,
9
9
} ;
10
10
11
11
use crate :: infer:: error_reporting:: { TyCategory , TypeAnnotationNeeded as ErrorCode } ;
@@ -21,6 +21,8 @@ use rustc_hir::intravisit::Visitor;
21
21
use rustc_hir:: GenericParam ;
22
22
use rustc_hir:: Item ;
23
23
use rustc_hir:: Node ;
24
+ use rustc_infer:: infer:: error_reporting:: same_type_modulo_infer;
25
+ use rustc_infer:: traits:: TraitEngine ;
24
26
use rustc_middle:: thir:: abstract_const:: NotConstEvaluatable ;
25
27
use rustc_middle:: traits:: select:: OverflowError ;
26
28
use rustc_middle:: ty:: error:: ExpectedFound ;
@@ -103,6 +105,17 @@ pub trait InferCtxtExt<'tcx> {
103
105
found_args : Vec < ArgKind > ,
104
106
is_closure : bool ,
105
107
) -> DiagnosticBuilder < ' tcx , ErrorGuaranteed > ;
108
+
109
+ /// Checks if the type implements one of `Fn`, `FnMut`, or `FnOnce`
110
+ /// in that order, and returns the generic type corresponding to the
111
+ /// argument of that trait (corresponding to the closure arguments).
112
+ fn type_implements_fn_trait (
113
+ & self ,
114
+ param_env : ty:: ParamEnv < ' tcx > ,
115
+ ty : ty:: Binder < ' tcx , Ty < ' tcx > > ,
116
+ constness : ty:: BoundConstness ,
117
+ polarity : ty:: ImplPolarity ,
118
+ ) -> Result < ( ty:: ClosureKind , ty:: Binder < ' tcx , Ty < ' tcx > > ) , ( ) > ;
106
119
}
107
120
108
121
impl < ' a , ' tcx > InferCtxtExt < ' tcx > for InferCtxt < ' a , ' tcx > {
@@ -563,7 +576,64 @@ impl<'a, 'tcx> InferCtxtExt<'tcx> for InferCtxt<'a, 'tcx> {
563
576
}
564
577
565
578
// Try to report a help message
566
- if !trait_ref. has_infer_types_or_consts ( )
579
+ if is_fn_trait
580
+ && let Ok ( ( implemented_kind, params) ) = self . type_implements_fn_trait (
581
+ obligation. param_env ,
582
+ trait_ref. self_ty ( ) ,
583
+ trait_predicate. skip_binder ( ) . constness ,
584
+ trait_predicate. skip_binder ( ) . polarity ,
585
+ )
586
+ {
587
+ // If the type implements `Fn`, `FnMut`, or `FnOnce`, suppress the following
588
+ // suggestion to add trait bounds for the type, since we only typically implement
589
+ // these traits once.
590
+
591
+ // Note if the `FnMut` or `FnOnce` is less general than the trait we're trying
592
+ // to implement.
593
+ let selected_kind =
594
+ ty:: ClosureKind :: from_def_id ( self . tcx , trait_ref. def_id ( ) )
595
+ . expect ( "expected to map DefId to ClosureKind" ) ;
596
+ if !implemented_kind. extends ( selected_kind) {
597
+ err. note (
598
+ & format ! (
599
+ "`{}` implements `{}`, but it must implement `{}`, which is more general" ,
600
+ trait_ref. skip_binder( ) . self_ty( ) ,
601
+ implemented_kind,
602
+ selected_kind
603
+ )
604
+ ) ;
605
+ }
606
+
607
+ // Note any argument mismatches
608
+ let given_ty = params. skip_binder ( ) ;
609
+ let expected_ty = trait_ref. skip_binder ( ) . substs . type_at ( 1 ) ;
610
+ if let ty:: Tuple ( given) = given_ty. kind ( )
611
+ && let ty:: Tuple ( expected) = expected_ty. kind ( )
612
+ {
613
+ if expected. len ( ) != given. len ( ) {
614
+ // Note number of types that were expected and given
615
+ err. note (
616
+ & format ! (
617
+ "expected a closure taking {} argument{}, but one taking {} argument{} was given" ,
618
+ given. len( ) ,
619
+ if given. len( ) == 1 { "" } else { "s" } ,
620
+ expected. len( ) ,
621
+ if expected. len( ) == 1 { "" } else { "s" } ,
622
+ )
623
+ ) ;
624
+ } else if !same_type_modulo_infer ( given_ty, expected_ty) {
625
+ // Print type mismatch
626
+ let ( expected_args, given_args) =
627
+ self . cmp ( given_ty, expected_ty) ;
628
+ err. note_expected_found (
629
+ & "a closure with arguments" ,
630
+ expected_args,
631
+ & "a closure with arguments" ,
632
+ given_args,
633
+ ) ;
634
+ }
635
+ }
636
+ } else if !trait_ref. has_infer_types_or_consts ( )
567
637
&& self . predicate_can_apply ( obligation. param_env , trait_ref)
568
638
{
569
639
// If a where-clause may be useful, remind the
@@ -1148,6 +1218,52 @@ impl<'a, 'tcx> InferCtxtExt<'tcx> for InferCtxt<'a, 'tcx> {
1148
1218
1149
1219
err
1150
1220
}
1221
+
1222
+ fn type_implements_fn_trait (
1223
+ & self ,
1224
+ param_env : ty:: ParamEnv < ' tcx > ,
1225
+ ty : ty:: Binder < ' tcx , Ty < ' tcx > > ,
1226
+ constness : ty:: BoundConstness ,
1227
+ polarity : ty:: ImplPolarity ,
1228
+ ) -> Result < ( ty:: ClosureKind , ty:: Binder < ' tcx , Ty < ' tcx > > ) , ( ) > {
1229
+ self . commit_if_ok ( |_| {
1230
+ for trait_def_id in [
1231
+ self . tcx . lang_items ( ) . fn_trait ( ) ,
1232
+ self . tcx . lang_items ( ) . fn_mut_trait ( ) ,
1233
+ self . tcx . lang_items ( ) . fn_once_trait ( ) ,
1234
+ ] {
1235
+ let Some ( trait_def_id) = trait_def_id else { continue } ;
1236
+ // Make a fresh inference variable so we can determine what the substitutions
1237
+ // of the trait are.
1238
+ let var = self . next_ty_var ( TypeVariableOrigin {
1239
+ span : DUMMY_SP ,
1240
+ kind : TypeVariableOriginKind :: MiscVariable ,
1241
+ } ) ;
1242
+ let substs = self . tcx . mk_substs_trait ( ty. skip_binder ( ) , & [ var. into ( ) ] ) ;
1243
+ let obligation = Obligation :: new (
1244
+ ObligationCause :: dummy ( ) ,
1245
+ param_env,
1246
+ ty. rebind ( ty:: TraitPredicate {
1247
+ trait_ref : ty:: TraitRef :: new ( trait_def_id, substs) ,
1248
+ constness,
1249
+ polarity,
1250
+ } )
1251
+ . to_predicate ( self . tcx ) ,
1252
+ ) ;
1253
+ let mut fulfill_cx = FulfillmentContext :: new_in_snapshot ( ) ;
1254
+ fulfill_cx. register_predicate_obligation ( self , obligation) ;
1255
+ if fulfill_cx. select_all_or_error ( self ) . is_empty ( ) {
1256
+ return Ok ( (
1257
+ ty:: ClosureKind :: from_def_id ( self . tcx , trait_def_id)
1258
+ . expect ( "expected to map DefId to ClosureKind" ) ,
1259
+ ty. rebind ( self . resolve_vars_if_possible ( var) ) ,
1260
+ ) ) ;
1261
+ }
1262
+ }
1263
+
1264
+ Err ( ( ) )
1265
+ } )
1266
+ }
1151
1267
}
1152
1268
1153
1269
trait InferCtxtPrivExt < ' hir , ' tcx > {
0 commit comments