Skip to content

Record impl args in the proof tree in new solver #124759

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions compiler/rustc_middle/src/traits/solve/inspect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ pub enum ProbeStep<'tcx> {
/// used whenever there are multiple candidates to prove the
/// current goalby .
NestedProbe(Probe<'tcx>),
/// A trait goal was satisfied by an impl candidate.
RecordImplArgs { impl_args: CanonicalState<'tcx, ty::GenericArgsRef<'tcx>> },
/// A call to `EvalCtxt::evaluate_added_goals_make_canonical_response` with
/// `Certainty` was made. This is the certainty passed in, so it's not unified
/// with the certainty of the `try_evaluate_added_goals` that is done within;
Expand Down
3 changes: 3 additions & 0 deletions compiler/rustc_middle/src/traits/solve/inspect/format.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,9 @@ impl<'a, 'b> ProofTreeFormatter<'a, 'b> {
ProbeStep::MakeCanonicalResponse { shallow_certainty } => {
writeln!(this.f, "EVALUATE GOALS AND MAKE RESPONSE: {shallow_certainty:?}")?
}
ProbeStep::RecordImplArgs { impl_args } => {
writeln!(this.f, "RECORDED IMPL ARGS: {impl_args:?}")?
}
}
}
Ok(())
Expand Down
5 changes: 5 additions & 0 deletions compiler/rustc_trait_selection/src/solve/eval_ctxt/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,11 @@ impl<'a, 'tcx> EvalCtxt<'a, 'tcx> {

Ok(unchanged_certainty)
}

/// Record impl args in the proof tree for later access by `InspectCandidate`.
pub(crate) fn record_impl_args(&mut self, impl_args: ty::GenericArgsRef<'tcx>) {
self.inspect.record_impl_args(self.infcx, self.max_input_universe, impl_args)
}
}

impl<'tcx> EvalCtxt<'_, 'tcx> {
Expand Down
80 changes: 21 additions & 59 deletions compiler/rustc_trait_selection/src/solve/eval_ctxt/select.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
use std::ops::ControlFlow;

use rustc_hir::def_id::DefId;
use rustc_infer::infer::{DefineOpaqueTypes, InferCtxt, InferOk};
use rustc_infer::infer::InferCtxt;
use rustc_infer::traits::solve::inspect::ProbeKind;
use rustc_infer::traits::solve::{CandidateSource, Certainty, Goal};
use rustc_infer::traits::{
BuiltinImplSource, ImplSource, ImplSourceUserDefinedData, Obligation, ObligationCause,
PolyTraitObligation, PredicateObligation, Selection, SelectionError, SelectionResult,
PolyTraitObligation, Selection, SelectionError, SelectionResult,
};
use rustc_macros::extension;
use rustc_span::Span;
Expand Down Expand Up @@ -133,32 +132,32 @@ fn to_selection<'tcx>(
return None;
}

let make_nested = || {
cand.instantiate_nested_goals(span)
.into_iter()
.map(|nested| {
Obligation::new(
nested.infcx().tcx,
ObligationCause::dummy_with_span(span),
nested.goal().param_env,
nested.goal().predicate,
)
})
.collect()
};
let (nested, impl_args) = cand.instantiate_nested_goals_and_opt_impl_args(span);
let nested = nested
.into_iter()
.map(|nested| {
Obligation::new(
nested.infcx().tcx,
ObligationCause::dummy_with_span(span),
nested.goal().param_env,
nested.goal().predicate,
)
})
.collect();

Some(match cand.kind() {
ProbeKind::TraitCandidate { source, result: _ } => match source {
CandidateSource::Impl(impl_def_id) => {
// FIXME: Remove this in favor of storing this in the tree
// For impl candidates, we do the rematch manually to compute the args.
ImplSource::UserDefined(rematch_impl(cand.goal(), impl_def_id, span))
}
CandidateSource::BuiltinImpl(builtin) => ImplSource::Builtin(builtin, make_nested()),
CandidateSource::ParamEnv(_) => ImplSource::Param(make_nested()),
CandidateSource::AliasBound => {
ImplSource::Builtin(BuiltinImplSource::Misc, make_nested())
ImplSource::UserDefined(ImplSourceUserDefinedData {
impl_def_id,
args: impl_args.expect("expected recorded impl args for impl candidate"),
nested,
})
}
CandidateSource::BuiltinImpl(builtin) => ImplSource::Builtin(builtin, nested),
CandidateSource::ParamEnv(_) | CandidateSource::AliasBound => ImplSource::Param(nested),
CandidateSource::CoherenceUnknowable => {
span_bug!(span, "didn't expect to select an unknowable candidate")
}
Expand All @@ -173,40 +172,3 @@ fn to_selection<'tcx>(
}
})
}

