Skip to content

Commit

Permalink
Auto merge of rust-lang#107507 - BoxyUwU:deferred_projection_equality…
Browse files Browse the repository at this point in the history
…, r=lcnr

Implement `deferred_projection_equality` for erica solver

Somewhat of a revival of rust-lang#96912. When relating projections now emit an `AliasEq` obligation instead of attempting to determine equality of projections that may not be as normalized as possible (i.e. because of lazy norm, or just containing inference variables that prevent us from resolving an impl). Only do this when the new solver is enabled
  • Loading branch information
bors committed Feb 11, 2023
2 parents 5a8dfd9 + 4c98429 commit 1623ab0
Show file tree
Hide file tree
Showing 46 changed files with 585 additions and 163 deletions.
6 changes: 1 addition & 5 deletions compiler/rustc_borrowck/src/type_check/relate_tys.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use rustc_infer::infer::nll_relate::{NormalizationStrategy, TypeRelating, TypeRelatingDelegate};
use rustc_infer::infer::nll_relate::{TypeRelating, TypeRelatingDelegate};
use rustc_infer::infer::NllRegionVariableOrigin;
use rustc_infer::traits::PredicateObligations;
use rustc_middle::mir::ConstraintCategory;
Expand Down Expand Up @@ -140,10 +140,6 @@ impl<'tcx> TypeRelatingDelegate<'tcx> for NllTypeRelatingDelegate<'_, '_, 'tcx>
);
}

fn normalization() -> NormalizationStrategy {
NormalizationStrategy::Eager
}

fn forbid_inference_vars() -> bool {
true
}
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_hir_analysis/src/astconv/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1320,6 +1320,7 @@ impl<'o, 'tcx> dyn AstConv<'tcx> + 'o {
ty::Clause::RegionOutlives(_) => bug!(),
},
ty::PredicateKind::WellFormed(_)
| ty::PredicateKind::AliasEq(..)
| ty::PredicateKind::ObjectSafe(_)
| ty::PredicateKind::ClosureKind(_, _, _)
| ty::PredicateKind::Subtype(_)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,7 @@ fn trait_predicate_kind<'tcx>(
ty::PredicateKind::Clause(ty::Clause::RegionOutlives(_))
| ty::PredicateKind::Clause(ty::Clause::TypeOutlives(_))
| ty::PredicateKind::Clause(ty::Clause::Projection(_))
| ty::PredicateKind::AliasEq(..)
| ty::PredicateKind::WellFormed(_)
| ty::PredicateKind::Subtype(_)
| ty::PredicateKind::Coerce(_)
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_hir_analysis/src/outlives/explicit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ impl<'tcx> ExplicitPredicatesMap<'tcx> {
ty::PredicateKind::Clause(ty::Clause::Trait(..))
| ty::PredicateKind::Clause(ty::Clause::Projection(..))
| ty::PredicateKind::WellFormed(..)
| ty::PredicateKind::AliasEq(..)
| ty::PredicateKind::ObjectSafe(..)
| ty::PredicateKind::ClosureKind(..)
| ty::PredicateKind::Subtype(..)
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_hir_typeck/src/fn_ctxt/_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -669,6 +669,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
| ty::PredicateKind::Clause(ty::Clause::TypeOutlives(..))
| ty::PredicateKind::WellFormed(..)
| ty::PredicateKind::ObjectSafe(..)
| ty::PredicateKind::AliasEq(..)
| ty::PredicateKind::ConstEvaluatable(..)
| ty::PredicateKind::ConstEquate(..)
// N.B., this predicate is created by breaking down a
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_hir_typeck/src/method/probe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -837,6 +837,7 @@ impl<'a, 'tcx> ProbeContext<'a, 'tcx> {
| ty::PredicateKind::ConstEvaluatable(..)
| ty::PredicateKind::ConstEquate(..)
| ty::PredicateKind::Ambiguous
| ty::PredicateKind::AliasEq(..)
| ty::PredicateKind::TypeWellFormedFromEnv(..) => None,
}
});
Expand Down
6 changes: 1 addition & 5 deletions compiler/rustc_infer/src/infer/canonical/query_response.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use crate::infer::canonical::{
Canonical, CanonicalQueryResponse, CanonicalVarValues, Certainty, OriginalQueryValues,
QueryOutlivesConstraint, QueryRegionConstraints, QueryResponse,
};
use crate::infer::nll_relate::{NormalizationStrategy, TypeRelating, TypeRelatingDelegate};
use crate::infer::nll_relate::{TypeRelating, TypeRelatingDelegate};
use crate::infer::region_constraints::{Constraint, RegionConstraintData};
use crate::infer::{InferCtxt, InferOk, InferResult, NllRegionVariableOrigin};
use crate::traits::query::{Fallible, NoSolution};
Expand Down Expand Up @@ -717,10 +717,6 @@ impl<'tcx> TypeRelatingDelegate<'tcx> for QueryTypeRelatingDelegate<'_, 'tcx> {
});
}

