Skip to content

Commit 4bed748

Browse files
committed
Suggest impl Trait return type
Address #85991 Suggest the `impl Trait` return type syntax if the user tried to return a generic parameter and we get a type mismatch The suggestion is not emitted if the param appears in the function parameters, and only get the bounds that actually involve `T: ` directly It also checks whether the generic param is contained in any where bound (where it isn't the self type), and if one is found (like `Option<T>: Send`), it is not suggested. This also adds `TyS::contains`, which recursively vistits the type and looks if the other type is contained anywhere
1 parent b8c56fa commit 4bed748

File tree

7 files changed

+321
-3
lines changed

7 files changed

+321
-3
lines changed

compiler/rustc_middle/src/ty/sty.rs

+22-2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@ use crate::infer::canonical::Canonical;
88
use crate::ty::fold::ValidateBoundVars;
99
use crate::ty::subst::{GenericArg, InternalSubsts, Subst, SubstsRef};
1010
use crate::ty::InferTy::{self, *};
11-
use crate::ty::{self, AdtDef, DefIdTree, Discr, Term, Ty, TyCtxt, TypeFlags, TypeFoldable};
11+
use crate::ty::{
12+
self, AdtDef, DefIdTree, Discr, Term, Ty, TyCtxt, TypeFlags, TypeFoldable, TypeVisitor,
13+
};
1214
use crate::ty::{DelaySpanBugEmitted, List, ParamEnv};
1315
use polonius_engine::Atom;
1416
use rustc_data_structures::captures::Captures;
@@ -24,7 +26,7 @@ use std::borrow::Cow;
2426
use std::cmp::Ordering;
2527
use std::fmt;
2628
use std::marker::PhantomData;
27-
use std::ops::{Deref, Range};
29+
use std::ops::{ControlFlow, Deref, Range};
2830
use ty::util::IntTypeExt;
2931

3032
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug, TyEncodable, TyDecodable)]
@@ -2072,6 +2074,24 @@ impl<'tcx> Ty<'tcx> {
20722074
!matches!(self.kind(), Param(_) | Infer(_) | Error(_))
20732075
}
20742076

2077+
/// Checks whether a type recursively contains another type
2078+
///
2079+
/// Example: `Option<()>` contains `()`
2080+
pub fn contains(self, other: Ty<'tcx>) -> bool {
2081+
struct ContainsTyVisitor<'tcx>(Ty<'tcx>);
2082+
2083+
impl<'tcx> TypeVisitor<'tcx> for ContainsTyVisitor<'tcx> {
2084+
type BreakTy = ();
2085+
2086+
fn visit_ty(&mut self, t: Ty<'tcx>) -> ControlFlow<Self::BreakTy> {
2087+
if self.0 == t { ControlFlow::BREAK } else { t.super_visit_with(self) }
2088+
}
2089+
}
2090+
2091+
let cf = self.visit_with(&mut ContainsTyVisitor(other));
2092+
cf.is_break()
2093+
}
2094+
20752095
/// Returns the type and mutability of `*ty`.
20762096
///
20772097
/// The parameter `explicit` indicates if this is an *explicit* dereference.

compiler/rustc_typeck/src/check/fn_ctxt/suggestions.rs

+115-1
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,12 @@ use rustc_errors::{Applicability, DiagnosticBuilder};
88
use rustc_hir as hir;
99
use rustc_hir::def::{CtorOf, DefKind};
1010
use rustc_hir::lang_items::LangItem;
11-
use rustc_hir::{Expr, ExprKind, ItemKind, Node, Path, QPath, Stmt, StmtKind, TyKind};
11+
use rustc_hir::{
12+
Expr, ExprKind, GenericBound, ItemKind, Node, Path, QPath, Stmt, StmtKind, TyKind,
13+
WherePredicate,
14+
};
1215
use rustc_infer::infer::{self, TyCtxtInferExt};
16+
1317
use rustc_middle::lint::in_external_macro;
1418
use rustc_middle::ty::{self, Binder, Ty};
1519
use rustc_span::symbol::{kw, sym};
@@ -559,13 +563,123 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
559563
let ty = self.tcx.erase_late_bound_regions(ty);
560564
if self.can_coerce(expected, ty) {
561565
err.span_label(sp, format!("expected `{}` because of return type", expected));
566+
self.try_suggest_return_impl_trait(err, expected, ty, fn_id);
562567
return true;
563568
}
564569
false
565570
}
566571
}
567572
}
568573

