Skip to content

Commit 0dddad0

Browse files
committed
Auto merge of #111161 - compiler-errors:rtn-super, r=cjgillot
Support return-type bounds on associated methods from supertraits Support `T: Trait<method(): Bound>` when `method` comes from a supertrait, aligning it with the behavior of associated type bounds (both equality and trait bounds). The only wrinkle is that I have to extend `super_predicates_that_define_assoc_type` to look for *all* items, not just `AssocKind::Ty`. This will also be needed to support `feature(associated_const_equality)` as well, which is subtly broken when it comes to supertraits, though this PR does not fix those yet. There's a slight chance there's a perf regression here, in which case I guess I could split it out into a separate query.
2 parents 8660707 + 76802e3 commit 0dddad0

File tree

17 files changed

+210
-50
lines changed

17 files changed

+210
-50
lines changed

compiler/rustc_hir_analysis/messages.ftl

+5-1
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,11 @@ hir_analysis_return_type_notation_equality_bound =
192192
return type notation is not allowed to use type equality
193193
194194
hir_analysis_return_type_notation_missing_method =
195-
cannot find associated function `{$assoc_name}` in trait `{$trait_name}`
195+
cannot find associated function `{$assoc_name}` for `{$ty_name}`
196+
197+
hir_analysis_return_type_notation_conflicting_bound =
198+
ambiguous associated function `{$assoc_name}` for `{$ty_name}`
199+
.note = `{$assoc_name}` is declared in two supertraits: `{$first_bound}` and `{$second_bound}`
196200
197201
hir_analysis_placeholder_not_allowed_item_signatures = the placeholder `_` is not allowed within types on item signatures for {$kind}
198202
.label = not allowed in type signatures

compiler/rustc_hir_analysis/src/astconv/mod.rs

