Skip to content

Commit

Permalink
Auto merge of rust-lang#136074 - compiler-errors:deeply-normalize-nex…
Browse files Browse the repository at this point in the history
…t-solver, r=lcnr

Properly deeply normalize in the next solver

Turn deep normalization into a `TypeOp`. In the old solver, just dispatch to the `Normalize` type op, but in the new solver call `deeply_normalize`. I chose to separate it into a different type op b/c some normalization is a no-op in the new solver, so this distinguishes just the normalization we need for correctness.

Then use `DeeplyNormalize` in the callsites we used to be using a `CustomTypeOp` (for normalizing known type outlives obligations), and also use it to normalize function args and impl headers in the new solver.

Finally, use it to normalize signatures for WF checks in the new solver as well. This addresses rust-lang/trait-system-refactor-initiative#146.
  • Loading branch information
bors committed Feb 12, 2025
2 parents 34a5ea9 + d5be3ba commit 672e3aa
Show file tree
Hide file tree
Showing 12 changed files with 250 additions and 47 deletions.
59 changes: 57 additions & 2 deletions compiler/rustc_borrowck/src/diagnostics/bound_region_errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ use rustc_infer::infer::{
};
use rustc_infer::traits::ObligationCause;
use rustc_infer::traits::query::{
CanonicalTypeOpAscribeUserTypeGoal, CanonicalTypeOpNormalizeGoal,
CanonicalTypeOpProvePredicateGoal,
CanonicalTypeOpAscribeUserTypeGoal, CanonicalTypeOpDeeplyNormalizeGoal,
CanonicalTypeOpNormalizeGoal, CanonicalTypeOpProvePredicateGoal,
};
use rustc_middle::ty::error::TypeError;
use rustc_middle::ty::{
Expand Down Expand Up @@ -109,6 +109,14 @@ impl<'tcx, T: Copy + fmt::Display + TypeFoldable<TyCtxt<'tcx>> + 'tcx> ToUnivers
}
}

impl<'tcx, T: Copy + fmt::Display + TypeFoldable<TyCtxt<'tcx>> + 'tcx> ToUniverseInfo<'tcx>
for CanonicalTypeOpDeeplyNormalizeGoal<'tcx, T>
{
fn to_universe_info(self, base_universe: ty::UniverseIndex) -> UniverseInfo<'tcx> {
UniverseInfo::TypeOp(Rc::new(DeeplyNormalizeQuery { canonical_query: self, base_universe }))
}
}

impl<'tcx> ToUniverseInfo<'tcx> for CanonicalTypeOpAscribeUserTypeGoal<'tcx> {
fn to_universe_info(self, base_universe: ty::UniverseIndex) -> UniverseInfo<'tcx> {
UniverseInfo::TypeOp(Rc::new(AscribeUserTypeQuery { canonical_query: self, base_universe }))
Expand Down Expand Up @@ -285,6 +293,53 @@ where
}
}

struct DeeplyNormalizeQuery<'tcx, T> {
canonical_query: CanonicalTypeOpDeeplyNormalizeGoal<'tcx, T>,
base_universe: ty::UniverseIndex,
}

impl<'tcx, T> TypeOpInfo<'tcx> for DeeplyNormalizeQuery<'tcx, T>
where
T: Copy + fmt::Display + TypeFoldable<TyCtxt<'tcx>> + 'tcx,
{
fn fallback_error(&self, tcx: TyCtxt<'tcx>, span: Span) -> Diag<'tcx> {
tcx.dcx().create_err(HigherRankedLifetimeError {
cause: Some(HigherRankedErrorCause::CouldNotNormalize {
value: self.canonical_query.canonical.value.value.value.to_string(),
}),
span,
})
}

fn base_universe(&self) -> ty::UniverseIndex {
self.base_universe
}

fn nice_error<'infcx>(
&self,
mbcx: &mut MirBorrowckCtxt<'_, 'infcx, 'tcx>,
cause: ObligationCause<'tcx>,
placeholder_region: ty::Region<'tcx>,
error_region: Option<ty::Region<'tcx>>,
) -> Option<Diag<'infcx>> {
let (infcx, key, _) =
mbcx.infcx.tcx.infer_ctxt().build_with_canonical(cause.span, &self.canonical_query);
let ocx = ObligationCtxt::new(&infcx);

let (param_env, value) = key.into_parts();
let _ = ocx.deeply_normalize(&cause, param_env, value.value);

let diag = try_extract_error_from_fulfill_cx(
&ocx,
mbcx.mir_def_id(),
placeholder_region,
error_region,
)?
.with_dcx(mbcx.dcx());
Some(diag)
}
}