574+
/// check whether the return type is a generic type with a trait bound
575+
/// only suggest this if the generic param is not present in the arguments
576+
/// if this is true, hint them towards changing the return type to `impl Trait`
577+
/// ```
578+
/// fn cant_name_it<T: Fn() -> u32>() -> T {
579+
/// || 3
580+
/// }
581+
/// ```
582+
fn try_suggest_return_impl_trait(
583+
&self,
584+
err: &mut DiagnosticBuilder<'_>,
585+
expected: Ty<'tcx>,
586+
found: Ty<'tcx>,
587+
fn_id: hir::HirId,
588+
) {
589+
// Only apply the suggestion if:
590+
// - the return type is a generic parameter
591+
// - the generic param is not used as a fn param
592+
// - the generic param has at least one bound
593+
// - the generic param doesn't appear in any other bounds where it's not the Self type
594+
// Suggest:
595+
// - Changing the return type to be `impl <all bounds>`
596+
597+
debug!("try_suggest_return_impl_trait, expected = {:?}, found = {:?}", expected, found);
598+
599+
let ty::Param(expected_ty_as_param) = expected.kind() else { return };
600+
601+
let fn_node = self.tcx.hir().find(fn_id);
602+
603+
let Some(hir::Node::Item(hir::Item {
604+
kind:
605+
hir::ItemKind::Fn(
606+
hir::FnSig { decl: hir::FnDecl { inputs: fn_parameters, output: fn_return, .. }, .. },
607+
hir::Generics { params, where_clause, .. },
608+
_body_id,
609+
),
610+
..
611+
})) = fn_node else { return };
612+
613+
let Some(expected_generic_param) = params.get(expected_ty_as_param.index as usize) else { return };
614+
615+
// get all where BoundPredicates here, because they are used in to cases below
616+
let where_predicates = where_clause
617+
.predicates
618+
.iter()
619+
.filter_map(|p| match p {
620+
WherePredicate::BoundPredicate(hir::WhereBoundPredicate {
621+
bounds,
622+
bounded_ty,
623+
..
624+
}) => {
625+
// FIXME: Maybe these calls to `ast_ty_to_ty` can be removed (and the ones below)
626+
let ty = <dyn AstConv<'_>>::ast_ty_to_ty(self, bounded_ty);
627+
Some((ty, bounds))
628+
}
629+
_ => None,
630+
})
631+
.map(|(ty, bounds)| match ty.kind() {
632+
ty::Param(param_ty) if param_ty == expected_ty_as_param => Ok(Some(bounds)),
633+
// check whether there is any predicate that contains our `T`, like `Option<T>: Send`
634+
_ => match ty.contains(expected) {
635+
true => Err(()),
636+
false => Ok(None),
637+
},
638+
})
639+
.collect::<Result<Vec<_>, _>>();
640+
641+
let Ok(where_predicates) = where_predicates else { return };
642+
643+
// now get all predicates in the same types as the where bounds, so we can chain them
644+
let predicates_from_where =
645+
where_predicates.iter().flatten().map(|bounds| bounds.iter()).flatten();
646+
647+
// extract all bounds from the source code using their spans
648+
let all_matching_bounds_strs = expected_generic_param
649+
.bounds
650+
.iter()
651+
.chain(predicates_from_where)
652+
.filter_map(|bound| match bound {
653+
GenericBound::Trait(_, _) => {
654+
self.tcx.sess.source_map().span_to_snippet(bound.span()).ok()
655+
}
656+
_ => None,
657+
})
658+
.collect::<Vec<String>>();
659+
660+
if all_matching_bounds_strs.len() == 0 {
661+
return;
662+
}
663+
664+
let all_bounds_str = all_matching_bounds_strs.join(" + ");
665+
666+
let ty_param_used_in_fn_params = fn_parameters.iter().any(|param| {
667+
let ty = <dyn AstConv<'_>>::ast_ty_to_ty(self, param);
668+
matches!(ty.kind(), ty::Param(fn_param_ty_param) if expected_ty_as_param == fn_param_ty_param)
669+
});
670+
671+
if ty_param_used_in_fn_params {
672+
return;
673+
}
674+
675+
err.span_suggestion(
676+
fn_return.span(),
677+
"consider using an impl return type",
678+
format!("impl {}", all_bounds_str),
679+
Applicability::MaybeIncorrect,
680+
);
681+
}
682+
569683
pub(in super::super) fn suggest_missing_break_or_return_expr(
570684
&self,
571685
err: &mut DiagnosticBuilder<'_>,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
trait Trait {}
2+
impl Trait for () {}
3+
4+
fn bad_echo<T>(_t: T) -> T {
5+
"this should not suggest impl Trait" //~ ERROR mismatched types
6+
}
7+
8+
fn bad_echo_2<T: Trait>(_t: T) -> T {
9+
"this will not suggest it, because that would probably be wrong" //~ ERROR mismatched types
10+
}
11+
12+
fn other_bounds_bad<T>() -> T
13+
where
14+
T: Send,
15+
Option<T>: Send,
16+
{
17+
"don't suggest this, because Option<T> places additional constraints" //~ ERROR mismatched types
18+
}
19+
20+
// FIXME: implement this check
21+
trait GenericTrait<T> {}
22+
23+
fn used_in_trait<T>() -> T
24+
where
25+
T: Send,
26+
(): GenericTrait<T>,
27+
{
28+
"don't suggest this, because the generic param is used in the bound." //~ ERROR mismatched types
29+
}
30+
31+
fn main() {}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
error[E0308]: mismatched types
2+
--> $DIR/return-impl-trait-bad.rs:5:5
3+
|
4+
LL | fn bad_echo<T>(_t: T) -> T {
5+
| - - expected `T` because of return type
6+
| |
7+
| this type parameter
8+
LL | "this should not suggest impl Trait"
9+
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ expected type parameter `T`, found `&str`
10+
|
11+
= note: expected type parameter `T`
12+
found reference `&'static str`
13+
14+
error[E0308]: mismatched types
15+
--> $DIR/return-impl-trait-bad.rs:9:5
16+
|
17+
LL | fn bad_echo_2<T: Trait>(_t: T) -> T {
18+
| - - expected `T` because of return type
19+
| |
20+
| this type parameter
21+
LL | "this will not suggest it, because that would probably be wrong"
22+
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ expected type parameter `T`, found `&str`
23+
|
24+
= note: expected type parameter `T`
25+
found reference `&'static str`
26+
27+
error[E0308]: mismatched types
28+
--> $DIR/return-impl-trait-bad.rs:17:5
29+
|
30+
LL | fn other_bounds_bad<T>() -> T
31+
| - - expected `T` because of return type
32+
| |
33+
| this type parameter
34+
...
35+
LL | "don't suggest this, because Option<T> places additional constraints"
36+
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ expected type parameter `T`, found `&str`
37+
|
38+
= note: expected type parameter `T`
39+
found reference `&'static str`
40+
41+
error[E0308]: mismatched types
42+
--> $DIR/return-impl-trait-bad.rs:28:5
43+
|
44+
LL | fn used_in_trait<T>() -> T
45+
| - -
46+
| | |
47+
| | expected `T` because of return type
48+
| | help: consider using an impl return type: `impl Send`
49+
| this type parameter
50+
...
51+
LL | "don't suggest this, because the generic param is used in the bound."
52+
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ expected type parameter `T`, found `&str`
53+
|
54+
= note: expected type parameter `T`
55+
found reference `&'static str`
56+
57+
error: aborting due to 4 previous errors
58+
59+
For more information about this error, try `rustc --explain E0308`.
+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
// run-rustfix
2+
3+
trait Trait {}
4+
impl Trait for () {}
5+
6+
// this works
7+
fn foo() -> impl Trait {
8+
()
9+
}
10+
11+
fn bar<T: Trait + std::marker::Sync>() -> impl Trait + std::marker::Sync + Send
12+
where
13+
T: Send,
14+
{
15+
() //~ ERROR mismatched types
16+
}
17+
18+
fn other_bounds<T>() -> impl Trait
19+
where
20+
T: Trait,
21+
Vec<usize>: Clone,
22+
{
23+
() //~ ERROR mismatched types
24+
}
25+
26+
fn main() {
27+
foo();
28+
bar::<()>();
29+
other_bounds::<()>();
30+
}
+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
// run-rustfix
2+
3+
trait Trait {}
4+
impl Trait for () {}
5+
6+
// this works
7+
fn foo() -> impl Trait {
8+
()
9+
}
10+
11+
fn bar<T: Trait + std::marker::Sync>() -> T
12+
where
13+
T: Send,
14+
{
15+
() //~ ERROR mismatched types
16+
}
17+
18+
fn other_bounds<T>() -> T
19+
where
20+
T: Trait,
21+
Vec<usize>: Clone,
22+
{
23+
() //~ ERROR mismatched types
24+
}
25+
26+
fn main() {
27+
foo();
28+
bar::<()>();
29+
other_bounds::<()>();
30+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
error[E0308]: mismatched types
2+
--> $DIR/return-impl-trait.rs:15:5
3+
|
4+
LL | fn bar<T: Trait + std::marker::Sync>() -> T
5+
| - -
6+
| | |
7+
| | expected `T` because of return type
8+
| this type parameter help: consider using an impl return type: `impl Trait + std::marker::Sync + Send`
9+
...
10+
LL | ()
11+
| ^^ expected type parameter `T`, found `()`
12+
|
13+
= note: expected type parameter `T`
14+
found unit type `()`
15+
16+
error[E0308]: mismatched types
17+
--> $DIR/return-impl-trait.rs:23:5
18+
|
19+
LL | fn other_bounds<T>() -> T
20+
| - -
21+
| | |
22+
| | expected `T` because of return type
23+
| | help: consider using an impl return type: `impl Trait`
24+
| this type parameter
25+
...
26+
LL | ()
27+
| ^^ expected type parameter `T`, found `()`
28+
|
29+
= note: expected type parameter `T`
30+
found unit type `()`
31+
32+
error: aborting due to 2 previous errors
33+
34+
For more information about this error, try `rustc --explain E0308`.

0 commit comments

Comments
 (0)