Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split implied and super predicate queries, then allow elaborator to filter only supertraits #107614

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 37 additions & 31 deletions compiler/rustc_hir_analysis/src/astconv/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1663,39 +1663,45 @@ impl<'o, 'tcx> dyn AstConv<'tcx> + 'o {
})
});

let existential_projections = projection_bounds.iter().map(|(bound, _)| {
bound.map_bound(|mut b| {
assert_eq!(b.projection_ty.self_ty(), dummy_self);

// Like for trait refs, verify that `dummy_self` did not leak inside default type
// parameters.
let references_self = b.projection_ty.substs.iter().skip(1).any(|arg| {
if arg.walk().any(|arg| arg == dummy_self.into()) {
return true;
let existential_projections = projection_bounds
.iter()
// We filter out traits that don't have `Self` as their self type above,
// we need to do the same for projections.
.filter(|(bound, _)| bound.skip_binder().self_ty() == dummy_self)
.map(|(bound, _)| {
bound.map_bound(|mut b| {
assert_eq!(b.projection_ty.self_ty(), dummy_self);

// Like for trait refs, verify that `dummy_self` did not leak inside default type
// parameters.
let references_self = b.projection_ty.substs.iter().skip(1).any(|arg| {
if arg.walk().any(|arg| arg == dummy_self.into()) {
return true;
}
false
});
if references_self {
let guar = tcx.sess.delay_span_bug(
span,
"trait object projection bounds reference `Self`",
);
let substs: Vec<_> = b
.projection_ty
.substs
.iter()
.map(|arg| {
if arg.walk().any(|arg| arg == dummy_self.into()) {
return tcx.ty_error(guar).into();
}
arg
})
.collect();
b.projection_ty.substs = tcx.mk_substs(&substs);
}
false
});
if references_self {
let guar = tcx
.sess
.delay_span_bug(span, "trait object projection bounds reference `Self`");
let substs: Vec<_> = b
.projection_ty
.substs
.iter()
.map(|arg| {
if arg.walk().any(|arg| arg == dummy_self.into()) {
return tcx.ty_error(guar).into();
}
arg
})
.collect();
b.projection_ty.substs = tcx.mk_substs(&substs);
}

ty::ExistentialProjection::erase_self_ty(tcx, b)
})
});
ty::ExistentialProjection::erase_self_ty(tcx, b)
})
});

let regular_trait_predicates = existential_trait_refs
.map(|trait_ref| trait_ref.map_bound(ty::ExistentialPredicate::Trait));
Expand Down
2 changes: 2 additions & 0 deletions compiler/rustc_hir_analysis/src/collect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ pub fn provide(providers: &mut Providers) {
predicates_defined_on,
explicit_predicates_of: predicates_of::explicit_predicates_of,
super_predicates_of: predicates_of::super_predicates_of,
implied_predicates_of: predicates_of::implied_predicates_of,
super_predicates_that_define_assoc_type:
predicates_of::super_predicates_that_define_assoc_type,
trait_explicit_predicates_and_bounds: predicates_of::trait_explicit_predicates_and_bounds,
Expand Down Expand Up @@ -596,6 +597,7 @@ fn convert_item(tcx: TyCtxt<'_>, item_id: hir::ItemId) {
}
hir::ItemKind::TraitAlias(..) => {
tcx.ensure().generics_of(def_id);
tcx.at(it.span).implied_predicates_of(def_id);
tcx.at(it.span).super_predicates_of(def_id);
tcx.ensure().predicates_of(def_id);
}
Expand Down
113 changes: 81 additions & 32 deletions compiler/rustc_hir_analysis/src/collect/predicates_of.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ fn gather_explicit_predicates_of(tcx: TyCtxt<'_>, def_id: LocalDefId) -> ty::Gen
// on a trait we need to add in the supertrait bounds and bounds found on
// associated types.
if let Some(_trait_ref) = is_trait {
predicates.extend(tcx.super_predicates_of(def_id).predicates.iter().cloned());
predicates.extend(tcx.implied_predicates_of(def_id).predicates.iter().cloned());
}

// In default impls, we can assume that the self type implements
Expand Down Expand Up @@ -534,31 +534,62 @@ pub(super) fn explicit_predicates_of<'tcx>(
}
}

#[derive(Copy, Clone, Debug)]
pub enum PredicateFilter {
/// All predicates may be implied by the trait
All,

/// Only traits that reference `Self: ..` are implied by the trait
SelfOnly,

/// Only traits that reference `Self: ..` and define an associated type
/// with the given ident are implied by the trait
SelfThatDefines(Ident),
}

/// Ensures that the super-predicates of the trait with a `DefId`
/// of `trait_def_id` are converted and stored. This also ensures that
/// the transitive super-predicates are converted.
pub(super) fn super_predicates_of(
tcx: TyCtxt<'_>,
trait_def_id: LocalDefId,
) -> ty::GenericPredicates<'_> {
tcx.super_predicates_that_define_assoc_type((trait_def_id.to_def_id(), None))
implied_predicates_with_filter(tcx, trait_def_id.to_def_id(), PredicateFilter::SelfOnly)
}

