Skip to content

Commit d25abdc

Browse files
Point out custom Fn-family trait impl
1 parent ddb7003 commit d25abdc

File tree

3 files changed

+114
-53
lines changed

3 files changed

+114
-53
lines changed

Diff for: compiler/rustc_middle/src/ty/closure.rs

+8
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,14 @@ impl<'tcx> ClosureKind {
128128
None
129129
}
130130
}
131+
132+
pub fn to_def_id(&self, tcx: TyCtxt<'_>) -> DefId {
133+
match self {
134+
ClosureKind::Fn => tcx.lang_items().fn_once_trait().unwrap(),
135+
ClosureKind::FnMut => tcx.lang_items().fn_mut_trait().unwrap(),
136+
ClosureKind::FnOnce => tcx.lang_items().fn_trait().unwrap(),
137+
}
138+
}
131139
}
132140

133141
/// A composite describing a `Place` that is captured by a closure.

Diff for: compiler/rustc_typeck/src/check/fn_ctxt/checks.rs

+90-53
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ use rustc_hir::def_id::DefId;
2121
use rustc_hir::{ExprKind, Node, QPath};
2222
use rustc_index::vec::IndexVec;
2323
use rustc_infer::infer::error_reporting::{FailureCode, ObligationCauseExt};
24+
use rustc_infer::infer::type_variable::{TypeVariableOrigin, TypeVariableOriginKind};
2425
use rustc_infer::infer::InferOk;
2526
use rustc_infer::infer::TypeTrace;
2627
use rustc_middle::ty::adjustment::AllowTwoPhase;
@@ -29,7 +30,9 @@ use rustc_middle::ty::{self, DefIdTree, IsSuggestable, Ty};
2930
use rustc_session::Session;
3031
use rustc_span::symbol::Ident;
3132
use rustc_span::{self, Span};
32-
use rustc_trait_selection::traits::{self, ObligationCauseCode, StatementAsExpression};
33+
use rustc_trait_selection::traits::{
34+
self, ObligationCauseCode, SelectionContext, StatementAsExpression,
35+
};
3336

3437
use std::iter;
3538
use std::slice;
@@ -393,41 +396,6 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
393396
}
394397