fn normalization() -> NormalizationStrategy {
NormalizationStrategy::Eager
}

fn forbid_inference_vars() -> bool {
true
}
Expand Down
87 changes: 58 additions & 29 deletions compiler/rustc_infer/src/infer/combine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ use rustc_middle::ty::error::{ExpectedFound, TypeError};
use rustc_middle::ty::relate::{self, Relate, RelateResult, TypeRelation};
use rustc_middle::ty::subst::SubstsRef;
use rustc_middle::ty::{
self, FallibleTypeFolder, InferConst, Ty, TyCtxt, TypeFoldable, TypeSuperFoldable,
TypeVisitable,
self, AliasKind, FallibleTypeFolder, InferConst, ToPredicate, Ty, TyCtxt, TypeFoldable,
TypeSuperFoldable, TypeVisitable,
};
use rustc_middle::ty::{IntType, UintType};
use rustc_span::{Span, DUMMY_SP};
Expand Down Expand Up @@ -74,7 +74,7 @@ impl<'tcx> InferCtxt<'tcx> {
b: Ty<'tcx>,
) -> RelateResult<'tcx, Ty<'tcx>>
where
R: TypeRelation<'tcx>,
R: ObligationEmittingRelation<'tcx>,
{
let a_is_expected = relation.a_is_expected();

Expand Down Expand Up @@ -122,6 +122,15 @@ impl<'tcx> InferCtxt<'tcx> {
Err(TypeError::Sorts(ty::relate::expected_found(relation, a, b)))
}

(ty::Alias(AliasKind::Projection, _), _) if self.tcx.trait_solver_next() => {
relation.register_type_equate_obligation(a.into(), b.into());
Ok(b)
}
(_, ty::Alias(AliasKind::Projection, _)) if self.tcx.trait_solver_next() => {
relation.register_type_equate_obligation(b.into(), a.into());
Ok(a)
}

_ => ty::relate::super_relate_tys(relation, a, b),
}
}
Expand All @@ -133,7 +142,7 @@ impl<'tcx> InferCtxt<'tcx> {
b: ty::Const<'tcx>,
) -> RelateResult<'tcx, ty::Const<'tcx>>
where
R: ConstEquateRelation<'tcx>,
R: ObligationEmittingRelation<'tcx>,
{
debug!("{}.consts({:?}, {:?})", relation.tag(), a, b);
if a == b {
Expand Down Expand Up @@ -169,15 +178,15 @@ impl<'tcx> InferCtxt<'tcx> {
// FIXME(#59490): Need to remove the leak check to accommodate
// escaping bound variables here.
if !a.has_escaping_bound_vars() && !b.has_escaping_bound_vars() {
relation.const_equate_obligation(a, b);
relation.register_const_equate_obligation(a, b);
}
return Ok(b);
}
(_, ty::ConstKind::Unevaluated(..)) if self.tcx.lazy_normalization() => {
// FIXME(#59490): Need to remove the leak check to accommodate
// escaping bound variables here.
if !a.has_escaping_bound_vars() && !b.has_escaping_bound_vars() {
relation.const_equate_obligation(a, b);
relation.register_const_equate_obligation(a, b);
}
return Ok(a);
}
Expand Down Expand Up @@ -435,32 +444,21 @@ impl<'infcx, 'tcx> CombineFields<'infcx, 'tcx> {
Ok(Generalization { ty, needs_wf })
}

pub fn add_const_equate_obligation(
pub fn register_obligations(&mut self, obligations: PredicateObligations<'tcx>) {
self.obligations.extend(obligations.into_iter());
}

pub fn register_predicates(
&mut self,
a_is_expected: bool,
a: ty::Const<'tcx>,
b: ty::Const<'tcx>,
obligations: impl IntoIterator<Item = impl ToPredicate<'tcx>>,
) {
let predicate = if a_is_expected {
ty::PredicateKind::ConstEquate(a, b)
} else {
ty::PredicateKind::ConstEquate(b, a)
};
self.obligations.push(Obligation::new(
self.tcx(),
self.trace.cause.clone(),
self.param_env,
ty::Binder::dummy(predicate),
));
self.obligations.extend(obligations.into_iter().map(|to_pred| {
Obligation::new(self.infcx.tcx, self.trace.cause.clone(), self.param_env, to_pred)
}))
}

pub fn mark_ambiguous(&mut self) {
self.obligations.push(Obligation::new(
self.tcx(),
self.trace.cause.clone(),
self.param_env,
ty::Binder::dummy(ty::PredicateKind::Ambiguous),
));
self.register_predicates([ty::Binder::dummy(ty::PredicateKind::Ambiguous)]);
}
}

Expand Down Expand Up @@ -779,11 +777,42 @@ impl<'tcx> TypeRelation<'tcx> for Generalizer<'_, 'tcx> {
}
}

pub trait ConstEquateRelation<'tcx>: TypeRelation<'tcx> {
pub trait ObligationEmittingRelation<'tcx>: TypeRelation<'tcx> {
/// Register obligations that must hold in order for this relation to hold
fn register_obligations(&mut self, obligations: PredicateObligations<'tcx>);

/// Register predicates that must hold in order for this relation to hold. Uses
/// a default obligation cause, [`ObligationEmittingRelation::register_obligations`] should
/// be used if control over the obligaton causes is required.
fn register_predicates(
&mut self,
obligations: impl IntoIterator<Item = impl ToPredicate<'tcx>>,
);

/// Register an obligation that both constants must be equal to each other.
///
/// If they aren't equal then the relation doesn't hold.
fn const_equate_obligation(&mut self, a: ty::Const<'tcx>, b: ty::Const<'tcx>);
fn register_const_equate_obligation(&mut self, a: ty::Const<'tcx>, b: ty::Const<'tcx>) {
let (a, b) = if self.a_is_expected() { (a, b) } else { (b, a) };

self.register_predicates([ty::Binder::dummy(if self.tcx().trait_solver_next() {
ty::PredicateKind::AliasEq(a.into(), b.into())
} else {
ty::PredicateKind::ConstEquate(a, b)
})]);
}

/// Register an obligation that both types must be equal to each other.
///
/// If they aren't equal then the relation doesn't hold.
fn register_type_equate_obligation(&mut self, a: Ty<'tcx>, b: Ty<'tcx>) {
let (a, b) = if self.a_is_expected() { (a, b) } else { (b, a) };

self.register_predicates([ty::Binder::dummy(ty::PredicateKind::AliasEq(
a.into(),
b.into(),
))]);
}
}

fn int_unification_error<'tcx>(
Expand Down
17 changes: 13 additions & 4 deletions compiler/rustc_infer/src/infer/equate.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use super::combine::{CombineFields, ConstEquateRelation, RelationDir};
use crate::traits::PredicateObligations;

use super::combine::{CombineFields, ObligationEmittingRelation, RelationDir};
use super::Subtype;

use rustc_middle::ty::relate::{self, Relate, RelateResult, TypeRelation};
Expand Down Expand Up @@ -198,8 +200,15 @@ impl<'tcx> TypeRelation<'tcx> for Equate<'_, '_, 'tcx> {
}
}

impl<'tcx> ConstEquateRelation<'tcx> for Equate<'_, '_, 'tcx> {
fn const_equate_obligation(&mut self, a: ty::Const<'tcx>, b: ty::Const<'tcx>) {
self.fields.add_const_equate_obligation(self.a_is_expected, a, b);
impl<'tcx> ObligationEmittingRelation<'tcx> for Equate<'_, '_, 'tcx> {
fn register_predicates(
&mut self,
obligations: impl IntoIterator<Item = impl ty::ToPredicate<'tcx>>,
) {
self.fields.register_predicates(obligations);
}

fn register_obligations(&mut self, obligations: PredicateObligations<'tcx>) {
self.fields.register_obligations(obligations);
}
}
22 changes: 12 additions & 10 deletions compiler/rustc_infer/src/infer/glb.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
//! Greatest lower bound. See [`lattice`].

use super::combine::CombineFields;
use super::combine::{CombineFields, ObligationEmittingRelation};
use super::lattice::{self, LatticeDir};
use super::InferCtxt;
use super::Subtype;

use crate::infer::combine::ConstEquateRelation;
use crate::traits::{ObligationCause, PredicateObligation};
use crate::traits::{ObligationCause, PredicateObligations};
use rustc_middle::ty::relate::{Relate, RelateResult, TypeRelation};
use rustc_middle::ty::{self, Ty, TyCtxt};

Expand Down Expand Up @@ -136,10 +135,6 @@ impl<'combine, 'infcx, 'tcx> LatticeDir<'infcx, 'tcx> for Glb<'combine, 'infcx,
&self.fields.trace.cause
}