struct AscribeUserTypeQuery<'tcx> {
canonical_query: CanonicalTypeOpAscribeUserTypeGoal<'tcx>,
base_universe: ty::UniverseIndex,
Expand Down
12 changes: 12 additions & 0 deletions compiler/rustc_borrowck/src/type_check/canonical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,18 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
self.normalize_with_category(value, location, ConstraintCategory::Boring)
}

pub(super) fn deeply_normalize<T>(&mut self, value: T, location: impl NormalizeLocation) -> T
where
T: type_op::normalize::Normalizable<'tcx> + fmt::Display + Copy + 'tcx,
{
let result: Result<_, ErrorGuaranteed> = self.fully_perform_op(
location.to_locations(),
ConstraintCategory::Boring,
self.infcx.param_env.and(type_op::normalize::DeeplyNormalize { value }),
);
result.unwrap_or(value)
}

#[instrument(skip(self), level = "debug")]
pub(super) fn normalize_with_category<T>(
&mut self,
Expand Down
21 changes: 3 additions & 18 deletions compiler/rustc_borrowck/src/type_check/constraint_conversion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,12 @@ use rustc_infer::infer::outlives::env::RegionBoundPairs;
use rustc_infer::infer::outlives::obligations::{TypeOutlives, TypeOutlivesDelegate};
use rustc_infer::infer::region_constraints::{GenericKind, VerifyBound};
use rustc_infer::infer::{self, InferCtxt, SubregionOrigin};
use rustc_infer::traits::query::type_op::DeeplyNormalize;
use rustc_middle::bug;
use rustc_middle::mir::{ClosureOutlivesSubject, ClosureRegionRequirements, ConstraintCategory};
use rustc_middle::traits::ObligationCause;
use rustc_middle::traits::query::NoSolution;
use rustc_middle::ty::fold::fold_regions;
use rustc_middle::ty::{self, GenericArgKind, Ty, TyCtxt, TypeFoldable, TypeVisitableExt};
use rustc_span::Span;
use rustc_trait_selection::traits::ScrubbedTraitError;
use rustc_trait_selection::traits::query::type_op::custom::CustomTypeOp;
use rustc_trait_selection::traits::query::type_op::{TypeOp, TypeOpOutput};
use tracing::{debug, instrument};

Expand Down Expand Up @@ -270,20 +267,8 @@ impl<'a, 'tcx> ConstraintConversion<'a, 'tcx> {
ConstraintCategory<'tcx>,
)>,
) -> Ty<'tcx> {
let result = CustomTypeOp::new(
|ocx| {
ocx.deeply_normalize(
&ObligationCause::dummy_with_span(self.span),
self.param_env,
ty,
)
.map_err(|_: Vec<ScrubbedTraitError<'tcx>>| NoSolution)
},
"normalize type outlives obligation",
)
.fully_perform(self.infcx, self.span);

match result {
match self.param_env.and(DeeplyNormalize { value: ty }).fully_perform(self.infcx, self.span)
{
Ok(TypeOpOutput { output: ty, constraints, .. }) => {
if let Some(QueryRegionConstraints { outlives }) = constraints {
next_outlives_predicates.extend(outlives.iter().copied());
Expand Down
28 changes: 8 additions & 20 deletions compiler/rustc_borrowck/src/type_check/free_region_relations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,11 @@ use rustc_infer::infer::canonical::QueryRegionConstraints;
use rustc_infer::infer::outlives::env::RegionBoundPairs;
use rustc_infer::infer::region_constraints::GenericKind;
use rustc_infer::infer::{InferCtxt, outlives};
use rustc_infer::traits::ScrubbedTraitError;
use rustc_infer::traits::query::type_op::DeeplyNormalize;
use rustc_middle::mir::ConstraintCategory;
use rustc_middle::traits::ObligationCause;
use rustc_middle::traits::query::OutlivesBound;
use rustc_middle::ty::{self, RegionVid, Ty, TypeVisitableExt};
use rustc_span::{ErrorGuaranteed, Span};
use rustc_trait_selection::solve::NoSolution;
use rustc_trait_selection::traits::query::type_op::custom::CustomTypeOp;
use rustc_trait_selection::traits::query::type_op::{self, TypeOp};
use tracing::{debug, instrument};
use type_op::TypeOpOutput;
Expand Down Expand Up @@ -267,7 +264,7 @@ impl<'tcx> UniversalRegionRelationsBuilder<'_, 'tcx> {
}
let TypeOpOutput { output: norm_ty, constraints: constraints_normalize, .. } =
param_env
.and(type_op::normalize::Normalize { value: ty })
.and(DeeplyNormalize { value: ty })
.fully_perform(self.infcx, span)
.unwrap_or_else(|guar| TypeOpOutput {
output: Ty::new_error(self.infcx.tcx, guar),
Expand Down Expand Up @@ -303,9 +300,8 @@ impl<'tcx> UniversalRegionRelationsBuilder<'_, 'tcx> {
// Add implied bounds from impl header.
if matches!(tcx.def_kind(defining_ty_def_id), DefKind::AssocFn | DefKind::AssocConst) {
for &(ty, _) in tcx.assumed_wf_types(tcx.local_parent(defining_ty_def_id)) {
let result: Result<_, ErrorGuaranteed> = param_env
.and(type_op::normalize::Normalize { value: ty })
.fully_perform(self.infcx, span);
let result: Result<_, ErrorGuaranteed> =
param_env.and(DeeplyNormalize { value: ty }).fully_perform(self.infcx, span);
let Ok(TypeOpOutput { output: norm_ty, constraints: c, .. }) = result else {
continue;
};
Expand Down Expand Up @@ -360,18 +356,10 @@ impl<'tcx> UniversalRegionRelationsBuilder<'_, 'tcx> {
output: normalized_outlives,
constraints: constraints_normalize,
error_info: _,
}) = CustomTypeOp::new(
|ocx| {
ocx.deeply_normalize(
&ObligationCause::dummy_with_span(span),
self.param_env,
outlives,
)
.map_err(|_: Vec<ScrubbedTraitError<'tcx>>| NoSolution)
},
"normalize type outlives obligation",
)
.fully_perform(self.infcx, span)
}) = self
.param_env
.and(DeeplyNormalize { value: outlives })
.fully_perform(self.infcx, span)
else {
self.infcx.dcx().delayed_bug(format!("could not normalize {outlives:?}"));
return;
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_borrowck/src/type_check/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1116,7 +1116,7 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
ConstraintCategory::Boring,
);

let sig = self.normalize(unnormalized_sig, term_location);
let sig = self.deeply_normalize(unnormalized_sig, term_location);
// HACK(#114936): `WF(sig)` does not imply `WF(normalized(sig))`
// with built-in `Fn` implementations, since the impl may not be
// well-formed itself.
Expand Down
10 changes: 10 additions & 0 deletions compiler/rustc_middle/src/traits/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,18 @@ pub mod type_op {
pub predicate: Predicate<'tcx>,
}

/// Normalizes, but not in the new solver.
#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq, HashStable, TypeFoldable, TypeVisitable)]
pub struct Normalize<T> {
pub value: T,
}

/// Normalizes, and deeply normalizes in the new solver.
#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq, HashStable, TypeFoldable, TypeVisitable)]
pub struct DeeplyNormalize<T> {
pub value: T,
}

#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq, HashStable, TypeFoldable, TypeVisitable)]
pub struct ImpliedOutlivesBounds<'tcx> {
pub ty: Ty<'tcx>,
Expand Down Expand Up @@ -80,6 +87,9 @@ pub type CanonicalTypeOpProvePredicateGoal<'tcx> =
pub type CanonicalTypeOpNormalizeGoal<'tcx, T> =
CanonicalQueryInput<'tcx, ty::ParamEnvAnd<'tcx, type_op::Normalize<T>>>;

pub type CanonicalTypeOpDeeplyNormalizeGoal<'tcx, T> =
CanonicalQueryInput<'tcx, ty::ParamEnvAnd<'tcx, type_op::DeeplyNormalize<T>>>;

pub type CanonicalImpliedOutlivesBoundsGoal<'tcx> =
CanonicalQueryInput<'tcx, ty::ParamEnvAnd<'tcx, type_op::ImpliedOutlivesBounds<'tcx>>>;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::fmt;

use rustc_middle::traits::ObligationCause;
use rustc_middle::traits::query::NoSolution;
pub use rustc_middle::traits::query::type_op::Normalize;
pub use rustc_middle::traits::query::type_op::{DeeplyNormalize, Normalize};
use rustc_middle::ty::fold::TypeFoldable;
use rustc_middle::ty::{self, Lift, ParamEnvAnd, Ty, TyCtxt, TypeVisitableExt};
use rustc_span::Span;
Expand All @@ -27,13 +27,54 @@ where
T::type_op_method(tcx, canonicalized)
}

fn perform_locally_with_next_solver(
_ocx: &ObligationCtxt<'_, 'tcx>,
key: ParamEnvAnd<'tcx, Self>,
_span: Span,
) -> Result<Self::QueryResponse, NoSolution> {
Ok(key.value.value)
}
}

impl<'tcx, T> super::QueryTypeOp<'tcx> for DeeplyNormalize<T>
where
T: Normalizable<'tcx> + 'tcx,
{
type QueryResponse = T;

fn try_fast_path(_tcx: TyCtxt<'tcx>, key: &ParamEnvAnd<'tcx, Self>) -> Option<T> {
if !key.value.value.has_aliases() { Some(key.value.value) } else { None }
}

fn perform_query(
tcx: TyCtxt<'tcx>,
canonicalized: CanonicalQueryInput<'tcx, ParamEnvAnd<'tcx, Self>>,
) -> Result<CanonicalQueryResponse<'tcx, Self::QueryResponse>, NoSolution> {
T::type_op_method(
tcx,
CanonicalQueryInput {
typing_mode: canonicalized.typing_mode,
canonical: canonicalized.canonical.unchecked_map(
|ty::ParamEnvAnd { param_env, value }| ty::ParamEnvAnd {
param_env,
value: Normalize { value: value.value },
},
),
},
)
}

fn perform_locally_with_next_solver(
ocx: &ObligationCtxt<'_, 'tcx>,
key: ParamEnvAnd<'tcx, Self>,
span: Span,
) -> Result<Self::QueryResponse, NoSolution> {
// FIXME(-Znext-solver): shouldn't be using old normalizer
Ok(ocx.normalize(&ObligationCause::dummy_with_span(span), key.param_env, key.value.value))
ocx.deeply_normalize(
&ObligationCause::dummy_with_span(span),
key.param_env,
key.value.value,
)
.map_err(|_| NoSolution)
}
}

Expand Down Expand Up @@ -81,3 +122,14 @@ impl<'tcx> Normalizable<'tcx> for ty::FnSig<'tcx> {
tcx.type_op_normalize_fn_sig(canonicalized)
}
}

