Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix infinite recursion on complex traits #106278

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 23 additions & 20 deletions compiler/rustc_trait_selection/src/traits/select/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ use rustc_middle::ty::print::with_no_trimmed_paths;

mod candidate_assembly;
mod confirmation;
mod only_fresh_differ;

#[derive(Clone, Debug, Eq, PartialEq, Hash)]
pub enum IntercrateAmbiguityCause {
Expand Down Expand Up @@ -1123,29 +1124,24 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
// unbound variable. When we evaluate this trait-reference, we
// will unify `$0` with `Vec<$1>` (for some fresh variable
// `$1`), on the condition that `$1 : Eq`. We will then wind
// up with many candidates (since that are other `Eq` impls
// up with many candidates (since there are other `Eq` impls
// that apply) and try to winnow things down. This results in
// a recursive evaluation that `$1 : Eq` -- as you can
// imagine, this is just where we started. To avoid that, we
// check for unbound variables and return an ambiguous (hence possible)
// match if we've seen this trait before.
// check if the stack contains an obligation that is identical
// up to fresh variables.
//
// This suffices to allow chains like `FnMut` implemented in
// terms of `Fn` etc, but we could probably make this more
// precise still.
let unbound_input_types =
stack.fresh_trait_pred.skip_binder().trait_ref.substs.types().any(|ty| ty.is_fresh());

if unbound_input_types
&& stack.iter().skip(1).any(|prev| {
stack.obligation.param_env == prev.obligation.param_env
&& self.match_fresh_trait_refs(
stack.fresh_trait_pred,
prev.fresh_trait_pred,
prev.obligation.param_env,
)
})
{
if stack.iter().skip(1).any(|prev| {
stack.obligation.param_env == prev.obligation.param_env
&& self.is_repetition(
stack.fresh_trait_pred,
prev.fresh_trait_pred,
prev.obligation.param_env,
)
}) {
debug!("evaluate_stack --> unbound argument, recursive --> giving up",);
return Ok(EvaluatedToUnknown);
}
Expand Down Expand Up @@ -2457,14 +2453,21 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
///////////////////////////////////////////////////////////////////////////
// Miscellany

fn match_fresh_trait_refs(
fn is_repetition(
&self,
previous: ty::PolyTraitPredicate<'tcx>,
current: ty::PolyTraitPredicate<'tcx>,
new: ty::PolyTraitPredicate<'tcx>,
old: ty::PolyTraitPredicate<'tcx>,
param_env: ty::ParamEnv<'tcx>,
) -> bool {
let mut matcher = only_fresh_differ::FreshDiffer::new(self.tcx(), param_env);
if matcher.relate(new, old).is_ok() {
return true;
}

let mut matcher = ty::_match::Match::new(self.tcx(), param_env);
matcher.relate(previous, current).is_ok()

new.skip_binder().trait_ref.substs.types().any(|ty| ty.is_fresh())
&& matcher.relate(new, old).is_ok()
}

fn push_stack<'o>(
Expand Down
116 changes: 116 additions & 0 deletions compiler/rustc_trait_selection/src/traits/select/only_fresh_differ.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
use rustc_middle::ty::error::TypeError;
use rustc_middle::ty::relate::{self, Relate, RelateResult, TypeRelation};
use rustc_middle::ty::{self, InferConst, Ty, TyCtxt};

/// Requires that the types can be unified by substituting fresh variables
/// for fresh variable only.
pub struct FreshDiffer<'tcx> {
tcx: TyCtxt<'tcx>,
param_env: ty::ParamEnv<'tcx>,
}

impl<'tcx> FreshDiffer<'tcx> {
pub fn new(tcx: TyCtxt<'tcx>, param_env: ty::ParamEnv<'tcx>) -> Self {
Self { tcx, param_env }
}
}

impl<'tcx> TypeRelation<'tcx> for FreshDiffer<'tcx> {
fn tag(&self) -> &'static str {
"FreshDiffer"
}
fn tcx(&self) -> TyCtxt<'tcx> {
self.tcx
}

fn intercrate(&self) -> bool {
false
}

fn param_env(&self) -> ty::ParamEnv<'tcx> {
self.param_env
}
fn a_is_expected(&self) -> bool {
true
} // irrelevant

fn mark_ambiguous(&mut self) {
bug!()
}

fn relate_with_variance<T: Relate<'tcx>>(
&mut self,
_: ty::Variance,
_: ty::VarianceDiagInfo<'tcx>,
a: T,
b: T,
) -> RelateResult<'tcx, T> {
self.relate(a, b)
}

#[instrument(skip(self), level = "debug")]
fn regions(
&mut self,
a: ty::Region<'tcx>,
b: ty::Region<'tcx>,
) -> RelateResult<'tcx, ty::Region<'tcx>> {
Ok(a)
}

#[instrument(skip(self), level = "debug")]
fn tys(&mut self, a: Ty<'tcx>, b: Ty<'tcx>) -> RelateResult<'tcx, Ty<'tcx>> {
if a == b {
return Ok(a);
}

match (a.kind(), b.kind()) {
(&ty::Infer(ty::FreshTy(_)), &ty::Infer(ty::FreshTy(_)))
| (&ty::Infer(ty::FreshIntTy(_)), &ty::Infer(ty::FreshIntTy(_)))
| (&ty::Infer(ty::FreshFloatTy(_)), &ty::Infer(ty::FreshFloatTy(_))) => Ok(a),

(&ty::Infer(_), _) | (_, &ty::Infer(_)) => {
Err(TypeError::Sorts(relate::expected_found(self, a, b)))
}

(&ty::Error(_), _) | (_, &ty::Error(_)) => Ok(self.tcx().ty_error()),

_ => relate::super_relate_tys(self, a, b),
}
}

fn consts(
&mut self,
a: ty::Const<'tcx>,
b: ty::Const<'tcx>,
) -> RelateResult<'tcx, ty::Const<'tcx>> {
debug!("{}.consts({:?}, {:?})", self.tag(), a, b);
if a == b {
return Ok(a);
}

match (a.kind(), b.kind()) {
(_, ty::ConstKind::Infer(InferConst::Fresh(_))) => {
return Ok(a);
}

(ty::ConstKind::Infer(_), _) | (_, ty::ConstKind::Infer(_)) => {
return Err(TypeError::ConstMismatch(relate::expected_found(self, a, b)));
}

_ => {}
}

relate::super_relate_consts(self, a, b)
}

fn binders<T>(
&mut self,
a: ty::Binder<'tcx, T>,
b: ty::Binder<'tcx, T>,
) -> RelateResult<'tcx, ty::Binder<'tcx, T>>
where
T: Relate<'tcx>,
{
Ok(a.rebind(self.relate(a.skip_binder(), b.skip_binder())?))
}
}