diff --git a/compiler/rustc_infer/src/infer/error_reporting/mod.rs b/compiler/rustc_infer/src/infer/error_reporting/mod.rs index d7505717bf3d2..6c57e7f4347f2 100644 --- a/compiler/rustc_infer/src/infer/error_reporting/mod.rs +++ b/compiler/rustc_infer/src/infer/error_reporting/mod.rs @@ -316,37 +316,6 @@ pub fn unexpected_hidden_region_diagnostic<'tcx>( err } -/// Structurally compares two types, modulo any inference variables. -/// -/// Returns `true` if two types are equal, or if one type is an inference variable compatible -/// with the other type. A TyVar inference type is compatible with any type, and an IntVar or -/// FloatVar inference type are compatible with themselves or their concrete types (Int and -/// Float types, respectively). When comparing two ADTs, these rules apply recursively. -pub fn same_type_modulo_infer<'tcx>(a: Ty<'tcx>, b: Ty<'tcx>) -> bool { - match (&a.kind(), &b.kind()) { - (&ty::Adt(did_a, substs_a), &ty::Adt(did_b, substs_b)) => { - if did_a != did_b { - return false; - } - - substs_a.types().zip(substs_b.types()).all(|(a, b)| same_type_modulo_infer(a, b)) - } - (&ty::Int(_), &ty::Infer(ty::InferTy::IntVar(_))) - | (&ty::Infer(ty::InferTy::IntVar(_)), &ty::Int(_) | &ty::Infer(ty::InferTy::IntVar(_))) - | (&ty::Float(_), &ty::Infer(ty::InferTy::FloatVar(_))) - | ( - &ty::Infer(ty::InferTy::FloatVar(_)), - &ty::Float(_) | &ty::Infer(ty::InferTy::FloatVar(_)), - ) - | (&ty::Infer(ty::InferTy::TyVar(_)), _) - | (_, &ty::Infer(ty::InferTy::TyVar(_))) => true, - (&ty::Ref(_, ty_a, mut_a), &ty::Ref(_, ty_b, mut_b)) => { - mut_a == mut_b && same_type_modulo_infer(*ty_a, *ty_b) - } - _ => a == b, - } -} - impl<'a, 'tcx> InferCtxt<'a, 'tcx> { pub fn report_region_errors( &self, @@ -1723,15 +1692,14 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> { }; debug!("exp_found {:?} terr {:?} cause.code {:?}", exp_found, terr, cause.code()); if let Some(exp_found) = exp_found { - let should_suggest_fixes = if let ObligationCauseCode::Pattern { root_ty, .. } = - cause.code() - { - // Skip if the root_ty of the pattern is not the same as the expected_ty. - // If these types aren't equal then we've probably peeled off a layer of arrays. - same_type_modulo_infer(self.resolve_vars_if_possible(*root_ty), exp_found.expected) - } else { - true - }; + let should_suggest_fixes = + if let ObligationCauseCode::Pattern { root_ty, .. } = cause.code() { + // Skip if the root_ty of the pattern is not the same as the expected_ty. + // If these types aren't equal then we've probably peeled off a layer of arrays. + self.same_type_modulo_infer(*root_ty, exp_found.expected) + } else { + true + }; if should_suggest_fixes { self.suggest_tuple_pattern(cause, &exp_found, diag); @@ -1786,7 +1754,7 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> { .filter_map(|variant| { let sole_field = &variant.fields[0]; let sole_field_ty = sole_field.ty(self.tcx, substs); - if same_type_modulo_infer(sole_field_ty, exp_found.found) { + if self.same_type_modulo_infer(sole_field_ty, exp_found.found) { let variant_path = with_no_trimmed_paths!(self.tcx.def_path_str(variant.def_id)); // FIXME #56861: DRYer prelude filtering @@ -1902,39 +1870,41 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> { self.get_impl_future_output_ty(exp_found.expected).map(Binder::skip_binder), self.get_impl_future_output_ty(exp_found.found).map(Binder::skip_binder), ) { - (Some(exp), Some(found)) if same_type_modulo_infer(exp, found) => match cause.code() { - ObligationCauseCode::IfExpression(box IfExpressionCause { then, .. }) => { - diag.multipart_suggestion( - "consider `await`ing on both `Future`s", - vec![ - (then.shrink_to_hi(), ".await".to_string()), - (exp_span.shrink_to_hi(), ".await".to_string()), - ], - Applicability::MaybeIncorrect, - ); - } - ObligationCauseCode::MatchExpressionArm(box MatchExpressionArmCause { - prior_arms, - .. - }) => { - if let [.., arm_span] = &prior_arms[..] { + (Some(exp), Some(found)) if self.same_type_modulo_infer(exp, found) => { + match cause.code() { + ObligationCauseCode::IfExpression(box IfExpressionCause { then, .. }) => { diag.multipart_suggestion( "consider `await`ing on both `Future`s", vec![ - (arm_span.shrink_to_hi(), ".await".to_string()), + (then.shrink_to_hi(), ".await".to_string()), (exp_span.shrink_to_hi(), ".await".to_string()), ], Applicability::MaybeIncorrect, ); - } else { + } + ObligationCauseCode::MatchExpressionArm(box MatchExpressionArmCause { + prior_arms, + .. + }) => { + if let [.., arm_span] = &prior_arms[..] { + diag.multipart_suggestion( + "consider `await`ing on both `Future`s", + vec![ + (arm_span.shrink_to_hi(), ".await".to_string()), + (exp_span.shrink_to_hi(), ".await".to_string()), + ], + Applicability::MaybeIncorrect, + ); + } else { + diag.help("consider `await`ing on both `Future`s"); + } + } + _ => { diag.help("consider `await`ing on both `Future`s"); } } - _ => { - diag.help("consider `await`ing on both `Future`s"); - } - }, - (_, Some(ty)) if same_type_modulo_infer(exp_found.expected, ty) => { + } + (_, Some(ty)) if self.same_type_modulo_infer(exp_found.expected, ty) => { diag.span_suggestion_verbose( exp_span.shrink_to_hi(), "consider `await`ing on the `Future`", @@ -1942,7 +1912,8 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> { Applicability::MaybeIncorrect, ); } - (Some(ty), _) if same_type_modulo_infer(ty, exp_found.found) => match cause.code() { + (Some(ty), _) if self.same_type_modulo_infer(ty, exp_found.found) => match cause.code() + { ObligationCauseCode::Pattern { span: Some(span), .. } | ObligationCauseCode::IfExpression(box IfExpressionCause { then: span, .. }) => { diag.span_suggestion_verbose( @@ -1992,7 +1963,7 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> { .iter() .filter(|field| field.vis.is_accessible_from(field.did, self.tcx)) .map(|field| (field.name, field.ty(self.tcx, expected_substs))) - .find(|(_, ty)| same_type_modulo_infer(*ty, exp_found.found)) + .find(|(_, ty)| self.same_type_modulo_infer(*ty, exp_found.found)) { if let ObligationCauseCode::Pattern { span: Some(span), .. } = *cause.code() { if let Ok(snippet) = self.tcx.sess.source_map().span_to_snippet(span) { @@ -2057,7 +2028,7 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> { | (_, ty::Infer(_)) | (ty::Param(_), _) | (ty::Infer(_), _) => {} - _ if same_type_modulo_infer(exp_ty, found_ty) => {} + _ if self.same_type_modulo_infer(exp_ty, found_ty) => {} _ => show_suggestion = false, }; } @@ -2179,7 +2150,7 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> { ) { let [expected_tup_elem] = expected_fields[..] else { return }; - if !same_type_modulo_infer(expected_tup_elem, found) { + if !self.same_type_modulo_infer(expected_tup_elem, found) { return; } @@ -2647,6 +2618,45 @@ impl<'a, 'tcx> InferCtxt<'a, 'tcx> { span.is_desugaring(DesugaringKind::QuestionMark) && self.tcx.is_diagnostic_item(sym::From, trait_def_id) } + + /// Structurally compares two types, modulo any inference variables. + /// + /// Returns `true` if two types are equal, or if one type is an inference variable compatible + /// with the other type. A TyVar inference type is compatible with any type, and an IntVar or + /// FloatVar inference type are compatible with themselves or their concrete types (Int and + /// Float types, respectively). When comparing two ADTs, these rules apply recursively. + pub fn same_type_modulo_infer(&self, a: Ty<'tcx>, b: Ty<'tcx>) -> bool { + let (a, b) = self.resolve_vars_if_possible((a, b)); + match (&a.kind(), &b.kind()) { + (&ty::Adt(did_a, substs_a), &ty::Adt(did_b, substs_b)) => { + if did_a != did_b { + return false; + } + + substs_a + .types() + .zip(substs_b.types()) + .all(|(a, b)| self.same_type_modulo_infer(a, b)) + } + (&ty::Int(_) | &ty::Uint(_), &ty::Infer(ty::InferTy::IntVar(_))) + | ( + &ty::Infer(ty::InferTy::IntVar(_)), + &ty::Int(_) | &ty::Uint(_) | &ty::Infer(ty::InferTy::IntVar(_)), + ) + | (&ty::Float(_), &ty::Infer(ty::InferTy::FloatVar(_))) + | ( + &ty::Infer(ty::InferTy::FloatVar(_)), + &ty::Float(_) | &ty::Infer(ty::InferTy::FloatVar(_)), + ) + | (&ty::Infer(ty::InferTy::TyVar(_)), _) + | (_, &ty::Infer(ty::InferTy::TyVar(_))) => true, + (&ty::Ref(_, ty_a, mut_a), &ty::Ref(_, ty_b, mut_b)) => { + mut_a == mut_b && self.same_type_modulo_infer(*ty_a, *ty_b) + } + // FIXME(compiler-errors): This needs to be generalized more + _ => a == b, + } + } } impl<'a, 'tcx> InferCtxt<'a, 'tcx> { diff --git a/compiler/rustc_trait_selection/src/traits/error_reporting/mod.rs b/compiler/rustc_trait_selection/src/traits/error_reporting/mod.rs index 29df771b95780..1fbc904eb48e8 100644 --- a/compiler/rustc_trait_selection/src/traits/error_reporting/mod.rs +++ b/compiler/rustc_trait_selection/src/traits/error_reporting/mod.rs @@ -22,7 +22,6 @@ use rustc_hir::intravisit::Visitor; use rustc_hir::GenericParam; use rustc_hir::Item; use rustc_hir::Node; -use rustc_infer::infer::error_reporting::same_type_modulo_infer; use rustc_infer::traits::TraitEngine; use rustc_middle::traits::select::OverflowError; use rustc_middle::ty::abstract_const::NotConstEvaluatable; @@ -640,7 +639,7 @@ impl<'a, 'tcx> InferCtxtExt<'tcx> for InferCtxt<'a, 'tcx> { if expected.len() == 1 { "" } else { "s" }, ) ); - } else if !same_type_modulo_infer(given_ty, expected_ty) { + } else if !self.same_type_modulo_infer(given_ty, expected_ty) { // Print type mismatch let (expected_args, given_args) = self.cmp(given_ty, expected_ty);