Skip to content

Tweaks to writeback and Obligation -> Goal conversion #138846

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions compiler/rustc_hir_typeck/src/fn_ctxt/inspect_obligations.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
//! A utility module to inspect currently ambiguous obligations in the current context.

use rustc_infer::traits::{self, ObligationCause, PredicateObligations};
use rustc_middle::traits::solve::{Goal, GoalSource};
use rustc_middle::traits::solve::GoalSource;
use rustc_middle::ty::{self, Ty, TypeVisitableExt};
use rustc_span::Span;
use rustc_trait_selection::solve::inspect::{
Expand Down Expand Up @@ -85,7 +85,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
root_cause: &obligation.cause,
};

let goal = Goal::new(self.tcx, obligation.param_env, obligation.predicate);
let goal = obligation.as_goal();
self.visit_proof_tree(goal, &mut visitor);
}

Expand Down
50 changes: 32 additions & 18 deletions compiler/rustc_hir_typeck/src/writeback.rs
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,8 @@ impl<'cx, 'tcx> WritebackCx<'cx, 'tcx> {
let fcx_typeck_results = self.fcx.typeck_results.borrow();
assert_eq!(fcx_typeck_results.hir_owner, self.typeck_results.hir_owner);
for (predicate, cause) in &fcx_typeck_results.coroutine_stalled_predicates {
let (predicate, cause) = self.resolve((*predicate, cause.clone()), &cause.span);
let (predicate, cause) =
self.resolve_coroutine_predicate((*predicate, cause.clone()), &cause.span);
self.typeck_results.coroutine_stalled_predicates.insert((predicate, cause));
}
}
Expand Down Expand Up @@ -730,7 +731,25 @@ impl<'cx, 'tcx> WritebackCx<'cx, 'tcx> {
T: TypeFoldable<TyCtxt<'tcx>>,
{
let value = self.fcx.resolve_vars_if_possible(value);
let value = value.fold_with(&mut Resolver::new(self.fcx, span, self.body));
let value = value.fold_with(&mut Resolver::new(self.fcx, span, self.body, true));
assert!(!value.has_infer());

// We may have introduced e.g. `ty::Error`, if inference failed, make sure
// to mark the `TypeckResults` as tainted in that case, so that downstream
// users of the typeck results don't produce extra errors, or worse, ICEs.
if let Err(guar) = value.error_reported() {
self.typeck_results.tainted_by_errors = Some(guar);
}

value
}

fn resolve_coroutine_predicate<T>(&mut self, value: T, span: &dyn Locatable) -> T
where
T: TypeFoldable<TyCtxt<'tcx>>,
{
let value = self.fcx.resolve_vars_if_possible(value);
let value = value.fold_with(&mut Resolver::new(self.fcx, span, self.body, false));
assert!(!value.has_infer());

// We may have introduced e.g. `ty::Error`, if inference failed, make sure
Expand Down Expand Up @@ -774,8 +793,9 @@ impl<'cx, 'tcx> Resolver<'cx, 'tcx> {
fcx: &'cx FnCtxt<'cx, 'tcx>,
span: &'cx dyn Locatable,
body: &'tcx hir::Body<'tcx>,
should_normalize: bool,
) -> Resolver<'cx, 'tcx> {
Resolver { fcx, span, body, should_normalize: fcx.next_trait_solver() }
Resolver { fcx, span, body, should_normalize }
}

fn report_error(&self, p: impl Into<ty::GenericArg<'tcx>>) -> ErrorGuaranteed {
Expand Down Expand Up @@ -805,10 +825,9 @@ impl<'cx, 'tcx> Resolver<'cx, 'tcx> {
T: Into<ty::GenericArg<'tcx>> + TypeSuperFoldable<TyCtxt<'tcx>> + Copy,
{
let tcx = self.fcx.tcx;
// We must deeply normalize in the new solver, since later lints
// expect that types that show up in the typeck are fully
// normalized.
let mut value = if self.should_normalize {
// We must deeply normalize in the new solver, since later lints expect
// that types that show up in the typeck are fully normalized.
let mut value = if self.should_normalize && self.fcx.next_trait_solver() {
let body_id = tcx.hir_body_owner_def_id(self.body.id());
let cause = ObligationCause::misc(self.span.to_span(tcx), body_id);
let at = self.fcx.at(&cause, self.fcx.param_env);
Expand Down Expand Up @@ -864,20 +883,15 @@ impl<'cx, 'tcx> TypeFolder<TyCtxt<'tcx>> for Resolver<'cx, 'tcx> {
}

fn fold_const(&mut self, ct: ty::Const<'tcx>) -> ty::Const<'tcx> {
self.handle_term(ct, ty::Const::outer_exclusive_binder, |tcx, guar| {
ty::Const::new_error(tcx, guar)
})
.super_fold_with(self)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to super fold b/c it should be fully resolved :)

self.handle_term(ct, ty::Const::outer_exclusive_binder, ty::Const::new_error)
}

fn fold_predicate(&mut self, predicate: ty::Predicate<'tcx>) -> ty::Predicate<'tcx> {
// Do not normalize predicates in the new solver. The new solver is
// supposed to handle unnormalized predicates and incorrectly normalizing
// them can be unsound, e.g. for `WellFormed` predicates.
let prev = mem::replace(&mut self.should_normalize, false);
let predicate = predicate.super_fold_with(self);
self.should_normalize = prev;
predicate
assert!(
!self.should_normalize,
"normalizing predicates in writeback is not generally sound"
);
predicate.super_fold_with(self)
}
}

Expand Down
3 changes: 1 addition & 2 deletions compiler/rustc_infer/src/infer/opaque_types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -246,8 +246,7 @@ impl<'tcx> InferCtxt<'tcx> {
.eq(DefineOpaqueTypes::Yes, prev, hidden_ty)?
.obligations
.into_iter()
// FIXME: Shuttling between obligations and goals is awkward.
.map(Goal::from),
.map(|obligation| obligation.as_goal()),
);
}
}
Expand Down
12 changes: 6 additions & 6 deletions compiler/rustc_infer/src/traits/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,12 @@ pub struct Obligation<'tcx, T> {
pub recursion_depth: usize,
}

impl<'tcx, T: Copy> Obligation<'tcx, T> {
pub fn as_goal(&self) -> solve::Goal<'tcx, T> {
solve::Goal { param_env: self.param_env, predicate: self.predicate }
}
}

impl<'tcx, T: PartialEq> PartialEq<Obligation<'tcx, T>> for Obligation<'tcx, T> {
#[inline]
fn eq(&self, other: &Obligation<'tcx, T>) -> bool {
Expand All @@ -75,12 +81,6 @@ impl<T: Hash> Hash for Obligation<'_, T> {
}
}

impl<'tcx, P> From<Obligation<'tcx, P>> for solve::Goal<'tcx, P> {
fn from(value: Obligation<'tcx, P>) -> Self {
solve::Goal { param_env: value.param_env, predicate: value.predicate }
}
}

