diff --git a/compiler/rustc_infer/src/infer/error_reporting/mod.rs b/compiler/rustc_infer/src/infer/error_reporting/mod.rs index abc25d51776f3..fb6adccdf497b 100644 --- a/compiler/rustc_infer/src/infer/error_reporting/mod.rs +++ b/compiler/rustc_infer/src/infer/error_reporting/mod.rs @@ -1086,7 +1086,11 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> { /// Compares two given types, eliding parts that are the same between them and highlighting /// relevant differences, and return two representation of those types for highlighted printing. - fn cmp(&self, t1: Ty<'tcx>, t2: Ty<'tcx>) -> (DiagnosticStyledString, DiagnosticStyledString) { + pub fn cmp( + &self, + t1: Ty<'tcx>, + t2: Ty<'tcx>, + ) -> (DiagnosticStyledString, DiagnosticStyledString) { debug!("cmp(t1={}, t1.kind={:?}, t2={}, t2.kind={:?})", t1, t1.kind(), t2, t2.kind()); // helper functions diff --git a/compiler/rustc_middle/src/ty/closure.rs b/compiler/rustc_middle/src/ty/closure.rs index 1446f7dac3638..3bddf7fb6ffc4 100644 --- a/compiler/rustc_middle/src/ty/closure.rs +++ b/compiler/rustc_middle/src/ty/closure.rs @@ -119,9 +119,21 @@ impl<'tcx> ClosureKind { /// See `Ty::to_opt_closure_kind` for more details. pub fn to_ty(self, tcx: TyCtxt<'tcx>) -> Ty<'tcx> { match self { - ty::ClosureKind::Fn => tcx.types.i8, - ty::ClosureKind::FnMut => tcx.types.i16, - ty::ClosureKind::FnOnce => tcx.types.i32, + ClosureKind::Fn => tcx.types.i8, + ClosureKind::FnMut => tcx.types.i16, + ClosureKind::FnOnce => tcx.types.i32, + } + } + + pub fn from_def_id(tcx: TyCtxt<'_>, def_id: DefId) -> Option { + if Some(def_id) == tcx.lang_items().fn_once_trait() { + Some(ClosureKind::FnOnce) + } else if Some(def_id) == tcx.lang_items().fn_mut_trait() { + Some(ClosureKind::FnMut) + } else if Some(def_id) == tcx.lang_items().fn_trait() { + Some(ClosureKind::Fn) + } else { + None } } } diff --git a/compiler/rustc_trait_selection/src/traits/error_reporting/mod.rs b/compiler/rustc_trait_selection/src/traits/error_reporting/mod.rs index 229e108d5d640..b727cd4a3cc8b 100644 --- a/compiler/rustc_trait_selection/src/traits/error_reporting/mod.rs +++ b/compiler/rustc_trait_selection/src/traits/error_reporting/mod.rs @@ -2,10 +2,10 @@ pub mod on_unimplemented; pub mod suggestions; use super::{ - EvaluationResult, FulfillmentError, FulfillmentErrorCode, MismatchedProjectionTypes, - Obligation, ObligationCause, ObligationCauseCode, OnUnimplementedDirective, - OnUnimplementedNote, OutputTypeParameterMismatch, Overflow, PredicateObligation, - SelectionContext, SelectionError, TraitNotObjectSafe, + EvaluationResult, FulfillmentContext, FulfillmentError, FulfillmentErrorCode, + MismatchedProjectionTypes, Obligation, ObligationCause, ObligationCauseCode, + OnUnimplementedDirective, OnUnimplementedNote, OutputTypeParameterMismatch, Overflow, + PredicateObligation, SelectionContext, SelectionError, TraitNotObjectSafe, }; use crate::infer::error_reporting::{TyCategory, TypeAnnotationNeeded as ErrorCode}; @@ -21,6 +21,8 @@ use rustc_hir::intravisit::Visitor; use rustc_hir::GenericParam; use rustc_hir::Item; use rustc_hir::Node; +use rustc_infer::infer::error_reporting::same_type_modulo_infer; +use rustc_infer::traits::TraitEngine; use rustc_middle::thir::abstract_const::NotConstEvaluatable; use rustc_middle::ty::error::ExpectedFound; use rustc_middle::ty::fold::TypeFolder; @@ -103,6 +105,17 @@ pub trait InferCtxtExt<'tcx> { found_args: Vec, is_closure: bool, ) -> DiagnosticBuilder<'tcx, ErrorGuaranteed>; + + /// Checks if the type implements one of `Fn`, `FnMut`, or `FnOnce` + /// in that order, and returns the generic type corresponding to the + /// argument of that trait (corresponding to the closure arguments). + fn type_implements_fn_trait( + &self, + param_env: ty::ParamEnv<'tcx>, + ty: ty::Binder<'tcx, Ty<'tcx>>, + constness: ty::BoundConstness, + polarity: ty::ImplPolarity, + ) -> Result<(ty::ClosureKind, ty::Binder<'tcx, Ty<'tcx>>), ()>; } impl<'a, 'tcx> InferCtxtExt<'tcx> for InferCtxt<'a, 'tcx> { @@ -563,7 +576,64 @@ impl<'a, 'tcx> InferCtxtExt<'tcx> for InferCtxt<'a, 'tcx> { } // Try to report a help message - if !trait_ref.has_infer_types_or_consts() + if is_fn_trait + && let Ok((implemented_kind, params)) = self.type_implements_fn_trait( + obligation.param_env, + trait_ref.self_ty(), + trait_predicate.skip_binder().constness, + trait_predicate.skip_binder().polarity, + ) + { + // If the type implements `Fn`, `FnMut`, or `FnOnce`, suppress the following + // suggestion to add trait bounds for the type, since we only typically implement + // these traits once. + + // Note if the `FnMut` or `FnOnce` is less general than the trait we're trying + // to implement. + let selected_kind = + ty::ClosureKind::from_def_id(self.tcx, trait_ref.def_id()) + .expect("expected to map DefId to ClosureKind"); + if !implemented_kind.extends(selected_kind) { + err.note( + &format!( + "`{}` implements `{}`, but it must implement `{}`, which is more general", + trait_ref.skip_binder().self_ty(), + implemented_kind, + selected_kind + ) + ); + } + + // Note any argument mismatches + let given_ty = params.skip_binder(); + let expected_ty = trait_ref.skip_binder().substs.type_at(1); + if let ty::Tuple(given) = given_ty.kind() + && let ty::Tuple(expected) = expected_ty.kind() + { + if expected.len() != given.len() { + // Note number of types that were expected and given + err.note( + &format!( + "expected a closure taking {} argument{}, but one taking {} argument{} was given", + given.len(), + if given.len() == 1 { "" } else { "s" }, + expected.len(), + if expected.len() == 1 { "" } else { "s" }, + ) + ); + } else if !same_type_modulo_infer(given_ty, expected_ty) { + // Print type mismatch + let (expected_args, given_args) = + self.cmp(given_ty, expected_ty); + err.note_expected_found( + &"a closure with arguments", + expected_args, + &"a closure with arguments", + given_args, + ); + } + } + } else if !trait_ref.has_infer_types_or_consts() && self.predicate_can_apply(obligation.param_env, trait_ref) { // If a where-clause may be useful, remind the @@ -1144,6 +1214,52 @@ impl<'a, 'tcx> InferCtxtExt<'tcx> for InferCtxt<'a, 'tcx> { err } + + fn type_implements_fn_trait( + &self, + param_env: ty::ParamEnv<'tcx>, + ty: ty::Binder<'tcx, Ty<'tcx>>, + constness: ty::BoundConstness, + polarity: ty::ImplPolarity, + ) -> Result<(ty::ClosureKind, ty::Binder<'tcx, Ty<'tcx>>), ()> { + self.commit_if_ok(|_| { + for trait_def_id in [ + self.tcx.lang_items().fn_trait(), + self.tcx.lang_items().fn_mut_trait(), + self.tcx.lang_items().fn_once_trait(), + ] { + let Some(trait_def_id) = trait_def_id else { continue }; + // Make a fresh inference variable so we can determine what the substitutions + // of the trait are. + let var = self.next_ty_var(TypeVariableOrigin { + span: DUMMY_SP, + kind: TypeVariableOriginKind::MiscVariable, + }); + let substs = self.tcx.mk_substs_trait(ty.skip_binder(), &[var.into()]); + let obligation = Obligation::new( + ObligationCause::dummy(), + param_env, + ty.rebind(ty::TraitPredicate { + trait_ref: ty::TraitRef::new(trait_def_id, substs), + constness, + polarity, + }) + .to_predicate(self.tcx), + ); + let mut fulfill_cx = FulfillmentContext::new_in_snapshot(); + fulfill_cx.register_predicate_obligation(self, obligation); + if fulfill_cx.select_all_or_error(self).is_empty() { + return Ok(( + ty::ClosureKind::from_def_id(self.tcx, trait_def_id) + .expect("expected to map DefId to ClosureKind"), + ty.rebind(self.resolve_vars_if_possible(var)), + )); + } + } + + Err(()) + }) + } } trait InferCtxtPrivExt<'hir, 'tcx> { diff --git a/src/test/ui/higher-rank-trait-bounds/normalize-under-binder/issue-62529-3.stderr b/src/test/ui/higher-rank-trait-bounds/normalize-under-binder/issue-62529-3.stderr index b110734642177..066bf431a83d7 100644 --- a/src/test/ui/higher-rank-trait-bounds/normalize-under-binder/issue-62529-3.stderr +++ b/src/test/ui/higher-rank-trait-bounds/normalize-under-binder/issue-62529-3.stderr @@ -4,6 +4,8 @@ error[E0277]: expected a `Fn<(<_ as ATC<'a>>::Type,)>` closure, found `F` LL | call(f, ()); | ^^^^ expected an `Fn<(<_ as ATC<'a>>::Type,)>` closure, found `F` | + = note: expected a closure with arguments `((),)` + found a closure with arguments `(<_ as ATC<'a>>::Type,)` note: required by a bound in `call` --> $DIR/issue-62529-3.rs:9:36 | diff --git a/src/test/ui/issues/issue-59494.stderr b/src/test/ui/issues/issue-59494.stderr index a9284535e4dc4..8b542bb69de2e 100644 --- a/src/test/ui/issues/issue-59494.stderr +++ b/src/test/ui/issues/issue-59494.stderr @@ -7,6 +7,8 @@ LL | let t8 = t8n(t7, t7p(f, g)); | required by a bound introduced by this call | = help: the trait `Fn<(_,)>` is not implemented for `impl Fn(((_, _), _))` + = note: expected a closure with arguments `(((_, _), _),)` + found a closure with arguments `(_,)` note: required by a bound in `t8n` --> $DIR/issue-59494.rs:5:45 | diff --git a/src/test/ui/trait-bounds/mismatch-fn-trait.rs b/src/test/ui/trait-bounds/mismatch-fn-trait.rs new file mode 100644 index 0000000000000..0ed64043a9a56 --- /dev/null +++ b/src/test/ui/trait-bounds/mismatch-fn-trait.rs @@ -0,0 +1,28 @@ +fn take(_f: impl FnMut(i32)) {} + +fn test1(f: impl FnMut(u32)) { + take(f) + //~^ ERROR [E0277] +} + +fn test2(f: impl FnMut(i32, i32)) { + take(f) + //~^ ERROR [E0277] +} + +fn test3(f: impl FnMut()) { + take(f) + //~^ ERROR [E0277] +} + +fn test4(f: impl FnOnce(i32)) { + take(f) + //~^ ERROR [E0277] +} + +fn test5(f: impl FnOnce(u32)) { + take(f) + //~^ ERROR [E0277] +} + +fn main() {} diff --git a/src/test/ui/trait-bounds/mismatch-fn-trait.stderr b/src/test/ui/trait-bounds/mismatch-fn-trait.stderr new file mode 100644 index 0000000000000..961e6d88fbef4 --- /dev/null +++ b/src/test/ui/trait-bounds/mismatch-fn-trait.stderr @@ -0,0 +1,81 @@ +error[E0277]: expected a `FnMut<(i32,)>` closure, found `impl FnMut(u32)` + --> $DIR/mismatch-fn-trait.rs:4:10 + | +LL | take(f) + | ---- ^ expected an `FnMut<(i32,)>` closure, found `impl FnMut(u32)` + | | + | required by a bound introduced by this call + | + = note: expected a closure with arguments `(u32,)` + found a closure with arguments `(i32,)` +note: required by a bound in `take` + --> $DIR/mismatch-fn-trait.rs:1:18 + | +LL | fn take(_f: impl FnMut(i32)) {} + | ^^^^^^^^^^ required by this bound in `take` + +error[E0277]: expected a `FnMut<(i32,)>` closure, found `impl FnMut(i32, i32)` + --> $DIR/mismatch-fn-trait.rs:9:10 + | +LL | take(f) + | ---- ^ expected an `FnMut<(i32,)>` closure, found `impl FnMut(i32, i32)` + | | + | required by a bound introduced by this call + | + = note: expected a closure taking 2 arguments, but one taking 1 argument was given +note: required by a bound in `take` + --> $DIR/mismatch-fn-trait.rs:1:18 + | +LL | fn take(_f: impl FnMut(i32)) {} + | ^^^^^^^^^^ required by this bound in `take` + +error[E0277]: expected a `FnMut<(i32,)>` closure, found `impl FnMut()` + --> $DIR/mismatch-fn-trait.rs:14:10 + | +LL | take(f) + | ---- ^ expected an `FnMut<(i32,)>` closure, found `impl FnMut()` + | | + | required by a bound introduced by this call + | + = note: expected a closure taking 0 arguments, but one taking 1 argument was given +note: required by a bound in `take` + --> $DIR/mismatch-fn-trait.rs:1:18 + | +LL | fn take(_f: impl FnMut(i32)) {} + | ^^^^^^^^^^ required by this bound in `take` + +error[E0277]: expected a `FnMut<(i32,)>` closure, found `impl FnOnce(i32)` + --> $DIR/mismatch-fn-trait.rs:19:10 + | +LL | take(f) + | ---- ^ expected an `FnMut<(i32,)>` closure, found `impl FnOnce(i32)` + | | + | required by a bound introduced by this call + | + = note: `impl FnOnce(i32)` implements `FnOnce`, but it must implement `FnMut`, which is more general +note: required by a bound in `take` + --> $DIR/mismatch-fn-trait.rs:1:18 + | +LL | fn take(_f: impl FnMut(i32)) {} + | ^^^^^^^^^^ required by this bound in `take` + +error[E0277]: expected a `FnMut<(i32,)>` closure, found `impl FnOnce(u32)` + --> $DIR/mismatch-fn-trait.rs:24:10 + | +LL | take(f) + | ---- ^ expected an `FnMut<(i32,)>` closure, found `impl FnOnce(u32)` + | | + | required by a bound introduced by this call + | + = note: `impl FnOnce(u32)` implements `FnOnce`, but it must implement `FnMut`, which is more general + = note: expected a closure with arguments `(u32,)` + found a closure with arguments `(i32,)` +note: required by a bound in `take` + --> $DIR/mismatch-fn-trait.rs:1:18 + | +LL | fn take(_f: impl FnMut(i32)) {} + | ^^^^^^^^^^ required by this bound in `take` + +error: aborting due to 5 previous errors + +For more information about this error, try `rustc --explain E0277`. diff --git a/src/test/ui/unboxed-closures/unboxed-closures-fnmut-as-fn.stderr b/src/test/ui/unboxed-closures/unboxed-closures-fnmut-as-fn.stderr index f379d73eecff7..0ea1c1dcd5bde 100644 --- a/src/test/ui/unboxed-closures/unboxed-closures-fnmut-as-fn.stderr +++ b/src/test/ui/unboxed-closures/unboxed-closures-fnmut-as-fn.stderr @@ -7,6 +7,7 @@ LL | let x = call_it(&S, 22); | required by a bound introduced by this call | = help: the trait `Fn<(isize,)>` is not implemented for `S` + = note: `S` implements `FnMut`, but it must implement `Fn`, which is more general note: required by a bound in `call_it` --> $DIR/unboxed-closures-fnmut-as-fn.rs:22:14 |