/// This impl is not needed, since we never normalize type outlives predicates
/// in the old solver, but is required by trait bounds to be happy.
impl<'tcx> Normalizable<'tcx> for ty::PolyTypeOutlivesPredicate<'tcx> {
fn type_op_method(
_tcx: TyCtxt<'tcx>,
_canonicalized: CanonicalQueryInput<'tcx, ParamEnvAnd<'tcx, Normalize<Self>>>,
) -> Result<CanonicalQueryResponse<'tcx, Self>, NoSolution> {
unreachable!("we never normalize PolyTypeOutlivesPredicate")
}
}
28 changes: 28 additions & 0 deletions tests/ui/borrowck/implied-bound-from-impl-header.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
//@ check-pass
//@ revisions: current next
//@ ignore-compare-mode-next-solver (explicit revisions)
//@[next] compile-flags: -Znext-solver

// Make sure that we can normalize `<T as Ref<'a>>::Assoc` to `&'a T` and get
// its implied bounds in impl header.

trait Ref<'a> {
type Assoc;
}
impl<'a, T> Ref<'a> for T where T: 'a {
type Assoc = &'a T;
}

fn outlives<'a, T: 'a>() {}

trait Trait<'a, T> {
fn test();
}

impl<'a, T> Trait<'a, T> for <T as Ref<'a>>::Assoc {
fn test() {
outlives::<'a, T>();
}
}