pub(super) fn super_predicates_that_define_assoc_type(
tcx: TyCtxt<'_>,
(trait_def_id, assoc_name): (DefId, Ident),
) -> ty::GenericPredicates<'_> {
implied_predicates_with_filter(tcx, trait_def_id, PredicateFilter::SelfThatDefines(assoc_name))
}

pub(super) fn implied_predicates_of(
tcx: TyCtxt<'_>,
trait_def_id: LocalDefId,
) -> ty::GenericPredicates<'_> {
if tcx.is_trait_alias(trait_def_id.to_def_id()) {
implied_predicates_with_filter(tcx, trait_def_id.to_def_id(), PredicateFilter::All)
} else {
tcx.super_predicates_of(trait_def_id)
}
}

/// Ensures that the super-predicates of the trait with a `DefId`
/// of `trait_def_id` are converted and stored. This also ensures that
/// the transitive super-predicates are converted.
pub(super) fn super_predicates_that_define_assoc_type(
pub(super) fn implied_predicates_with_filter(
tcx: TyCtxt<'_>,
(trait_def_id, assoc_name): (DefId, Option<Ident>),
trait_def_id: DefId,
filter: PredicateFilter,
) -> ty::GenericPredicates<'_> {
let Some(trait_def_id) = trait_def_id.as_local() else {
// if `assoc_name` is None, then the query should've been redirected to an
// external provider
assert!(assoc_name.is_some());
assert!(matches!(filter, PredicateFilter::SelfThatDefines(_)));
return tcx.super_predicates_of(trait_def_id);
};

debug!("local trait");
let trait_hir_id = tcx.hir().local_def_id_to_hir_id(trait_def_id);

let Node::Item(item) = tcx.hir().get(trait_hir_id) else {
Expand All @@ -573,48 +604,66 @@ pub(super) fn super_predicates_that_define_assoc_type(

let icx = ItemCtxt::new(tcx, trait_def_id);

// Convert the bounds that follow the colon, e.g., `Bar + Zed` in `trait Foo: Bar + Zed`.
let self_param_ty = tcx.types.self_param;
let superbounds1 = if let Some(assoc_name) = assoc_name {
icx.astconv().compute_bounds_that_match_assoc_type(self_param_ty, bounds, assoc_name)
} else {
icx.astconv().compute_bounds(self_param_ty, bounds)
let (superbounds, where_bounds_that_match) = match filter {
PredicateFilter::All => (
// Convert the bounds that follow the colon (or equal in trait aliases)
icx.astconv().compute_bounds(self_param_ty, bounds),
// Also include all where clause bounds
icx.type_parameter_bounds_in_generics(
generics,
item.owner_id.def_id,
self_param_ty,
OnlySelfBounds(false),
None,
),
),
PredicateFilter::SelfOnly => (
// Convert the bounds that follow the colon (or equal in trait aliases)
icx.astconv().compute_bounds(self_param_ty, bounds),
// Include where clause bounds for `Self`
icx.type_parameter_bounds_in_generics(
generics,
item.owner_id.def_id,
self_param_ty,
OnlySelfBounds(true),
None,
),
),
PredicateFilter::SelfThatDefines(assoc_name) => (
// Convert the bounds that follow the colon (or equal) that reference the associated name
icx.astconv().compute_bounds_that_match_assoc_type(self_param_ty, bounds, assoc_name),
// Include where clause bounds for `Self` that reference the associated name
icx.type_parameter_bounds_in_generics(
generics,
item.owner_id.def_id,
self_param_ty,
OnlySelfBounds(true),
Some(assoc_name),
),
),
};

let superbounds1 = superbounds1.predicates();

// Convert any explicit superbounds in the where-clause,
// e.g., `trait Foo where Self: Bar`.
// In the case of trait aliases, however, we include all bounds in the where-clause,
// so e.g., `trait Foo = where u32: PartialEq<Self>` would include `u32: PartialEq<Self>`
// as one of its "superpredicates".
let is_trait_alias = tcx.is_trait_alias(trait_def_id.to_def_id());
let superbounds2 = icx.type_parameter_bounds_in_generics(
generics,
item.owner_id.def_id,
self_param_ty,
OnlySelfBounds(!is_trait_alias),
assoc_name,
);

// Combine the two lists to form the complete set of superbounds:
let superbounds = &*tcx.arena.alloc_from_iter(superbounds1.into_iter().chain(superbounds2));
debug!(?superbounds);
let implied_bounds = &*tcx
.arena
.alloc_from_iter(superbounds.predicates().into_iter().chain(where_bounds_that_match));
debug!(?implied_bounds);

// Now require that immediate supertraits are converted,
// which will, in turn, reach indirect supertraits.
if assoc_name.is_none() {
if matches!(filter, PredicateFilter::SelfOnly) {
// Now require that immediate supertraits are converted,
// which will, in turn, reach indirect supertraits.
for &(pred, span) in superbounds {
for &(pred, span) in implied_bounds {
debug!("superbound: {:?}", pred);
if let ty::PredicateKind::Clause(ty::Clause::Trait(bound)) = pred.kind().skip_binder() {
tcx.at(span).super_predicates_of(bound.def_id());
}
}
}

ty::GenericPredicates { parent: None, predicates: superbounds }
ty::GenericPredicates { parent: None, predicates: implied_bounds }
}

/// Returns the predicates defined on `item_def_id` of the form
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1749,8 +1749,7 @@ impl<'a, 'tcx> BoundVarContext<'a, 'tcx> {
if trait_defines_associated_type_named(def_id) {
break Some(bound_vars.into_iter().collect());
}
let predicates =
tcx.super_predicates_that_define_assoc_type((def_id, Some(assoc_name)));
let predicates = tcx.super_predicates_that_define_assoc_type((def_id, assoc_name));
let obligations = predicates.predicates.iter().filter_map(|&(pred, _)| {
let bound_predicate = pred.kind();
match bound_predicate.skip_binder() {
Expand Down
5 changes: 4 additions & 1 deletion compiler/rustc_hir_typeck/src/closure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,10 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
// and we want to keep inference generally in the same order of
// the registered obligations.
predicates.rev(),
) {
)
// We only care about self bounds
.filter_only_self()
{
debug!(?pred);
let bound_predicate = pred.kind();

Expand Down
47 changes: 28 additions & 19 deletions compiler/rustc_infer/src/traits/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ impl<'tcx> Extend<ty::Predicate<'tcx>> for PredicateSet<'tcx> {
pub struct Elaborator<'tcx, O> {
stack: Vec<O>,
visited: PredicateSet<'tcx>,
only_self: bool,
}

/// Describes how to elaborate an obligation into a sub-obligation.
Expand Down Expand Up @@ -170,7 +171,8 @@ pub fn elaborate<'tcx, O: Elaboratable<'tcx>>(
tcx: TyCtxt<'tcx>,
obligations: impl IntoIterator<Item = O>,
) -> Elaborator<'tcx, O> {
let mut elaborator = Elaborator { stack: Vec::new(), visited: PredicateSet::new(tcx) };
let mut elaborator =
Elaborator { stack: Vec::new(), visited: PredicateSet::new(tcx), only_self: false };
elaborator.extend_deduped(obligations);
elaborator
}
Expand All @@ -185,14 +187,25 @@ impl<'tcx, O: Elaboratable<'tcx>> Elaborator<'tcx, O> {
self.stack.extend(obligations.into_iter().filter(|o| self.visited.insert(o.predicate())));
}

/// Filter to only the supertraits of trait predicates, i.e. only the predicates
/// that have `Self` as their self type, instead of all implied predicates.
pub fn filter_only_self(mut self) -> Self {
self.only_self = true;
self
}

fn elaborate(&mut self, elaboratable: &O) {
let tcx = self.visited.tcx;

let bound_predicate = elaboratable.predicate().kind();
match bound_predicate.skip_binder() {
ty::PredicateKind::Clause(ty::Clause::Trait(data)) => {
// Get predicates declared on the trait.
let predicates = tcx.super_predicates_of(data.def_id());
// Get predicates implied by the trait, or only super predicates if we only care about self predicates.
let predicates = if self.only_self {
tcx.super_predicates_of(data.def_id())
} else {
tcx.implied_predicates_of(data.def_id())
};

let obligations =
predicates.predicates.iter().enumerate().map(|(index, &(mut pred, span))| {
Expand Down Expand Up @@ -350,18 +363,16 @@ pub fn supertraits<'tcx>(
tcx: TyCtxt<'tcx>,
trait_ref: ty::PolyTraitRef<'tcx>,
) -> impl Iterator<Item = ty::PolyTraitRef<'tcx>> {
let pred: ty::Predicate<'tcx> = trait_ref.to_predicate(tcx);
FilterToTraits::new(elaborate(tcx, [pred]))
elaborate(tcx, [trait_ref.to_predicate(tcx)]).filter_only_self().filter_to_traits()
}

pub fn transitive_bounds<'tcx>(
tcx: TyCtxt<'tcx>,
trait_refs: impl Iterator<Item = ty::PolyTraitRef<'tcx>>,
) -> impl Iterator<Item = ty::PolyTraitRef<'tcx>> {
FilterToTraits::new(elaborate(
tcx,
trait_refs.map(|trait_ref| -> ty::Predicate<'tcx> { trait_ref.to_predicate(tcx) }),
))
elaborate(tcx, trait_refs.map(|trait_ref| trait_ref.to_predicate(tcx)))
.filter_only_self()
.filter_to_traits()
}

/// A specialized variant of `elaborate` that only elaborates trait references that may
Expand All @@ -381,10 +392,8 @@ pub fn transitive_bounds_that_define_assoc_type<'tcx>(
while let Some(trait_ref) = stack.pop() {
let anon_trait_ref = tcx.anonymize_bound_vars(trait_ref);
if visited.insert(anon_trait_ref) {
let super_predicates = tcx.super_predicates_that_define_assoc_type((
trait_ref.def_id(),
Some(assoc_name),
));
let super_predicates =
tcx.super_predicates_that_define_assoc_type((trait_ref.def_id(), assoc_name));
for (super_predicate, _) in super_predicates.predicates {
let subst_predicate = super_predicate.subst_supertrait(tcx, &trait_ref);
if let Some(binder) = subst_predicate.to_opt_poly_trait_pred() {
Expand All @@ -404,18 +413,18 @@ pub fn transitive_bounds_that_define_assoc_type<'tcx>(
// Other
///////////////////////////////////////////////////////////////////////////

impl<'tcx> Elaborator<'tcx, ty::Predicate<'tcx>> {
fn filter_to_traits(self) -> FilterToTraits<Self> {
FilterToTraits { base_iterator: self }
}
}

/// A filter around an iterator of predicates that makes it yield up
/// just trait references.
pub struct FilterToTraits<I> {
base_iterator: I,
}

impl<I> FilterToTraits<I> {
fn new(base: I) -> FilterToTraits<I> {
FilterToTraits { base_iterator: base }
}
}

impl<'tcx, I: Iterator<Item = ty::Predicate<'tcx>>> Iterator for FilterToTraits<I> {
type Item = ty::PolyTraitRef<'tcx>;

Expand Down
Loading