395398
if !call_appears_satisfied {
396-
// Next, let's construct the error
397-
let (error_span, full_call_span, ctor_of) = match &call_expr.kind {
398-
hir::ExprKind::Call(
399-
hir::Expr {
400-
span,
401-
kind:
402-
hir::ExprKind::Path(hir::QPath::Resolved(
403-
_,
404-
hir::Path { res: Res::Def(DefKind::Ctor(of, _), _), .. },
405-
)),
406-
..
407-
},
408-
_,
409-
) => (call_span, *span, Some(of)),
410-
hir::ExprKind::Call(hir::Expr { span, .. }, _) => (call_span, *span, None),
411-
hir::ExprKind::MethodCall(path_segment, _, span) => {
412-
let ident_span = path_segment.ident.span;
413-
let ident_span = if let Some(args) = path_segment.args {
414-
ident_span.with_hi(args.span_ext.hi())
415-
} else {
416-
ident_span
417-
};
418-
(
419-
*span, ident_span, None, // methods are never ctors
420-
)
421-
}
422-
k => span_bug!(call_span, "checking argument types on a non-call: `{:?}`", k),
423-
};
424-
let args_span = error_span.trim_start(full_call_span).unwrap_or(error_span);
425-
let call_name = match ctor_of {
426-
Some(CtorOf::Struct) => "struct",
427-
Some(CtorOf::Variant) => "enum variant",
428-
None => "function",
429-
};
430-
431399
let compatibility_diagonal = IndexVec::from_raw(compatibility_diagonal);
432400
let provided_args = IndexVec::from_iter(provided_args.iter().take(if c_variadic {
433401
minimum_input_count
@@ -451,13 +419,10 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
451419
compatibility_diagonal,
452420
formal_and_expected_inputs,
453421
provided_args,
454-
full_call_span,
455-
error_span,
456-
args_span,
457-
call_name,
458422
c_variadic,
459423
err_code,
460424
fn_def_id,
425+
call_span,
461426
call_expr,
462427
);
463428
}
@@ -468,15 +433,47 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
468433
compatibility_diagonal: IndexVec<ProvidedIdx, Compatibility<'tcx>>,
469434
formal_and_expected_inputs: IndexVec<ExpectedIdx, (Ty<'tcx>, Ty<'tcx>)>,
470435
provided_args: IndexVec<ProvidedIdx, &'tcx hir::Expr<'tcx>>,
471-
full_call_span: Span,
472-
error_span: Span,
473-
args_span: Span,
474-
call_name: &str,
475436
c_variadic: bool,
476437
err_code: &str,
477438
fn_def_id: Option<DefId>,
439+
call_span: Span,
478440
call_expr: &hir::Expr<'tcx>,
479441
) {
442+
// Next, let's construct the error
443+
let (error_span, full_call_span, ctor_of) = match &call_expr.kind {
444+
hir::ExprKind::Call(
445+
hir::Expr {
446+
span,
447+
kind:
448+
hir::ExprKind::Path(hir::QPath::Resolved(
449+
_,
450+
hir::Path { res: Res::Def(DefKind::Ctor(of, _), _), .. },
451+
)),
452+
..
453+
},
454+
_,
455+
) => (call_span, *span, Some(of)),
456+
hir::ExprKind::Call(hir::Expr { span, .. }, _) => (call_span, *span, None),
457+
hir::ExprKind::MethodCall(path_segment, _, span) => {
458+
let ident_span = path_segment.ident.span;
459+
let ident_span = if let Some(args) = path_segment.args {
460+
ident_span.with_hi(args.span_ext.hi())
461+
} else {
462+
ident_span
463+
};
464+
(
465+
*span, ident_span, None, // methods are never ctors
466+
)
467+
}
468+
k => span_bug!(call_span, "checking argument types on a non-call: `{:?}`", k),
469+
};
470+
let args_span = error_span.trim_start(full_call_span).unwrap_or(error_span);
471+
let call_name = match ctor_of {
472+
Some(CtorOf::Struct) => "struct",
473+
Some(CtorOf::Variant) => "enum variant",
474+
None => "function",
475+
};
476+
480477
// Don't print if it has error types or is just plain `_`
481478
fn has_error_or_infer<'tcx>(tys: impl IntoIterator<Item = Ty<'tcx>>) -> bool {
482479
tys.into_iter().any(|ty| ty.references_error() || ty.is_ty_var())
@@ -1818,17 +1815,22 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
18181815
fn label_fn_like(
18191816
&self,
18201817
err: &mut rustc_errors::DiagnosticBuilder<'tcx, rustc_errors::ErrorGuaranteed>,
1821-
def_id: Option<DefId>,
1818+
callable_def_id: Option<DefId>,
18221819
callee_ty: Option<Ty<'tcx>>,
18231820
) {
1824-
let Some(mut def_id) = def_id else {
1821+
let Some(mut def_id) = callable_def_id else {
18251822
return;
18261823
};
18271824

18281825
if let Some(assoc_item) = self.tcx.opt_associated_item(def_id)
1829-
&& let trait_def_id = assoc_item.trait_item_def_id.unwrap_or_else(|| self.tcx.parent(def_id))
1826+
// Possibly points at either impl or trait item, so try to get it
1827+
// to point to trait item, then get the parent.
1828+
// This parent might be an impl in the case of an inherent function,
1829+
// but the next check will fail.
1830+
&& let maybe_trait_item_def_id = assoc_item.trait_item_def_id.unwrap_or(def_id)
1831+
&& let maybe_trait_def_id = self.tcx.parent(maybe_trait_item_def_id)
18301832
// Just an easy way to check "trait_def_id == Fn/FnMut/FnOnce"
1831-
&& ty::ClosureKind::from_def_id(self.tcx, trait_def_id).is_some()
1833+
&& let Some(call_kind) = ty::ClosureKind::from_def_id(self.tcx, maybe_trait_def_id)
18321834
&& let Some(callee_ty) = callee_ty
18331835
{
18341836
let callee_ty = callee_ty.peel_refs();
@@ -1853,7 +1855,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
18531855
std::iter::zip(instantiated.predicates, instantiated.spans)
18541856
{
18551857
if let ty::PredicateKind::Trait(pred) = predicate.kind().skip_binder()
1856-
&& pred.self_ty() == callee_ty
1858+
&& pred.self_ty().peel_refs() == callee_ty
18571859
&& ty::ClosureKind::from_def_id(self.tcx, pred.def_id()).is_some()
18581860
{
18591861
err.span_note(span, "callable defined here");
@@ -1862,11 +1864,46 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
18621864
}
18631865
}
18641866
}
1865-
ty::Opaque(new_def_id, _) | ty::Closure(new_def_id, _) | ty::FnDef(new_def_id, _) => {
1867+
ty::Opaque(new_def_id, _)
1868+
| ty::Closure(new_def_id, _)
1869+
| ty::FnDef(new_def_id, _) => {
18661870
def_id = new_def_id;
18671871
}
18681872
_ => {
1869-
return;
1873+
// Look for a user-provided impl of a `Fn` trait, and point to it.
1874+
let new_def_id = self.probe(|_| {
1875+
let trait_ref = ty::TraitRef::new(
1876+
call_kind.to_def_id(self.tcx),
1877+
self.tcx.mk_substs([
1878+
ty::GenericArg::from(callee_ty),
1879+
self.next_ty_var(TypeVariableOrigin {
1880+
kind: TypeVariableOriginKind::MiscVariable,
1881+
span: rustc_span::DUMMY_SP,
1882+
})
1883+
.into(),
1884+
].into_iter()),
1885+
);
1886+
let obligation = traits::Obligation::new(
1887+
traits::ObligationCause::dummy(),
1888+
self.param_env,
1889+
ty::Binder::dummy(ty::TraitPredicate {
1890+
trait_ref,
1891+
constness: ty::BoundConstness::NotConst,
1892+
polarity: ty::ImplPolarity::Positive,
1893+
}),
1894+
);
1895+
match SelectionContext::new(&self).select(&obligation) {
1896+
Ok(Some(traits::ImplSource::UserDefined(impl_source))) => {
1897+
Some(impl_source.impl_def_id)
1898+
}
1899+
_ => None
1900+
}
1901+
});
1902+
if let Some(new_def_id) = new_def_id {
1903+
def_id = new_def_id;
1904+
} else {
1905+
return;
1906+
}
18701907
}
18711908
}
18721909
}
@@ -1888,8 +1925,8 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
18881925

18891926
let def_kind = self.tcx.def_kind(def_id);
18901927
err.span_note(spans, &format!("{} defined here", def_kind.descr(def_id)));
1891-
} else if let def_kind @ (DefKind::Closure | DefKind::OpaqueTy) = self.tcx.def_kind(def_id)
1892-
{
1928+
} else {
1929+
let def_kind = self.tcx.def_kind(def_id);
18931930
err.span_note(
18941931
self.tcx.def_span(def_id),
18951932
&format!("{} defined here", def_kind.descr(def_id)),

Diff for: src/test/ui/mismatched_types/overloaded-calls-bad.stderr

+16
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,24 @@ LL | let ans = s("what");
55
| - ^^^^^^ expected `isize`, found `&str`
66
| |
77
| arguments to this function are incorrect
8+
|
9+
note: implementation defined here
10+
--> $DIR/overloaded-calls-bad.rs:10:1
11+
|
12+
LL | impl FnMut<(isize,)> for S {
13+
| ^^^^^^^^^^^^^^^^^^^^^^^^^^
814

915
error[E0057]: this function takes 1 argument but 0 arguments were supplied
1016
--> $DIR/overloaded-calls-bad.rs:29:15
1117
|
1218
LL | let ans = s();
1319
| ^-- an argument of type `isize` is missing
1420
|
21+
note: implementation defined here
22+
--> $DIR/overloaded-calls-bad.rs:10:1
23+
|
24+
LL | impl FnMut<(isize,)> for S {
25+
| ^^^^^^^^^^^^^^^^^^^^^^^^^^
1526
help: provide the argument
1627
|
1728
LL | let ans = s(/* isize */);
@@ -25,6 +36,11 @@ LL | let ans = s("burma", "shave");
2536
| |
2637
| expected `isize`, found `&str`
2738
|
39+
note: implementation defined here
40+
--> $DIR/overloaded-calls-bad.rs:10:1
41+
|
42+
LL | impl FnMut<(isize,)> for S {
43+
| ^^^^^^^^^^^^^^^^^^^^^^^^^^
2844
help: remove the extra argument
2945
|
3046
LL | let ans = s(/* isize */);

0 commit comments

Comments
 (0)