fn rematch_impl<'tcx>(
goal: &inspect::InspectGoal<'_, 'tcx>,
impl_def_id: DefId,
span: Span,
) -> ImplSourceUserDefinedData<'tcx, PredicateObligation<'tcx>> {
let infcx = goal.infcx();
let goal_trait_ref = infcx
.enter_forall_and_leak_universe(goal.goal().predicate.to_opt_poly_trait_pred().unwrap())
.trait_ref;

let args = infcx.fresh_args_for_item(span, impl_def_id);
let impl_trait_ref =
infcx.tcx.impl_trait_ref(impl_def_id).unwrap().instantiate(infcx.tcx, args);

let InferOk { value: (), obligations: mut nested } = infcx
.at(&ObligationCause::dummy_with_span(span), goal.goal().param_env)
.eq(DefineOpaqueTypes::Yes, goal_trait_ref, impl_trait_ref)
.expect("rematching impl failed");

// FIXME(-Znext-solver=coinductive): We need to add supertraits here eventually.

nested.extend(
infcx.tcx.predicates_of(impl_def_id).instantiate(infcx.tcx, args).into_iter().map(
|(clause, _)| {
Obligation::new(
infcx.tcx,
ObligationCause::dummy_with_span(span),
goal.goal().param_env,
clause,
)
},
),
);

ImplSourceUserDefinedData { impl_def_id, nested, args }
}
42 changes: 37 additions & 5 deletions compiler/rustc_trait_selection/src/solve/inspect/analyse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ pub struct InspectCandidate<'a, 'tcx> {
kind: inspect::ProbeKind<'tcx>,
nested_goals: Vec<(GoalSource, inspect::CanonicalState<'tcx, Goal<'tcx, ty::Predicate<'tcx>>>)>,
final_state: inspect::CanonicalState<'tcx, ()>,
impl_args: Option<inspect::CanonicalState<'tcx, ty::GenericArgsRef<'tcx>>>,
result: QueryResult<'tcx>,
shallow_certainty: Certainty,
}
Expand Down Expand Up @@ -135,7 +136,20 @@ impl<'a, 'tcx> InspectCandidate<'a, 'tcx> {

/// Instantiate the nested goals for the candidate without rolling back their
/// inference constraints. This function modifies the state of the `infcx`.
///
/// See [`Self::instantiate_nested_goals_and_opt_impl_args`] if you need the impl args too.
pub fn instantiate_nested_goals(&self, span: Span) -> Vec<InspectGoal<'a, 'tcx>> {
self.instantiate_nested_goals_and_opt_impl_args(span).0
}

/// Instantiate the nested goals for the candidate without rolling back their
/// inference constraints, and optionally the args of an impl if this candidate
/// came from a `CandidateSource::Impl`. This function modifies the state of the
/// `infcx`.
pub fn instantiate_nested_goals_and_opt_impl_args(
&self,
span: Span,
) -> (Vec<InspectGoal<'a, 'tcx>>, Option<ty::GenericArgsRef<'tcx>>) {
let infcx = self.goal.infcx;
let param_env = self.goal.goal.param_env;
let mut orig_values = self.goal.orig_values.to_vec();
Expand Down Expand Up @@ -164,14 +178,25 @@ impl<'a, 'tcx> InspectCandidate<'a, 'tcx> {
self.final_state,
);

let impl_args = self.impl_args.map(|impl_args| {
canonical::instantiate_canonical_state(
infcx,
span,
param_env,
&mut orig_values,
impl_args,
)
.fold_with(&mut EagerResolver::new(infcx))
});

if let Some(term_hack) = self.goal.normalizes_to_term_hack {
// FIXME: We ignore the expected term of `NormalizesTo` goals
// when computing the result of its candidates. This is
// scuffed.
let _ = term_hack.constrain(infcx, span, param_env);
}

instantiated_goals
let goals = instantiated_goals
.into_iter()
.map(|(source, goal)| match goal.predicate.kind().no_bound_vars() {
Some(ty::PredicateKind::NormalizesTo(ty::NormalizesTo { alias, term })) => {
Expand Down Expand Up @@ -208,7 +233,9 @@ impl<'a, 'tcx> InspectCandidate<'a, 'tcx> {
source,
),
})
.collect()
.collect();

(goals, impl_args)
}

/// Visit all nested goals of this candidate, rolling back
Expand Down Expand Up @@ -245,9 +272,10 @@ impl<'a, 'tcx> InspectGoal<'a, 'tcx> {
probe: &inspect::Probe<'tcx>,
) {
let mut shallow_certainty = None;
let mut impl_args = None;
for step in &probe.steps {
match step {
&inspect::ProbeStep::AddGoal(source, goal) => nested_goals.push((source, goal)),
match *step {
inspect::ProbeStep::AddGoal(source, goal) => nested_goals.push((source, goal)),
inspect::ProbeStep::NestedProbe(ref probe) => {
// Nested probes have to prove goals added in their parent
// but do not leak them, so we truncate the added goals
Expand All @@ -257,7 +285,10 @@ impl<'a, 'tcx> InspectGoal<'a, 'tcx> {
nested_goals.truncate(num_goals);
}
inspect::ProbeStep::MakeCanonicalResponse { shallow_certainty: c } => {
assert_eq!(shallow_certainty.replace(*c), None);
assert_eq!(shallow_certainty.replace(c), None);
}
inspect::ProbeStep::RecordImplArgs { impl_args: i } => {
assert_eq!(impl_args.replace(i), None);
}
inspect::ProbeStep::EvaluateGoals(_) => (),
}
Expand All @@ -284,6 +315,7 @@ impl<'a, 'tcx> InspectGoal<'a, 'tcx> {
final_state: probe.final_state,
result,
shallow_certainty,
impl_args,
});
}
}
Expand Down
30 changes: 29 additions & 1 deletion compiler/rustc_trait_selection/src/solve/inspect/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ enum WipProbeStep<'tcx> {
EvaluateGoals(WipAddedGoalsEvaluation<'tcx>),
NestedProbe(WipProbe<'tcx>),
MakeCanonicalResponse { shallow_certainty: Certainty },
RecordImplArgs { impl_args: inspect::CanonicalState<'tcx, ty::GenericArgsRef<'tcx>> },
}

impl<'tcx> WipProbeStep<'tcx> {
Expand All @@ -250,6 +251,9 @@ impl<'tcx> WipProbeStep<'tcx> {
WipProbeStep::AddGoal(source, goal) => inspect::ProbeStep::AddGoal(source, goal),
WipProbeStep::EvaluateGoals(eval) => inspect::ProbeStep::EvaluateGoals(eval.finalize()),
WipProbeStep::NestedProbe(probe) => inspect::ProbeStep::NestedProbe(probe.finalize()),
WipProbeStep::RecordImplArgs { impl_args } => {
inspect::ProbeStep::RecordImplArgs { impl_args }
}
WipProbeStep::MakeCanonicalResponse { shallow_certainty } => {
inspect::ProbeStep::MakeCanonicalResponse { shallow_certainty }
}
Expand Down Expand Up @@ -534,6 +538,30 @@ impl<'tcx> ProofTreeBuilder<'tcx> {
}
}

pub(crate) fn record_impl_args(
&mut self,
infcx: &InferCtxt<'tcx>,
max_input_universe: ty::UniverseIndex,
impl_args: ty::GenericArgsRef<'tcx>,
) {
match self.as_mut() {
Some(DebugSolver::GoalEvaluationStep(state)) => {
let impl_args = canonical::make_canonical_state(
infcx,
&state.var_values,
max_input_universe,
impl_args,
);
state
.current_evaluation_scope()
.steps
.push(WipProbeStep::RecordImplArgs { impl_args });
}
None => {}
_ => bug!(),
}
}

pub fn make_canonical_response(&mut self, shallow_certainty: Certainty) {
match self.as_mut() {
Some(DebugSolver::GoalEvaluationStep(state)) => {
Expand All @@ -543,7 +571,7 @@ impl<'tcx> ProofTreeBuilder<'tcx> {
.push(WipProbeStep::MakeCanonicalResponse { shallow_certainty });
}
None => {}
_ => {}
_ => bug!(),
}
}

Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_trait_selection/src/solve/trait_goals.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ impl<'tcx> assembly::GoalKind<'tcx> for TraitPredicate<'tcx> {

ecx.probe_trait_candidate(CandidateSource::Impl(impl_def_id)).enter(|ecx| {
let impl_args = ecx.fresh_args_for_item(impl_def_id);
ecx.record_impl_args(impl_args);
let impl_trait_ref = impl_trait_header.trait_ref.instantiate(tcx, impl_args);

ecx.eq(goal.param_env, goal.predicate.trait_ref, impl_trait_ref)?;
Expand Down
13 changes: 13 additions & 0 deletions tests/ui/traits/next-solver/select-alias-bound-as-param.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
//@ check-pass
//@ compile-flags: -Znext-solver

pub(crate) fn y() -> impl FnMut() {
|| {}
}

pub(crate) fn x(a: (), b: ()) {
let x = ();
y()()
}

fn main() {}
Loading