fn add_obligations(&mut self, obligations: Vec<PredicateObligation<'tcx>>) {
self.fields.obligations.extend(obligations)
}

fn relate_bound(&mut self, v: Ty<'tcx>, a: Ty<'tcx>, b: Ty<'tcx>) -> RelateResult<'tcx, ()> {
let mut sub = self.fields.sub(self.a_is_expected);
sub.relate(v, a)?;
Expand All @@ -152,8 +147,15 @@ impl<'combine, 'infcx, 'tcx> LatticeDir<'infcx, 'tcx> for Glb<'combine, 'infcx,
}
}

impl<'tcx> ConstEquateRelation<'tcx> for Glb<'_, '_, 'tcx> {
fn const_equate_obligation(&mut self, a: ty::Const<'tcx>, b: ty::Const<'tcx>) {
self.fields.add_const_equate_obligation(self.a_is_expected, a, b);
impl<'tcx> ObligationEmittingRelation<'tcx> for Glb<'_, '_, 'tcx> {
fn register_predicates(
&mut self,
obligations: impl IntoIterator<Item = impl ty::ToPredicate<'tcx>>,
) {
self.fields.register_predicates(obligations);
}

fn register_obligations(&mut self, obligations: PredicateObligations<'tcx>) {
self.fields.register_obligations(obligations);
}
}
11 changes: 5 additions & 6 deletions compiler/rustc_infer/src/infer/lattice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@
//!
//! [lattices]: https://en.wikipedia.org/wiki/Lattice_(order)