+49-8
Original file line numberDiff line numberDiff line change
@@ -1062,7 +1062,7 @@ impl<'o, 'tcx> dyn AstConv<'tcx> + 'o {
10621062

10631063
/// Convert the bounds in `ast_bounds` that refer to traits which define an associated type
10641064
/// named `assoc_name` into ty::Bounds. Ignore the rest.
1065-
pub(crate) fn compute_bounds_that_match_assoc_type(
1065+
pub(crate) fn compute_bounds_that_match_assoc_item(
10661066
&self,
10671067
param_ty: Ty<'tcx>,
10681068
ast_bounds: &[hir::GenericBound<'_>],
@@ -1073,7 +1073,7 @@ impl<'o, 'tcx> dyn AstConv<'tcx> + 'o {
10731073
for ast_bound in ast_bounds {
10741074
if let Some(trait_ref) = ast_bound.trait_ref()
10751075
&& let Some(trait_did) = trait_ref.trait_def_id()
1076-
&& self.tcx().trait_may_define_assoc_type(trait_did, assoc_name)
1076+
&& self.tcx().trait_may_define_assoc_item(trait_did, assoc_name)
10771077
{
10781078
result.push(ast_bound.clone());
10791079
}
@@ -1141,11 +1141,12 @@ impl<'o, 'tcx> dyn AstConv<'tcx> + 'o {
11411141
) {
11421142
trait_ref
11431143
} else {
1144-
return Err(tcx.sess.emit_err(crate::errors::ReturnTypeNotationMissingMethod {
1145-
span: binding.span,
1146-
trait_name: tcx.item_name(trait_ref.def_id()),
1147-
assoc_name: binding.item_name.name,
1148-
}));
1144+
self.one_bound_for_assoc_method(
1145+
traits::supertraits(tcx, trait_ref),
1146+
trait_ref.print_only_trait_path(),
1147+
binding.item_name,
1148+
path_span,
1149+
)?
11491150
}
11501151
} else if self.trait_defines_associated_item_named(
11511152
trait_ref.def_id(),
@@ -1946,7 +1947,7 @@ impl<'o, 'tcx> dyn AstConv<'tcx> + 'o {
19461947
let param_name = tcx.hir().ty_param_name(ty_param_def_id);
19471948
self.one_bound_for_assoc_type(
19481949
|| {
1949-
traits::transitive_bounds_that_define_assoc_type(
1950+
traits::transitive_bounds_that_define_assoc_item(
19501951
tcx,
19511952
predicates.iter().filter_map(|(p, _)| {
19521953
Some(p.to_opt_poly_trait_pred()?.map_bound(|t| t.trait_ref))
@@ -2081,6 +2082,46 @@ impl<'o, 'tcx> dyn AstConv<'tcx> + 'o {
20812082
Ok(bound)
20822083
}
20832084

2085+
#[instrument(level = "debug", skip(self, all_candidates, ty_name), ret)]
2086+
fn one_bound_for_assoc_method(
2087+
&self,
2088+
all_candidates: impl Iterator<Item = ty::PolyTraitRef<'tcx>>,
2089+
ty_name: impl Display,
2090+
assoc_name: Ident,
2091+
span: Span,
2092+
) -> Result<ty::PolyTraitRef<'tcx>, ErrorGuaranteed> {
2093+
let mut matching_candidates = all_candidates.filter(|r| {
2094+
self.trait_defines_associated_item_named(r.def_id(), ty::AssocKind::Fn, assoc_name)
2095+
});
2096+
2097+
let candidate = match matching_candidates.next() {
2098+
Some(candidate) => candidate,
2099+
None => {
2100+
return Err(self.tcx().sess.emit_err(
2101+
crate::errors::ReturnTypeNotationMissingMethod {
2102+
span,
2103+
ty_name: ty_name.to_string(),
2104+
assoc_name: assoc_name.name,
2105+
},
2106+
));
2107+
}
2108+
};
2109+
2110+
if let Some(conflicting_candidate) = matching_candidates.next() {
2111+
return Err(self.tcx().sess.emit_err(
2112+
crate::errors::ReturnTypeNotationConflictingBound {
2113+
span,
2114+
ty_name: ty_name.to_string(),
2115+
assoc_name: assoc_name.name,
2116+
first_bound: candidate.print_only_trait_path(),
2117+
second_bound: conflicting_candidate.print_only_trait_path(),
2118+
},
2119+
));
2120+
}
2121+
2122+
Ok(candidate)
2123+
}
2124+
20842125
// Create a type from a path to an associated type or to an enum variant.
20852126
// For a path `A::B::C::D`, `qself_ty` and `qself_def` are the type and def for `A::B::C`
20862127
// and item_segment is the path segment for `D`. We return a type and a def for

compiler/rustc_hir_analysis/src/collect.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,8 @@ pub fn provide(providers: &mut Providers) {
6464
explicit_predicates_of: predicates_of::explicit_predicates_of,
6565
super_predicates_of: predicates_of::super_predicates_of,
6666
implied_predicates_of: predicates_of::implied_predicates_of,
67-
super_predicates_that_define_assoc_type:
68-
predicates_of::super_predicates_that_define_assoc_type,
67+
super_predicates_that_define_assoc_item:
68+
predicates_of::super_predicates_that_define_assoc_item,
6969
trait_explicit_predicates_and_bounds: predicates_of::trait_explicit_predicates_and_bounds,
7070
type_param_predicates: predicates_of::type_param_predicates,
7171
trait_def,

compiler/rustc_hir_analysis/src/collect/predicates_of.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -565,7 +565,7 @@ pub(super) fn super_predicates_of(
565565
implied_predicates_with_filter(tcx, trait_def_id.to_def_id(), PredicateFilter::SelfOnly)
566566
}
567567

568-
pub(super) fn super_predicates_that_define_assoc_type(
568+
pub(super) fn super_predicates_that_define_assoc_item(
569569
tcx: TyCtxt<'_>,
570570
(trait_def_id, assoc_name): (DefId, Ident),
571571
) -> ty::GenericPredicates<'_> {
@@ -640,7 +640,7 @@ pub(super) fn implied_predicates_with_filter(
640640
),
641641
PredicateFilter::SelfThatDefines(assoc_name) => (
642642
// Convert the bounds that follow the colon (or equal) that reference the associated name
643-
icx.astconv().compute_bounds_that_match_assoc_type(self_param_ty, bounds, assoc_name),
643+
icx.astconv().compute_bounds_that_match_assoc_item(self_param_ty, bounds, assoc_name),
644644
// Include where clause bounds for `Self` that reference the associated name
645645
icx.type_parameter_bounds_in_generics(
646646
generics,
@@ -819,7 +819,7 @@ impl<'tcx> ItemCtxt<'tcx> {
819819
hir::GenericBound::Trait(poly_trait_ref, _) => {
820820
let trait_ref = &poly_trait_ref.trait_ref;
821821
if let Some(trait_did) = trait_ref.trait_def_id() {
822-
self.tcx.trait_may_define_assoc_type(trait_did, assoc_name)
822+
self.tcx.trait_may_define_assoc_item(trait_did, assoc_name)
823823
} else {
824824
false
825825
}

compiler/rustc_hir_analysis/src/collect/resolve_bound_vars.rs

+33-23
Original file line numberDiff line numberDiff line change
@@ -1652,27 +1652,28 @@ impl<'a, 'tcx> BoundVarContext<'a, 'tcx> {
16521652
if binding.gen_args.parenthesized == hir::GenericArgsParentheses::ReturnTypeNotation {
16531653
let bound_vars = if let Some(type_def_id) = type_def_id
16541654
&& self.tcx.def_kind(type_def_id) == DefKind::Trait
1655-
// FIXME(return_type_notation): We could bound supertrait methods.
1656-
&& let Some(assoc_fn) = self
1657-
.tcx
1658-
.associated_items(type_def_id)
1659-
.find_by_name_and_kind(self.tcx, binding.ident, ty::AssocKind::Fn, type_def_id)
1655+
&& let Some((mut bound_vars, assoc_fn)) =
1656+
BoundVarContext::supertrait_hrtb_vars(
1657+
self.tcx,
1658+
type_def_id,
1659+
binding.ident,
1660+
ty::AssocKind::Fn,
1661+
)
16601662
{
1661-
self.tcx
1662-
.generics_of(assoc_fn.def_id)
1663-
.params
1664-
.iter()
1665-
.map(|param| match param.kind {
1663+
bound_vars.extend(self.tcx.generics_of(assoc_fn.def_id).params.iter().map(
1664+
|param| match param.kind {
16661665
ty::GenericParamDefKind::Lifetime => ty::BoundVariableKind::Region(
16671666
ty::BoundRegionKind::BrNamed(param.def_id, param.name),
16681667
),
16691668
ty::GenericParamDefKind::Type { .. } => ty::BoundVariableKind::Ty(
16701669
ty::BoundTyKind::Param(param.def_id, param.name),
16711670
),
16721671
ty::GenericParamDefKind::Const { .. } => ty::BoundVariableKind::Const,
1673-
})
1674-
.chain(self.tcx.fn_sig(assoc_fn.def_id).subst_identity().bound_vars())
1675-
.collect()
1672+
},
1673+
));
1674+
bound_vars
1675+
.extend(self.tcx.fn_sig(assoc_fn.def_id).subst_identity().bound_vars());
1676+
bound_vars
16761677
} else {
16771678
self.tcx.sess.delay_span_bug(
16781679
binding.ident.span,
@@ -1689,8 +1690,13 @@ impl<'a, 'tcx> BoundVarContext<'a, 'tcx> {
16891690
});
16901691
});
16911692
} else if let Some(type_def_id) = type_def_id {
1692-
let bound_vars =
1693-
BoundVarContext::supertrait_hrtb_vars(self.tcx, type_def_id, binding.ident);
1693+
let bound_vars = BoundVarContext::supertrait_hrtb_vars(
1694+
self.tcx,
1695+
type_def_id,
1696+
binding.ident,
1697+
ty::AssocKind::Type,
1698+
)
1699+
.map(|(bound_vars, _)| bound_vars);
16941700
self.with(scope, |this| {
16951701
let scope = Scope::Supertrait {
16961702
bound_vars: bound_vars.unwrap_or_default(),
@@ -1720,11 +1726,15 @@ impl<'a, 'tcx> BoundVarContext<'a, 'tcx> {
17201726
tcx: TyCtxt<'tcx>,
17211727
def_id: DefId,
17221728
assoc_name: Ident,
1723-
) -> Option<Vec<ty::BoundVariableKind>> {
1724-
let trait_defines_associated_type_named = |trait_def_id: DefId| {
1725-
tcx.associated_items(trait_def_id)
1726-
.find_by_name_and_kind(tcx, assoc_name, ty::AssocKind::Type, trait_def_id)
1727-
.is_some()
1729+
assoc_kind: ty::AssocKind,
1730+
) -> Option<(Vec<ty::BoundVariableKind>, &'tcx ty::AssocItem)> {
1731+
let trait_defines_associated_item_named = |trait_def_id: DefId| {
1732+
tcx.associated_items(trait_def_id).find_by_name_and_kind(
1733+
tcx,
1734+
assoc_name,
1735+
assoc_kind,
1736+
trait_def_id,
1737+
)
17281738
};
17291739

17301740
use smallvec::{smallvec, SmallVec};
@@ -1742,10 +1752,10 @@ impl<'a, 'tcx> BoundVarContext<'a, 'tcx> {
17421752
_ => break None,
17431753
}
17441754

1745-
if trait_defines_associated_type_named(def_id) {
1746-
break Some(bound_vars.into_iter().collect());
1755+
if let Some(assoc_item) = trait_defines_associated_item_named(def_id) {
1756+
break Some((bound_vars.into_iter().collect(), assoc_item));
17471757
}
1748-
let predicates = tcx.super_predicates_that_define_assoc_type((def_id, assoc_name));
1758+
let predicates = tcx.super_predicates_that_define_assoc_item((def_id, assoc_name));
17491759
let obligations = predicates.predicates.iter().filter_map(|&(pred, _)| {
17501760
let bound_predicate = pred.kind();
17511761
match bound_predicate.skip_binder() {

compiler/rustc_hir_analysis/src/errors.rs

+14-2
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use rustc_errors::{
66
MultiSpan,
77
};
88
use rustc_macros::{Diagnostic, Subdiagnostic};
9-
use rustc_middle::ty::Ty;
9+
use rustc_middle::ty::{self, print::TraitRefPrintOnlyTraitPath, Ty};
1010
use rustc_span::{symbol::Ident, Span, Symbol};
1111

1212
#[derive(Diagnostic)]
@@ -512,10 +512,22 @@ pub(crate) struct ReturnTypeNotationEqualityBound {
512512
pub(crate) struct ReturnTypeNotationMissingMethod {
513513
#[primary_span]
514514
pub span: Span,
515-
pub trait_name: Symbol,
515+
pub ty_name: String,
516516
pub assoc_name: Symbol,
517517
}
518518

519+
#[derive(Diagnostic)]
520+
#[diag(hir_analysis_return_type_notation_conflicting_bound)]
521+
#[note]
522+
pub(crate) struct ReturnTypeNotationConflictingBound<'tcx> {
523+
#[primary_span]
524+
pub span: Span,
525+
pub ty_name: String,
526+
pub assoc_name: Symbol,
527+
pub first_bound: ty::Binder<'tcx, TraitRefPrintOnlyTraitPath<'tcx>>,
528+
pub second_bound: ty::Binder<'tcx, TraitRefPrintOnlyTraitPath<'tcx>>,
529+
}
530+
519531
#[derive(Diagnostic)]
520532
#[diag(hir_analysis_placeholder_not_allowed_item_signatures, code = "E0121")]
521533
pub(crate) struct PlaceholderNotAllowedItemSignatures {

compiler/rustc_infer/src/traits/util.rs

+4-4
Original file line numberDiff line numberDiff line change
@@ -380,11 +380,11 @@ pub fn transitive_bounds<'tcx>(
380380
}
381381

382382
/// A specialized variant of `elaborate` that only elaborates trait references that may
383-
/// define the given associated type `assoc_name`. It uses the
384-
/// `super_predicates_that_define_assoc_type` query to avoid enumerating super-predicates that
383+
/// define the given associated item with the name `assoc_name`. It uses the
384+
/// `super_predicates_that_define_assoc_item` query to avoid enumerating super-predicates that
385385
/// aren't related to `assoc_item`. This is used when resolving types like `Self::Item` or
386386
/// `T::Item` and helps to avoid cycle errors (see e.g. #35237).
387-
pub fn transitive_bounds_that_define_assoc_type<'tcx>(
387+
pub fn transitive_bounds_that_define_assoc_item<'tcx>(
388388
tcx: TyCtxt<'tcx>,
389389
bounds: impl Iterator<Item = ty::PolyTraitRef<'tcx>>,
390390
assoc_name: Ident,
@@ -397,7 +397,7 @@ pub fn transitive_bounds_that_define_assoc_type<'tcx>(
397397
let anon_trait_ref = tcx.anonymize_bound_vars(trait_ref);
398398
if visited.insert(anon_trait_ref) {
399399
let super_predicates =
400-
tcx.super_predicates_that_define_assoc_type((trait_ref.def_id(), assoc_name));
400+
tcx.super_predicates_that_define_assoc_item((trait_ref.def_id(), assoc_name));
401401
for (super_predicate, _) in super_predicates.predicates {
402402
let subst_predicate = super_predicate.subst_supertrait(tcx, &trait_ref);
403403
if let Some(binder) = subst_predicate.to_opt_poly_trait_pred() {

compiler/rustc_middle/src/query/mod.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -569,7 +569,7 @@ rustc_queries! {
569569
/// returns the full set of predicates. If `Some<Ident>`, then the query returns only the
570570
/// subset of super-predicates that reference traits that define the given associated type.
571571
/// This is used to avoid cycles in resolving types like `T::Item`.
572-
query super_predicates_that_define_assoc_type(key: (DefId, rustc_span::symbol::Ident)) -> ty::GenericPredicates<'tcx> {
572+
query super_predicates_that_define_assoc_item(key: (DefId, rustc_span::symbol::Ident)) -> ty::GenericPredicates<'tcx> {
573573
desc { |tcx| "computing the super traits of `{}` with associated type name `{}`",
574574
tcx.def_path_str(key.0),
575575
key.1

compiler/rustc_middle/src/ty/context.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -1567,11 +1567,11 @@ impl<'tcx> TyCtxt<'tcx> {
15671567

15681568
/// Given the def_id of a Trait `trait_def_id` and the name of an associated item `assoc_name`
15691569
/// returns true if the `trait_def_id` defines an associated item of name `assoc_name`.
1570-
pub fn trait_may_define_assoc_type(self, trait_def_id: DefId, assoc_name: Ident) -> bool {
1570+
pub fn trait_may_define_assoc_item(self, trait_def_id: DefId, assoc_name: Ident) -> bool {
15711571
self.super_traits_of(trait_def_id).any(|trait_did| {
15721572
self.associated_items(trait_did)
1573-
.find_by_name_and_kind(self, assoc_name, ty::AssocKind::Type, trait_did)
1574-
.is_some()
1573+
.filter_by_name_unhygienic(assoc_name.name)
1574+
.any(|item| self.hygienic_eq(assoc_name, item.ident(self), trait_did))
15751575
})
15761576
}
15771577

compiler/rustc_middle/src/ty/print/pretty.rs

+6
Original file line numberDiff line numberDiff line change
@@ -2633,6 +2633,12 @@ macro_rules! define_print_and_forward_display {
26332633
#[derive(Copy, Clone, TypeFoldable, TypeVisitable, Lift)]
26342634
pub struct TraitRefPrintOnlyTraitPath<'tcx>(ty::TraitRef<'tcx>);
26352635

2636+
impl<'tcx> rustc_errors::IntoDiagnosticArg for TraitRefPrintOnlyTraitPath<'tcx> {
2637+
fn into_diagnostic_arg(self) -> rustc_errors::DiagnosticArgValue<'static> {
2638+
self.to_string().into_diagnostic_arg()
2639+
}
2640+
}
2641+
26362642
impl<'tcx> fmt::Debug for TraitRefPrintOnlyTraitPath<'tcx> {
26372643
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
26382644
fmt::Display::fmt(self, f)

compiler/rustc_trait_selection/src/traits/mod.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ pub use self::util::elaborate;
6262
pub use self::util::{expand_trait_aliases, TraitAliasExpander};
6363
pub use self::util::{get_vtable_index_of_object_method, impl_item_is_final, upcast_choices};
6464
pub use self::util::{
65-
supertrait_def_ids, supertraits, transitive_bounds, transitive_bounds_that_define_assoc_type,
65+
supertrait_def_ids, supertraits, transitive_bounds, transitive_bounds_that_define_assoc_item,
6666
SupertraitDefIds,
6767
};
6868

tests/ui/associated-type-bounds/return-type-notation/missing.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,6 @@ trait Trait {
88
}
99

1010
fn bar<T: Trait<methid(): Send>>() {}
11-
//~^ ERROR cannot find associated function `methid` in trait `Trait`
11+
//~^ ERROR cannot find associated function `methid` for `Trait`
1212

1313
fn main() {}

tests/ui/associated-type-bounds/return-type-notation/missing.stderr

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ LL | #![feature(return_type_notation, async_fn_in_trait)]
77
= note: see issue #109417 <https://github.com/rust-lang/rust/issues/109417> for more information
88
= note: `#[warn(incomplete_features)]` on by default
99

10-
error: cannot find associated function `methid` in trait `Trait`
10+
error: cannot find associated function `methid` for `Trait`
1111
--> $DIR/missing.rs:10:17
1212
|
1313
LL | fn bar<T: Trait<methid(): Send>>() {}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
// edition:2021
2+
3+
#![feature(async_fn_in_trait, return_type_notation)]
4+
//~^ WARN the feature `return_type_notation` is incomplete
5+
6+
trait Super1<'a> {
7+
async fn test();
8+
}
9+
impl Super1<'_> for () {
10+
async fn test() {}
11+
}
12+
13+
trait Super2 {
14+
async fn test();
15+
}
16+
impl Super2 for () {
17+
async fn test() {}
18+
}
19+
20+
trait Foo: for<'a> Super1<'a> + Super2 {}
21+
impl Foo for () {}
22+
23+
fn test<T>()
24+
where
25+
T: Foo<test(): Send>,
26+
//~^ ERROR ambiguous associated function `test` for `Foo`
27+
{
28+
}
29+
30+
fn main() {
31+
test::<()>();
32+
}

0 commit comments

Comments
 (0)