From 575f129faa5126869f11ef945276072b097a2b2a Mon Sep 17 00:00:00 2001 From: Michael Goulet Date: Sat, 22 Mar 2025 22:09:16 +0000 Subject: [PATCH 1/3] Obligation::as_goal --- .../src/fn_ctxt/inspect_obligations.rs | 4 ++-- compiler/rustc_infer/src/infer/opaque_types/mod.rs | 3 +-- compiler/rustc_infer/src/traits/mod.rs | 12 ++++++------ compiler/rustc_trait_selection/src/solve/delegate.rs | 2 +- compiler/rustc_trait_selection/src/solve/fulfill.rs | 4 ++-- .../src/solve/fulfill/derive_errors.rs | 8 ++++---- .../rustc_trait_selection/src/traits/coherence.rs | 2 +- 7 files changed, 17 insertions(+), 18 deletions(-) diff --git a/compiler/rustc_hir_typeck/src/fn_ctxt/inspect_obligations.rs b/compiler/rustc_hir_typeck/src/fn_ctxt/inspect_obligations.rs index 95b9cb3be627c..e068e60790277 100644 --- a/compiler/rustc_hir_typeck/src/fn_ctxt/inspect_obligations.rs +++ b/compiler/rustc_hir_typeck/src/fn_ctxt/inspect_obligations.rs @@ -1,7 +1,7 @@ //! A utility module to inspect currently ambiguous obligations in the current context. use rustc_infer::traits::{self, ObligationCause, PredicateObligations}; -use rustc_middle::traits::solve::{Goal, GoalSource}; +use rustc_middle::traits::solve::GoalSource; use rustc_middle::ty::{self, Ty, TypeVisitableExt}; use rustc_span::Span; use rustc_trait_selection::solve::inspect::{ @@ -85,7 +85,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> { root_cause: &obligation.cause, }; - let goal = Goal::new(self.tcx, obligation.param_env, obligation.predicate); + let goal = obligation.as_goal(); self.visit_proof_tree(goal, &mut visitor); } diff --git a/compiler/rustc_infer/src/infer/opaque_types/mod.rs b/compiler/rustc_infer/src/infer/opaque_types/mod.rs index 3fa1923121a23..215b133372664 100644 --- a/compiler/rustc_infer/src/infer/opaque_types/mod.rs +++ b/compiler/rustc_infer/src/infer/opaque_types/mod.rs @@ -246,8 +246,7 @@ impl<'tcx> InferCtxt<'tcx> { .eq(DefineOpaqueTypes::Yes, prev, hidden_ty)? .obligations .into_iter() - // FIXME: Shuttling between obligations and goals is awkward. - .map(Goal::from), + .map(|obligation| obligation.as_goal()), ); } } diff --git a/compiler/rustc_infer/src/traits/mod.rs b/compiler/rustc_infer/src/traits/mod.rs index ac641ef565228..b537750f1b51b 100644 --- a/compiler/rustc_infer/src/traits/mod.rs +++ b/compiler/rustc_infer/src/traits/mod.rs @@ -54,6 +54,12 @@ pub struct Obligation<'tcx, T> { pub recursion_depth: usize, } +impl<'tcx, T: Copy> Obligation<'tcx, T> { + pub fn as_goal(&self) -> solve::Goal<'tcx, T> { + solve::Goal { param_env: self.param_env, predicate: self.predicate } + } +} + impl<'tcx, T: PartialEq> PartialEq> for Obligation<'tcx, T> { #[inline] fn eq(&self, other: &Obligation<'tcx, T>) -> bool { @@ -75,12 +81,6 @@ impl Hash for Obligation<'_, T> { } } -impl<'tcx, P> From> for solve::Goal<'tcx, P> { - fn from(value: Obligation<'tcx, P>) -> Self { - solve::Goal { param_env: value.param_env, predicate: value.predicate } - } -} - pub type PredicateObligation<'tcx> = Obligation<'tcx, ty::Predicate<'tcx>>; pub type TraitObligation<'tcx> = Obligation<'tcx, ty::TraitPredicate<'tcx>>; pub type PolyTraitObligation<'tcx> = Obligation<'tcx, ty::PolyTraitPredicate<'tcx>>; diff --git a/compiler/rustc_trait_selection/src/solve/delegate.rs b/compiler/rustc_trait_selection/src/solve/delegate.rs index af5a60027ba4a..3d9a90eb74e7a 100644 --- a/compiler/rustc_trait_selection/src/solve/delegate.rs +++ b/compiler/rustc_trait_selection/src/solve/delegate.rs @@ -96,7 +96,7 @@ impl<'tcx> rustc_next_trait_solver::delegate::SolverDelegate for SolverDelegate< ) -> Option>>> { crate::traits::wf::unnormalized_obligations(&self.0, param_env, arg, DUMMY_SP, CRATE_DEF_ID) .map(|obligations| { - obligations.into_iter().map(|obligation| obligation.into()).collect() + obligations.into_iter().map(|obligation| obligation.as_goal()).collect() }) } diff --git a/compiler/rustc_trait_selection/src/solve/fulfill.rs b/compiler/rustc_trait_selection/src/solve/fulfill.rs index 704ba6e501d8c..192e632a2d5b9 100644 --- a/compiler/rustc_trait_selection/src/solve/fulfill.rs +++ b/compiler/rustc_trait_selection/src/solve/fulfill.rs @@ -80,7 +80,7 @@ impl<'tcx> ObligationStorage<'tcx> { // change. // FIXME: is merged, this can be removed. self.overflowed.extend(ExtractIf::new(&mut self.pending, |o| { - let goal = o.clone().into(); + let goal = o.as_goal(); let result = <&SolverDelegate<'tcx>>::from(infcx) .evaluate_root_goal(goal, GenerateProofTree::No, o.cause.span) .0; @@ -161,7 +161,7 @@ where let mut has_changed = false; for obligation in self.obligations.unstalled_for_select() { - let goal = obligation.clone().into(); + let goal = obligation.as_goal(); let result = <&SolverDelegate<'tcx>>::from(infcx) .evaluate_root_goal(goal, GenerateProofTree::No, obligation.cause.span) .0; diff --git a/compiler/rustc_trait_selection/src/solve/fulfill/derive_errors.rs b/compiler/rustc_trait_selection/src/solve/fulfill/derive_errors.rs index 352ac7c1a4e6f..3a939df25e07b 100644 --- a/compiler/rustc_trait_selection/src/solve/fulfill/derive_errors.rs +++ b/compiler/rustc_trait_selection/src/solve/fulfill/derive_errors.rs @@ -10,7 +10,7 @@ use rustc_middle::ty::error::{ExpectedFound, TypeError}; use rustc_middle::ty::{self, Ty, TyCtxt}; use rustc_middle::{bug, span_bug}; use rustc_next_trait_solver::solve::{GenerateProofTree, SolverDelegateEvalExt as _}; -use rustc_type_ir::solve::{Goal, NoSolution}; +use rustc_type_ir::solve::NoSolution; use tracing::{instrument, trace}; use crate::solve::Certainty; @@ -89,7 +89,7 @@ pub(super) fn fulfillment_error_for_stalled<'tcx>( let (code, refine_obligation) = infcx.probe(|_| { match <&SolverDelegate<'tcx>>::from(infcx) .evaluate_root_goal( - root_obligation.clone().into(), + root_obligation.as_goal(), GenerateProofTree::No, root_obligation.cause.span, ) @@ -155,7 +155,7 @@ fn find_best_leaf_obligation<'tcx>( .fudge_inference_if_ok(|| { infcx .visit_proof_tree( - obligation.clone().into(), + obligation.as_goal(), &mut BestObligation { obligation: obligation.clone(), consider_ambiguities }, ) .break_value() @@ -245,7 +245,7 @@ impl<'tcx> BestObligation<'tcx> { { let nested_goal = candidate.instantiate_proof_tree_for_nested_goal( GoalSource::Misc, - Goal::new(infcx.tcx, obligation.param_env, obligation.predicate), + obligation.as_goal(), self.span(), ); // Skip nested goals that aren't the *reason* for our goal's failure. diff --git a/compiler/rustc_trait_selection/src/traits/coherence.rs b/compiler/rustc_trait_selection/src/traits/coherence.rs index 4c7172c32781a..bcc247ba53c2b 100644 --- a/compiler/rustc_trait_selection/src/traits/coherence.rs +++ b/compiler/rustc_trait_selection/src/traits/coherence.rs @@ -625,7 +625,7 @@ fn compute_intercrate_ambiguity_causes<'tcx>( let mut causes: FxIndexSet> = Default::default(); for obligation in obligations { - search_ambiguity_causes(infcx, obligation.clone().into(), &mut causes); + search_ambiguity_causes(infcx, obligation.as_goal(), &mut causes); } causes From d588bc2a2b49cb9f18fe051774a3f4a41ab5068e Mon Sep 17 00:00:00 2001 From: Michael Goulet Date: Sat, 22 Mar 2025 21:01:11 +0000 Subject: [PATCH 2/3] Don't super fold const in Resolver --- compiler/rustc_hir_typeck/src/writeback.rs | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/compiler/rustc_hir_typeck/src/writeback.rs b/compiler/rustc_hir_typeck/src/writeback.rs index 6a3417ae5d6fb..748cb11290d93 100644 --- a/compiler/rustc_hir_typeck/src/writeback.rs +++ b/compiler/rustc_hir_typeck/src/writeback.rs @@ -864,10 +864,7 @@ impl<'cx, 'tcx> TypeFolder> for Resolver<'cx, 'tcx> { } fn fold_const(&mut self, ct: ty::Const<'tcx>) -> ty::Const<'tcx> { - self.handle_term(ct, ty::Const::outer_exclusive_binder, |tcx, guar| { - ty::Const::new_error(tcx, guar) - }) - .super_fold_with(self) + self.handle_term(ct, ty::Const::outer_exclusive_binder, ty::Const::new_error) } fn fold_predicate(&mut self, predicate: ty::Predicate<'tcx>) -> ty::Predicate<'tcx> { From fad34c603cd6c7e432f297e682bb68a2cf55df0b Mon Sep 17 00:00:00 2001 From: Michael Goulet Date: Sun, 23 Mar 2025 18:34:33 +0000 Subject: [PATCH 3/3] Explicitly don't fold coroutine obligations in writeback --- compiler/rustc_hir_typeck/src/writeback.rs | 45 +++++++++++++++------- 1 file changed, 31 insertions(+), 14 deletions(-) diff --git a/compiler/rustc_hir_typeck/src/writeback.rs b/compiler/rustc_hir_typeck/src/writeback.rs index 748cb11290d93..b63c0b6ab7e09 100644 --- a/compiler/rustc_hir_typeck/src/writeback.rs +++ b/compiler/rustc_hir_typeck/src/writeback.rs @@ -548,7 +548,8 @@ impl<'cx, 'tcx> WritebackCx<'cx, 'tcx> { let fcx_typeck_results = self.fcx.typeck_results.borrow(); assert_eq!(fcx_typeck_results.hir_owner, self.typeck_results.hir_owner); for (predicate, cause) in &fcx_typeck_results.coroutine_stalled_predicates { - let (predicate, cause) = self.resolve((*predicate, cause.clone()), &cause.span); + let (predicate, cause) = + self.resolve_coroutine_predicate((*predicate, cause.clone()), &cause.span); self.typeck_results.coroutine_stalled_predicates.insert((predicate, cause)); } } @@ -730,7 +731,25 @@ impl<'cx, 'tcx> WritebackCx<'cx, 'tcx> { T: TypeFoldable>, { let value = self.fcx.resolve_vars_if_possible(value); - let value = value.fold_with(&mut Resolver::new(self.fcx, span, self.body)); + let value = value.fold_with(&mut Resolver::new(self.fcx, span, self.body, true)); + assert!(!value.has_infer()); + + // We may have introduced e.g. `ty::Error`, if inference failed, make sure + // to mark the `TypeckResults` as tainted in that case, so that downstream + // users of the typeck results don't produce extra errors, or worse, ICEs. + if let Err(guar) = value.error_reported() { + self.typeck_results.tainted_by_errors = Some(guar); + } + + value + } + + fn resolve_coroutine_predicate(&mut self, value: T, span: &dyn Locatable) -> T + where + T: TypeFoldable>, + { + let value = self.fcx.resolve_vars_if_possible(value); + let value = value.fold_with(&mut Resolver::new(self.fcx, span, self.body, false)); assert!(!value.has_infer()); // We may have introduced e.g. `ty::Error`, if inference failed, make sure @@ -774,8 +793,9 @@ impl<'cx, 'tcx> Resolver<'cx, 'tcx> { fcx: &'cx FnCtxt<'cx, 'tcx>, span: &'cx dyn Locatable, body: &'tcx hir::Body<'tcx>, + should_normalize: bool, ) -> Resolver<'cx, 'tcx> { - Resolver { fcx, span, body, should_normalize: fcx.next_trait_solver() } + Resolver { fcx, span, body, should_normalize } } fn report_error(&self, p: impl Into>) -> ErrorGuaranteed { @@ -805,10 +825,9 @@ impl<'cx, 'tcx> Resolver<'cx, 'tcx> { T: Into> + TypeSuperFoldable> + Copy, { let tcx = self.fcx.tcx; - // We must deeply normalize in the new solver, since later lints - // expect that types that show up in the typeck are fully - // normalized. - let mut value = if self.should_normalize { + // We must deeply normalize in the new solver, since later lints expect + // that types that show up in the typeck are fully normalized. + let mut value = if self.should_normalize && self.fcx.next_trait_solver() { let body_id = tcx.hir_body_owner_def_id(self.body.id()); let cause = ObligationCause::misc(self.span.to_span(tcx), body_id); let at = self.fcx.at(&cause, self.fcx.param_env); @@ -868,13 +887,11 @@ impl<'cx, 'tcx> TypeFolder> for Resolver<'cx, 'tcx> { } fn fold_predicate(&mut self, predicate: ty::Predicate<'tcx>) -> ty::Predicate<'tcx> { - // Do not normalize predicates in the new solver. The new solver is - // supposed to handle unnormalized predicates and incorrectly normalizing - // them can be unsound, e.g. for `WellFormed` predicates. - let prev = mem::replace(&mut self.should_normalize, false); - let predicate = predicate.super_fold_with(self); - self.should_normalize = prev; - predicate + assert!( + !self.should_normalize, + "normalizing predicates in writeback is not generally sound" + ); + predicate.super_fold_with(self) } }