fn main() {}
22 changes: 22 additions & 0 deletions tests/ui/borrowck/implied-bound-from-normalized-arg.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
//@ check-pass
//@ revisions: current next
//@ ignore-compare-mode-next-solver (explicit revisions)
//@[next] compile-flags: -Znext-solver

// Make sure that we can normalize `<T as Ref<'a>>::Assoc` to `&'a T` and get
// its implied bounds.

trait Ref<'a> {
type Assoc;
}
impl<'a, T> Ref<'a> for T where T: 'a {
type Assoc = &'a T;
}

fn outlives<'a, T: 'a>() {}

fn test<'a, T>(_: <T as Ref<'a>>::Assoc) {
outlives::<'a, T>();
}

fn main() {}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
error[E0597]: `s` does not live long enough
--> $DIR/check-normalized-sig-for-wf.rs:7:7
--> $DIR/check-normalized-sig-for-wf.rs:11:7
|
LL | s: String,
| - binding `s` declared here
Expand All @@ -14,7 +14,7 @@ LL | }
| - `s` dropped here while still borrowed

error[E0521]: borrowed data escapes outside of function
--> $DIR/check-normalized-sig-for-wf.rs:15:5
--> $DIR/check-normalized-sig-for-wf.rs:19:5
|
LL | fn extend<T>(input: &T) -> &'static T {
| ----- - let's call the lifetime of this reference `'1`
Expand All @@ -28,7 +28,7 @@ LL | n(input).0
| argument requires that `'1` must outlive `'static`

error[E0521]: borrowed data escapes outside of function
--> $DIR/check-normalized-sig-for-wf.rs:23:5
--> $DIR/check-normalized-sig-for-wf.rs:27:5
|
LL | fn extend_mut<'a, T>(input: &'a mut T) -> &'static mut T {
| -- ----- `input` is a reference that is only valid in the function body
Expand Down
Loading

0 comments on commit 672e3aa

Please sign in to comment.