Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Collect relevant item bounds from trait clauses for nested rigid projections #120752

Merged
merged 3 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
245 changes: 235 additions & 10 deletions compiler/rustc_hir_analysis/src/collect/item_bounds.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use rustc_data_structures::fx::FxIndexSet;
use rustc_data_structures::fx::{FxIndexMap, FxIndexSet};
use rustc_hir as hir;
use rustc_infer::traits::util;
use rustc_middle::ty::fold::shift_vars;
use rustc_middle::ty::{
self, GenericArgs, Ty, TyCtxt, TypeFoldable, TypeFolder, TypeSuperFoldable,
self, GenericArgs, Ty, TyCtxt, TypeFoldable, TypeFolder, TypeSuperFoldable, TypeVisitableExt,
};
use rustc_middle::{bug, span_bug};
use rustc_span::Span;
Expand Down Expand Up @@ -42,14 +43,18 @@ fn associated_type_bounds<'tcx>(
let trait_def_id = tcx.local_parent(assoc_item_def_id);
let trait_predicates = tcx.trait_explicit_predicates_and_bounds(trait_def_id);

let bounds_from_parent = trait_predicates.predicates.iter().copied().filter(|(pred, _)| {
match pred.kind().skip_binder() {
ty::ClauseKind::Trait(tr) => tr.self_ty() == item_ty,
ty::ClauseKind::Projection(proj) => proj.projection_term.self_ty() == item_ty,
ty::ClauseKind::TypeOutlives(outlives) => outlives.0 == item_ty,
_ => false,
}
});
let item_trait_ref = ty::TraitRef::identity(tcx, tcx.parent(assoc_item_def_id.to_def_id()));
let bounds_from_parent =
trait_predicates.predicates.iter().copied().filter_map(|(clause, span)| {
remap_gat_vars_and_recurse_into_nested_projections(
tcx,
filter,
item_trait_ref,
assoc_item_def_id,
span,
clause,
)
});

let all_bounds = tcx.arena.alloc_from_iter(bounds.clauses(tcx).chain(bounds_from_parent));
debug!(
Expand All @@ -63,6 +68,226 @@ fn associated_type_bounds<'tcx>(
all_bounds
}

/// The code below is quite involved, so let me explain.
///
/// We loop here, because we also want to collect vars for nested associated items as
/// well. For example, given a clause like `Self::A::B`, we want to add that to the
/// item bounds for `A`, so that we may use that bound in the case that `Self::A::B` is
/// rigid.
///
/// Secondly, regarding bound vars, when we see a where clause that mentions a GAT
/// like `for<'a, ...> Self::Assoc<'a, ...>: Bound<'b, ...>`, we want to turn that into
/// an item bound on the GAT, where all of the GAT args are substituted with the GAT's
/// param regions, and then keep all of the other late-bound vars in the bound around.
/// We need to "compress" the binder so that it doesn't mention any of those vars that
/// were mapped to params.
fn remap_gat_vars_and_recurse_into_nested_projections<'tcx>(
tcx: TyCtxt<'tcx>,
filter: PredicateFilter,
item_trait_ref: ty::TraitRef<'tcx>,
assoc_item_def_id: LocalDefId,
span: Span,
clause: ty::Clause<'tcx>,
) -> Option<(ty::Clause<'tcx>, Span)> {
let mut clause_ty = match clause.kind().skip_binder() {
ty::ClauseKind::Trait(tr) => tr.self_ty(),
ty::ClauseKind::Projection(proj) => proj.projection_term.self_ty(),
ty::ClauseKind::TypeOutlives(outlives) => outlives.0,
_ => return None,
};

let gat_vars = loop {
if let ty::Alias(ty::Projection, alias_ty) = *clause_ty.kind() {
if alias_ty.trait_ref(tcx) == item_trait_ref
&& alias_ty.def_id == assoc_item_def_id.to_def_id()
{
// We have found the GAT in question...
// Return the vars, since we may need to remap them.
break &alias_ty.args[item_trait_ref.args.len()..];
} else {
// Only collect *self* type bounds if the filter is for self.
match filter {
PredicateFilter::SelfOnly | PredicateFilter::SelfThatDefines(_) => {
return None;
}
PredicateFilter::All | PredicateFilter::SelfAndAssociatedTypeBounds => {}
}

clause_ty = alias_ty.self_ty();
continue;
}
}

return None;
};

// Special-case: No GAT vars, no mapping needed.
if gat_vars.is_empty() {
return Some((clause, span));
}

// First, check that all of the GAT args are substituted with a unique late-bound arg.
// If we find a duplicate, then it can't be mapped to the definition's params.
let mut mapping = FxIndexMap::default();
let generics = tcx.generics_of(assoc_item_def_id);
for (param, var) in std::iter::zip(&generics.own_params, gat_vars) {
let existing = match var.unpack() {
ty::GenericArgKind::Lifetime(re) => {
if let ty::RegionKind::ReBound(ty::INNERMOST, bv) = re.kind() {
mapping.insert(bv.var, tcx.mk_param_from_def(param))
} else {
return None;
}
}
ty::GenericArgKind::Type(ty) => {
if let ty::Bound(ty::INNERMOST, bv) = *ty.kind() {
mapping.insert(bv.var, tcx.mk_param_from_def(param))
} else {
return None;
}
}
ty::GenericArgKind::Const(ct) => {
if let ty::ConstKind::Bound(ty::INNERMOST, bv) = ct.kind() {
mapping.insert(bv, tcx.mk_param_from_def(param))
} else {
return None;
}
}
};

if existing.is_some() {
return None;
}
}

// Finally, map all of the args in the GAT to the params we expect, and compress
// the remaining late-bound vars so that they count up from var 0.
let mut folder =
MapAndCompressBoundVars { tcx, binder: ty::INNERMOST, still_bound_vars: vec![], mapping };
let pred = clause.kind().skip_binder().fold_with(&mut folder);

Some((
ty::Binder::bind_with_vars(pred, tcx.mk_bound_variable_kinds(&folder.still_bound_vars))
.upcast(tcx),
span,
))
}