use super::combine::ObligationEmittingRelation;
use super::type_variable::{TypeVariableOrigin, TypeVariableOriginKind};
use super::InferCtxt;

use crate::traits::{ObligationCause, PredicateObligation};
use rustc_middle::ty::relate::{RelateResult, TypeRelation};
use crate::traits::ObligationCause;
use rustc_middle::ty::relate::RelateResult;
use rustc_middle::ty::TyVar;
use rustc_middle::ty::{self, Ty};

Expand All @@ -30,13 +31,11 @@ use rustc_middle::ty::{self, Ty};
///
/// GLB moves "down" the lattice (to smaller values); LUB moves
/// "up" the lattice (to bigger values).
pub trait LatticeDir<'f, 'tcx>: TypeRelation<'tcx> {
pub trait LatticeDir<'f, 'tcx>: ObligationEmittingRelation<'tcx> {
fn infcx(&self) -> &'f InferCtxt<'tcx>;

fn cause(&self) -> &ObligationCause<'tcx>;

fn add_obligations(&mut self, obligations: Vec<PredicateObligation<'tcx>>);

fn define_opaque_types(&self) -> bool;

// Relates the type `v` to `a` and `b` such that `v` represents
Expand Down Expand Up @@ -113,7 +112,7 @@ where
| (_, &ty::Alias(ty::Opaque, ty::AliasTy { def_id, .. }))
if this.define_opaque_types() && def_id.is_local() =>
{
this.add_obligations(
this.register_obligations(
infcx
.handle_opaque_type(a, b, this.a_is_expected(), this.cause(), this.param_env())?
.obligations,
Expand Down
28 changes: 15 additions & 13 deletions compiler/rustc_infer/src/infer/lub.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
//! Least upper bound. See [`lattice`].

use super::combine::CombineFields;
use super::combine::{CombineFields, ObligationEmittingRelation};
use super::lattice::{self, LatticeDir};
use super::InferCtxt;
use super::Subtype;

use crate::infer::combine::ConstEquateRelation;
use crate::traits::{ObligationCause, PredicateObligation};
use crate::traits::{ObligationCause, PredicateObligations};
use rustc_middle::ty::relate::{Relate, RelateResult, TypeRelation};
use rustc_middle::ty::{self, Ty, TyCtxt};

Expand Down Expand Up @@ -127,12 +126,6 @@ impl<'tcx> TypeRelation<'tcx> for Lub<'_, '_, 'tcx> {
}
}

impl<'tcx> ConstEquateRelation<'tcx> for Lub<'_, '_, 'tcx> {
fn const_equate_obligation(&mut self, a: ty::Const<'tcx>, b: ty::Const<'tcx>) {
self.fields.add_const_equate_obligation(self.a_is_expected, a, b);
}
}

impl<'combine, 'infcx, 'tcx> LatticeDir<'infcx, 'tcx> for Lub<'combine, 'infcx, 'tcx> {
fn infcx(&self) -> &'infcx InferCtxt<'tcx> {
self.fields.infcx
Expand All @@ -142,10 +135,6 @@ impl<'combine, 'infcx, 'tcx> LatticeDir<'infcx, 'tcx> for Lub<'combine, 'infcx,
&self.fields.trace.cause
}

fn add_obligations(&mut self, obligations: Vec<PredicateObligation<'tcx>>) {
self.fields.obligations.extend(obligations)
}

fn relate_bound(&mut self, v: Ty<'tcx>, a: Ty<'tcx>, b: Ty<'tcx>) -> RelateResult<'tcx, ()> {
let mut sub = self.fields.sub(self.a_is_expected);
sub.relate(a, v)?;
Expand All @@ -157,3 +146,16 @@ impl<'combine, 'infcx, 'tcx> LatticeDir<'infcx, 'tcx> for Lub<'combine, 'infcx,
self.fields.define_opaque_types
}
}

impl<'tcx> ObligationEmittingRelation<'tcx> for Lub<'_, '_, 'tcx> {
fn register_predicates(
&mut self,
obligations: impl IntoIterator<Item = impl ty::ToPredicate<'tcx>>,
) {
self.fields.register_predicates(obligations);
}

fn register_obligations(&mut self, obligations: PredicateObligations<'tcx>) {
self.fields.register_obligations(obligations)
}
}
Loading

0 comments on commit 1623ab0

Please sign in to comment.