Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 17 additions & 9 deletions compiler/rustc_next_trait_solver/src/solve/search_graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,20 +74,28 @@ where
}
}

fn is_initial_provisional_result(
cx: Self::Cx,
kind: PathKind,
input: CanonicalInput<I>,
result: QueryResult<I>,
) -> bool {
Self::initial_provisional_result(cx, kind, input) == result
fn is_initial_provisional_result(result: QueryResult<I>) -> Option<PathKind> {
match result {
Ok(response) => {
if has_no_inference_or_external_constraints(response) {
if response.value.certainty == Certainty::Yes {
return Some(PathKind::Coinductive);
} else if response.value.certainty == Certainty::overflow(false) {
return Some(PathKind::Unknown);
}
}

None
}
Err(NoSolution) => Some(PathKind::Inductive),
}
}

fn on_stack_overflow(cx: I, input: CanonicalInput<I>) -> QueryResult<I> {
fn stack_overflow_result(cx: I, input: CanonicalInput<I>) -> QueryResult<I> {
response_no_constraints(cx, input, Certainty::overflow(true))
}

fn on_fixpoint_overflow(cx: I, input: CanonicalInput<I>) -> QueryResult<I> {
fn fixpoint_overflow_result(cx: I, input: CanonicalInput<I>) -> QueryResult<I> {
response_no_constraints(cx, input, Certainty::overflow(false))
}

Expand Down
148 changes: 101 additions & 47 deletions compiler/rustc_type_ir/src/search_graph/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,12 @@ pub trait Delegate: Sized {
kind: PathKind,
input: <Self::Cx as Cx>::Input,
) -> <Self::Cx as Cx>::Result;
fn is_initial_provisional_result(
fn is_initial_provisional_result(result: <Self::Cx as Cx>::Result) -> Option<PathKind>;
fn stack_overflow_result(
cx: Self::Cx,
kind: PathKind,
input: <Self::Cx as Cx>::Input,
result: <Self::Cx as Cx>::Result,
) -> bool;
fn on_stack_overflow(cx: Self::Cx, input: <Self::Cx as Cx>::Input) -> <Self::Cx as Cx>::Result;
fn on_fixpoint_overflow(
) -> <Self::Cx as Cx>::Result;
fn fixpoint_overflow_result(
cx: Self::Cx,
input: <Self::Cx as Cx>::Input,
) -> <Self::Cx as Cx>::Result;
Expand Down Expand Up @@ -215,6 +213,27 @@ impl HeadUsages {
let HeadUsages { inductive, unknown, coinductive, forced_ambiguity } = self;
inductive == 0 && unknown == 0 && coinductive == 0 && forced_ambiguity == 0
}

fn is_single(self, path_kind: PathKind) -> bool {
match path_kind {
PathKind::Inductive => matches!(
self,
HeadUsages { inductive: _, unknown: 0, coinductive: 0, forced_ambiguity: 0 },
),
PathKind::Unknown => matches!(
self,
HeadUsages { inductive: 0, unknown: _, coinductive: 0, forced_ambiguity: 0 },
),
PathKind::Coinductive => matches!(
self,
HeadUsages { inductive: 0, unknown: 0, coinductive: _, forced_ambiguity: 0 },
),
PathKind::ForcedAmbiguity => matches!(
self,
HeadUsages { inductive: 0, unknown: 0, coinductive: 0, forced_ambiguity: _ },
),
}
}
}

#[derive(Debug, Default)]
Expand Down Expand Up @@ -869,7 +888,7 @@ impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D> {
}

debug!("encountered stack overflow");
D::on_stack_overflow(cx, input)
D::stack_overflow_result(cx, input)
}

/// When reevaluating a goal with a changed provisional result, all provisional cache entry
Expand All @@ -888,7 +907,29 @@ impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D> {
!entries.is_empty()
});
}
}