/// Given some where clause like `for<'b, 'c> <Self as Trait<'a_identity>>::Gat<'b>: Bound<'c>`,
/// the mapping will map `'b` back to the GAT's `'b_identity`. Then we need to compress the
/// remaining bound var `'c` to index 0.
///
/// This folder gives us: `for<'c> <Self as Trait<'a_identity>>::Gat<'b_identity>: Bound<'c>`,
/// which is sufficient for an item bound for `Gat`, since all of the GAT's args are identity.
struct MapAndCompressBoundVars<'tcx> {
tcx: TyCtxt<'tcx>,
compiler-errors marked this conversation as resolved.
Show resolved Hide resolved
/// How deep are we? Makes sure we don't touch the vars of nested binders.
binder: ty::DebruijnIndex,
/// List of bound vars that remain unsubstituted because they were not
/// mentioned in the GAT's args.
still_bound_vars: Vec<ty::BoundVariableKind>,
/// Subtle invariant: If the `GenericArg` is bound, then it should be
/// stored with the debruijn index of `INNERMOST` so it can be shifted
/// correctly during substitution.
mapping: FxIndexMap<ty::BoundVar, ty::GenericArg<'tcx>>,
}

impl<'tcx> TypeFolder<TyCtxt<'tcx>> for MapAndCompressBoundVars<'tcx> {
fn cx(&self) -> TyCtxt<'tcx> {
self.tcx
}