pub type PredicateObligation<'tcx> = Obligation<'tcx, ty::Predicate<'tcx>>;
pub type TraitObligation<'tcx> = Obligation<'tcx, ty::TraitPredicate<'tcx>>;
pub type PolyTraitObligation<'tcx> = Obligation<'tcx, ty::PolyTraitPredicate<'tcx>>;
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_trait_selection/src/solve/delegate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ impl<'tcx> rustc_next_trait_solver::delegate::SolverDelegate for SolverDelegate<
) -> Option<Vec<Goal<'tcx, ty::Predicate<'tcx>>>> {
crate::traits::wf::unnormalized_obligations(&self.0, param_env, arg, DUMMY_SP, CRATE_DEF_ID)
.map(|obligations| {
obligations.into_iter().map(|obligation| obligation.into()).collect()
obligations.into_iter().map(|obligation| obligation.as_goal()).collect()
})
}

Expand Down
4 changes: 2 additions & 2 deletions compiler/rustc_trait_selection/src/solve/fulfill.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ impl<'tcx> ObligationStorage<'tcx> {
// change.
// FIXME: <https://github.com/Gankra/thin-vec/pull/66> is merged, this can be removed.
self.overflowed.extend(ExtractIf::new(&mut self.pending, |o| {
let goal = o.clone().into();
let goal = o.as_goal();
let result = <&SolverDelegate<'tcx>>::from(infcx)
.evaluate_root_goal(goal, GenerateProofTree::No, o.cause.span)
.0;
Expand Down Expand Up @@ -161,7 +161,7 @@ where

let mut has_changed = false;
for obligation in self.obligations.unstalled_for_select() {
let goal = obligation.clone().into();
let goal = obligation.as_goal();
let result = <&SolverDelegate<'tcx>>::from(infcx)
.evaluate_root_goal(goal, GenerateProofTree::No, obligation.cause.span)
.0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use rustc_middle::ty::error::{ExpectedFound, TypeError};
use rustc_middle::ty::{self, Ty, TyCtxt};
use rustc_middle::{bug, span_bug};
use rustc_next_trait_solver::solve::{GenerateProofTree, SolverDelegateEvalExt as _};
use rustc_type_ir::solve::{Goal, NoSolution};
use rustc_type_ir::solve::NoSolution;
use tracing::{instrument, trace};

use crate::solve::Certainty;
Expand Down Expand Up @@ -89,7 +89,7 @@ pub(super) fn fulfillment_error_for_stalled<'tcx>(
let (code, refine_obligation) = infcx.probe(|_| {
match <&SolverDelegate<'tcx>>::from(infcx)
.evaluate_root_goal(
root_obligation.clone().into(),
root_obligation.as_goal(),
GenerateProofTree::No,
root_obligation.cause.span,
)
Expand Down Expand Up @@ -155,7 +155,7 @@ fn find_best_leaf_obligation<'tcx>(
.fudge_inference_if_ok(|| {
infcx
.visit_proof_tree(
obligation.clone().into(),
obligation.as_goal(),
&mut BestObligation { obligation: obligation.clone(), consider_ambiguities },
)
.break_value()
Expand Down Expand Up @@ -245,7 +245,7 @@ impl<'tcx> BestObligation<'tcx> {
{
let nested_goal = candidate.instantiate_proof_tree_for_nested_goal(
GoalSource::Misc,
Goal::new(infcx.tcx, obligation.param_env, obligation.predicate),
obligation.as_goal(),
self.span(),
);
// Skip nested goals that aren't the *reason* for our goal's failure.
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_trait_selection/src/traits/coherence.rs
Original file line number Diff line number Diff line change
Expand Up @@ -625,7 +625,7 @@ fn compute_intercrate_ambiguity_causes<'tcx>(
let mut causes: FxIndexSet<IntercrateAmbiguityCause<'tcx>> = Default::default();

for obligation in obligations {
search_ambiguity_causes(infcx, obligation.clone().into(), &mut causes);
search_ambiguity_causes(infcx, obligation.as_goal(), &mut causes);
}

causes
Expand Down
Loading