/// We need to rebase provisional cache entries when popping one of their cycle
/// heads from the stack. This may not necessarily mean that we've actually
/// reached a fixpoint for that cycle head, which impacts the way we rebase
/// provisional cache entries.
enum RebaseReason {
NoCycleUsages,
Ambiguity,
Overflow,
/// We've actually reached a fixpoint.
///
/// This either happens in the first evaluation step for the cycle head.
/// In this case the used provisional result depends on the cycle `PathKind`.
/// We store this path kind to check whether the the provisional cache entry
/// we're rebasing relied on the same cycles.
///
/// In later iterations cycles always return `stack_entry.provisional_result`
/// so we no longer depend on the `PathKind`. We store `None` in that case.
ReachedFixpoint(Option<PathKind>),
}

impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D, X> {
/// A necessary optimization to handle complex solver cycles. A provisional cache entry
/// relies on a set of cycle heads and the path towards these heads. When popping a cycle
/// head from the stack after we've finished computing it, we can't be sure that the
Expand All @@ -908,8 +949,9 @@ impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D> {
/// to me.
fn rebase_provisional_cache_entries(
&mut self,
cx: X,
stack_entry: &StackEntry<X>,
mut mutate_result: impl FnMut(X::Input, X::Result) -> X::Result,
rebase_reason: RebaseReason,
) {
let popped_head_index = self.stack.next_index();
#[allow(rustc::potential_query_instability)]
Expand All @@ -927,6 +969,10 @@ impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D> {
return true;
};

let Some(new_highest_head_index) = heads.opt_highest_cycle_head_index() else {
return false;
};
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moved earlier for potential perf gains. no need to do anything more involved if we need to pop anyways


// We're rebasing an entry `e` over a head `p`. This head
// has a number of own heads `h` it depends on.
//
Expand Down Expand Up @@ -977,22 +1023,37 @@ impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D> {
let eph = ep.extend_with_paths(ph);
heads.insert(head_index, eph, head.usages);
}
}

let Some(head_index) = heads.opt_highest_cycle_head_index() else {
return false;
};
// The provisional cache entry does depend on the provisional result
// of the popped cycle head. We need to mutate the result of our
// provisional cache entry in case we did not reach a fixpoint.
match rebase_reason {
// If the cycle head does not actually depend on itself, then
// the provisional result used by the provisional cache entry
// is not actually equal to the final provisional result. We
// need to discard the provisional cache entry in this case.
RebaseReason::NoCycleUsages => return false,
RebaseReason::Ambiguity => {
*result = D::propagate_ambiguity(cx, input, *result);
}
RebaseReason::Overflow => *result = D::fixpoint_overflow_result(cx, input),
RebaseReason::ReachedFixpoint(None) => {}
RebaseReason::ReachedFixpoint(Some(path_kind)) => {
if !popped_head.usages.is_single(path_kind) {
return false;
}
}
};
}

// We now care about the path from the next highest cycle head to the
// provisional cache entry.
*path_from_head = path_from_head.extend(Self::cycle_path_kind(
&self.stack,
stack_entry.step_kind_from_parent,
head_index,
new_highest_head_index,
));
// Mutate the result of the provisional cache entry in case we did
// not reach a fixpoint.
*result = mutate_result(input, *result);

true
});
!entries.is_empty()
Expand Down Expand Up @@ -1209,33 +1270,19 @@ impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D> {
/// Whether we've reached a fixpoint when evaluating a cycle head.
fn reached_fixpoint(
&mut self,
cx: X,
stack_entry: &StackEntry<X>,
usages: HeadUsages,
result: X::Result,
) -> bool {
) -> Result<Option<PathKind>, ()> {
let provisional_result = stack_entry.provisional_result;
if usages.is_empty() {
true
} else if let Some(provisional_result) = provisional_result {
provisional_result == result
if let Some(provisional_result) = provisional_result {
if provisional_result == result { Ok(None) } else { Err(()) }
} else if let Some(path_kind) = D::is_initial_provisional_result(result)
.filter(|&path_kind| usages.is_single(path_kind))
{
Ok(Some(path_kind))
} else {
let check = |k| D::is_initial_provisional_result(cx, k, stack_entry.input, result);
match usages {
HeadUsages { inductive: _, unknown: 0, coinductive: 0, forced_ambiguity: 0 } => {
check(PathKind::Inductive)
}
HeadUsages { inductive: 0, unknown: _, coinductive: 0, forced_ambiguity: 0 } => {
check(PathKind::Unknown)
}
HeadUsages { inductive: 0, unknown: 0, coinductive: _, forced_ambiguity: 0 } => {
check(PathKind::Coinductive)
}
HeadUsages { inductive: 0, unknown: 0, coinductive: 0, forced_ambiguity: _ } => {
check(PathKind::ForcedAmbiguity)
}
_ => false,
}
Err(())
}
}

Expand Down Expand Up @@ -1280,8 +1327,19 @@ impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D> {
// is equal to the provisional result of the previous iteration, or because
// this was only the head of either coinductive or inductive cycles, and the
// final result is equal to the initial response for that case.
if self.reached_fixpoint(cx, &stack_entry, usages, result) {
self.rebase_provisional_cache_entries(&stack_entry, |_, result| result);
if let Ok(fixpoint) = self.reached_fixpoint(&stack_entry, usages, result) {
self.rebase_provisional_cache_entries(
cx,
&stack_entry,
RebaseReason::ReachedFixpoint(fixpoint),
);
return EvaluationResult::finalize(stack_entry, encountered_overflow, result);
} else if usages.is_empty() {
self.rebase_provisional_cache_entries(
cx,
&stack_entry,
RebaseReason::NoCycleUsages,
);
return EvaluationResult::finalize(stack_entry, encountered_overflow, result);
}

Expand All @@ -1298,9 +1356,7 @@ impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D> {
// we also taint all provisional cache entries which depend on the
// current goal.
if D::is_ambiguous_result(result) {
self.rebase_provisional_cache_entries(&stack_entry, |input, _| {
D::propagate_ambiguity(cx, input, result)
});
self.rebase_provisional_cache_entries(cx, &stack_entry, RebaseReason::Ambiguity);
return EvaluationResult::finalize(stack_entry, encountered_overflow, result);
};

Expand All @@ -1309,10 +1365,8 @@ impl<D: Delegate<Cx = X>, X: Cx> SearchGraph<D> {
i += 1;
if i >= D::FIXPOINT_STEP_LIMIT {
debug!("canonical cycle overflow");
let result = D::on_fixpoint_overflow(cx, input);
self.rebase_provisional_cache_entries(&stack_entry, |input, _| {
D::on_fixpoint_overflow(cx, input)
});
let result = D::fixpoint_overflow_result(cx, input);
self.rebase_provisional_cache_entries(cx, &stack_entry, RebaseReason::Overflow);
return EvaluationResult::finalize(stack_entry, encountered_overflow, result);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
//@ compile-flags: -Znext-solver
//@ check-pass

// A regression test for trait-system-refactor-initiative#232. We've
// previously incorrectly rebased provisional cache entries even if
// the cycle head didn't reach a fixpoint as it did not depend on any
// cycles itself.
//
// Just because the result of a goal does not depend on its own provisional
// result, it does not mean its nested goals don't depend on its result.
struct B;
struct C;
struct D;

pub trait Trait {
type Output;
}
macro_rules! k {
($t:ty) => {
<$t as Trait>::Output
};
}

trait CallB<T1, T2> {
type Output;
type Return;
}

trait CallC<T1> {
type Output;
type Return;
}

trait CallD<T1, T2> {
type Output;
}

fn foo<X, Y>()
where
X: Trait,
Y: Trait,
D: CallD<k![X], k![Y]>,
C: CallC<<D as CallD<k![X], k![Y]>>::Output>,
<C as CallC<<D as CallD<k![X], k![Y]>>::Output>>::Output: Trait,
B: CallB<
<C as CallC<<D as CallD<k![X], k![Y]>>::Output>>::Return,
<C as CallC<<D as CallD<k![X], k![Y]>>::Output>>::Output,
>,
<B as CallB<
<C as CallC<<D as CallD<k![X], k![Y]>>::Output>>::Return,
<C as CallC<<D as CallD<k![X], k![Y]>>::Output>>::Output,
>>::Output: Trait<Output = ()>,
{
}
fn main() {}
Loading