fn fold_binder<T>(&mut self, t: ty::Binder<'tcx, T>) -> ty::Binder<'tcx, T>
where
ty::Binder<'tcx, T>: TypeSuperFoldable<TyCtxt<'tcx>>,
{
self.binder.shift_in(1);
let out = t.super_fold_with(self);
self.binder.shift_out(1);
out
}

fn fold_ty(&mut self, ty: Ty<'tcx>) -> Ty<'tcx> {
if !ty.has_bound_vars() {
return ty;
}

if let ty::Bound(binder, old_bound) = *ty.kind()
&& self.binder == binder
{
let mapped = if let Some(mapped) = self.mapping.get(&old_bound.var) {
mapped.expect_ty()
} else {
// If we didn't find a mapped generic, then make a new one.
// Allocate a new var idx, and insert a new bound ty.
let var = ty::BoundVar::from_usize(self.still_bound_vars.len());
self.still_bound_vars.push(ty::BoundVariableKind::Ty(old_bound.kind));
let mapped = Ty::new_bound(self.tcx, ty::INNERMOST, ty::BoundTy {
var,
kind: old_bound.kind,
});
self.mapping.insert(old_bound.var, mapped.into());
mapped
};

shift_vars(self.tcx, mapped, self.binder.as_u32())
} else {
ty.super_fold_with(self)
}
}

fn fold_region(&mut self, re: ty::Region<'tcx>) -> ty::Region<'tcx> {
if let ty::ReBound(binder, old_bound) = re.kind()
&& self.binder == binder
{
let mapped = if let Some(mapped) = self.mapping.get(&old_bound.var) {
mapped.expect_region()
} else {
let var = ty::BoundVar::from_usize(self.still_bound_vars.len());
self.still_bound_vars.push(ty::BoundVariableKind::Region(old_bound.kind));
let mapped = ty::Region::new_bound(self.tcx, ty::INNERMOST, ty::BoundRegion {
var,
kind: old_bound.kind,
});
self.mapping.insert(old_bound.var, mapped.into());
mapped
};

shift_vars(self.tcx, mapped, self.binder.as_u32())
} else {
re
}
}

fn fold_const(&mut self, ct: ty::Const<'tcx>) -> ty::Const<'tcx> {
if !ct.has_bound_vars() {
return ct;
}

if let ty::ConstKind::Bound(binder, old_var) = ct.kind()
&& self.binder == binder
{
let mapped = if let Some(mapped) = self.mapping.get(&old_var) {
mapped.expect_const()
} else {
let var = ty::BoundVar::from_usize(self.still_bound_vars.len());
self.still_bound_vars.push(ty::BoundVariableKind::Const);
let mapped = ty::Const::new_bound(self.tcx, ty::INNERMOST, var);
self.mapping.insert(old_var, mapped.into());
mapped
};

shift_vars(self.tcx, mapped, self.binder.as_u32())
} else {
ct.super_fold_with(self)
}
}

fn fold_predicate(&mut self, p: ty::Predicate<'tcx>) -> ty::Predicate<'tcx> {
if !p.has_bound_vars() { p } else { p.super_fold_with(self) }
}
}

/// Opaque types don't inherit bounds from their parent: for return position
/// impl trait it isn't possible to write a suitable predicate on the
/// containing function and for type-alias impl trait we don't have a backwards
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// Demonstrates a mostly-theoretical inference guidance now that we turn the where
// clause on `Trait` into an item bound, given that we prefer item bounds somewhat
// greedily in trait selection.

trait Bound<T> {}
impl<T, U> Bound<T> for U {}

trait Trait
where
<<Self as Trait>::Assoc as Other>::Assoc: Bound<u32>,
{
type Assoc: Other;
}

trait Other {
type Assoc;
}

fn impls_trait<T: Bound<U>, U>() -> Vec<U> { vec![] }

fn foo<T: Trait>() {
let mut vec_u = impls_trait::<<<T as Trait>::Assoc as Other>::Assoc, _>();
vec_u.sort();
drop::<Vec<u8>>(vec_u);
//~^ ERROR mismatched types
}

fn main() {}
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
error[E0308]: mismatched types
--> $DIR/nested-associated-type-bound-incompleteness.rs:24:21
|
LL | drop::<Vec<u8>>(vec_u);
| --------------- ^^^^^ expected `Vec<u8>`, found `Vec<u32>`
| |
| arguments to this function are incorrect
|
= note: expected struct `Vec<u8>`
found struct `Vec<u32>`
note: function defined here
--> $SRC_DIR/core/src/mem/mod.rs:LL:COL

error: aborting due to 1 previous error

For more information about this error, try `rustc --explain E0308`.
31 changes: 31 additions & 0 deletions tests/ui/associated-type-bounds/nested-gat-projection.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
//@ check-pass

trait Trait
where
for<'a> Self::Gat<'a>: OtherTrait,
for<'a, 'b, 'c> <Self::Gat<'a> as OtherTrait>::OtherGat<'b>: HigherRanked<'c>,
{
type Gat<'a>;
}

trait OtherTrait {
type OtherGat<'b>;
}

trait HigherRanked<'c> {}

fn lower_ranked<T: for<'b, 'c> OtherTrait<OtherGat<'b>: HigherRanked<'c>>>() {}

fn higher_ranked<T: Trait>()
where
for<'a> T::Gat<'a>: OtherTrait,
for<'a, 'b, 'c> <T::Gat<'a> as OtherTrait>::OtherGat<'b>: HigherRanked<'c>,
{
}

fn test<T: Trait>() {
lower_ranked::<T::Gat<'_>>();
higher_ranked::<T>();
}

fn main() {}
28 changes: 28 additions & 0 deletions tests/ui/associated-types/imply-relevant-nested-item-bounds-2.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
//@ check-pass
//@ revisions: current next
//@[next] compile-flags: -Znext-solver

trait Trait
where
Self::Assoc: Clone,
{
type Assoc;
}

fn foo<T: Trait>(x: &T::Assoc) -> T::Assoc {
x.clone()
}

trait Trait2
where
Self::Assoc: Iterator,
<Self::Assoc as Iterator>::Item: Clone,
{
type Assoc;
}

fn foo2<T: Trait2>(x: &<T::Assoc as Iterator>::Item) -> <T::Assoc as Iterator>::Item {
x.clone()
}

fn main() {}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
//@ check-pass

// Test that `for<'a> Self::Gat<'a>: Debug` is implied in the definition of `Foo`,
// just as it would be if it weren't a GAT but just a regular associated type.

use std::fmt::Debug;

trait Foo
where
for<'a> Self::Gat<'a>: Debug,
{
type Gat<'a>;
}

fn test<T: Foo>(x: T::Gat<'static>) {
println!("{:?}", x);
}

fn main() {}
Loading
Loading