@@ -33,10 +33,10 @@ use crate::types::{
3333 BoundMethodType , ClassLiteral , DataclassParams , FieldInstance , KnownBoundMethodType ,
3434 KnownClass , KnownInstanceType , MemberLookupPolicy , PropertyInstanceType , SpecialFormType ,
3535 TrackedConstraintSet , TypeAliasType , TypeContext , UnionBuilder , UnionType ,
36- WrapperDescriptorKind , enums, ide_support, todo_type,
36+ WrapperDescriptorKind , enums, ide_support, infer_isolated_expression , todo_type,
3737} ;
3838use ruff_db:: diagnostic:: { Annotation , Diagnostic , SubDiagnostic , SubDiagnosticSeverity } ;
39- use ruff_python_ast:: { self as ast, PythonVersion } ;
39+ use ruff_python_ast:: { self as ast, ArgOrKeyword , PythonVersion } ;
4040
4141/// Binding information for a possible union of callables. At a call site, the arguments must be
4242/// compatible with _all_ of the types in the union for the call to be valid.
@@ -1776,7 +1776,7 @@ impl<'db> CallableBinding<'db> {
17761776 }
17771777
17781778 /// Returns the index of the matching overload in the form of [`MatchingOverloadIndex`].
1779- fn matching_overload_index ( & self ) -> MatchingOverloadIndex {
1779+ pub ( crate ) fn matching_overload_index ( & self ) -> MatchingOverloadIndex {
17801780 let mut matching_overloads = self . matching_overloads ( ) ;
17811781 match matching_overloads. next ( ) {
17821782 None => MatchingOverloadIndex :: None ,
@@ -1794,8 +1794,15 @@ impl<'db> CallableBinding<'db> {
17941794 }
17951795 }
17961796
1797+ /// Returns all overloads for this call binding, including overloads that did not match.
1798+ pub ( crate ) fn overloads ( & self ) -> & [ Binding < ' db > ] {
1799+ self . overloads . as_slice ( )
1800+ }
1801+
17971802 /// Returns an iterator over all the overloads that matched for this call binding.
1798- pub ( crate ) fn matching_overloads ( & self ) -> impl Iterator < Item = ( usize , & Binding < ' db > ) > {
1803+ pub ( crate ) fn matching_overloads (
1804+ & self ,
1805+ ) -> impl Iterator < Item = ( usize , & Binding < ' db > ) > + Clone {
17991806 self . overloads
18001807 . iter ( )
18011808 . enumerate ( )
@@ -2026,7 +2033,7 @@ enum OverloadCallReturnType<'db> {
20262033}
20272034
20282035#[ derive( Debug ) ]
2029- enum MatchingOverloadIndex {
2036+ pub ( crate ) enum MatchingOverloadIndex {
20302037 /// No matching overloads found.
20312038 None ,
20322039
@@ -2504,9 +2511,17 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
25042511 if let Some ( return_ty) = self . signature . return_ty
25052512 && let Some ( call_expression_tcx) = self . call_expression_tcx . annotation
25062513 {
2507- // Ignore any specialization errors here, because the type context is only used to
2508- // optionally widen the return type.
2509- let _ = builder. infer ( return_ty, call_expression_tcx) ;
2514+ match call_expression_tcx {
2515+ // A type variable is not a useful type-context for expression inference, and applying it
2516+ // to the return type can lead to confusing unions in nested generic calls.
2517+ Type :: TypeVar ( _) => { }
2518+
2519+ _ => {
2520+ // Ignore any specialization errors here, because the type context is only used as a hint
2521+ // to infer a more assignable return type.
2522+ let _ = builder. infer ( return_ty, call_expression_tcx) ;
2523+ }
2524+ }
25102525 }
25112526
25122527 let parameters = self . signature . parameters ( ) ;
@@ -3289,6 +3304,23 @@ impl<'db> BindingError<'db> {
32893304 return ;
32903305 } ;
32913306
3307+ // Re-infer the argument type of call expressions, ignoring the type context for more
3308+ // precise error messages.
3309+ let provided_ty = match Self :: get_argument_node ( node, * argument_index) {
3310+ None => * provided_ty,
3311+
3312+ // Ignore starred arguments, as those are difficult to re-infer.
3313+ Some (
3314+ ast:: ArgOrKeyword :: Arg ( ast:: Expr :: Starred ( _) )
3315+ | ast:: ArgOrKeyword :: Keyword ( ast:: Keyword { arg : None , .. } ) ,
3316+ ) => * provided_ty,
3317+
3318+ Some (
3319+ ast:: ArgOrKeyword :: Arg ( value)
3320+ | ast:: ArgOrKeyword :: Keyword ( ast:: Keyword { value, .. } ) ,
3321+ ) => infer_isolated_expression ( context. db ( ) , context. scope ( ) , value) ,
3322+ } ;
3323+
32923324 let provided_ty_display = provided_ty. display ( context. db ( ) ) ;
32933325 let expected_ty_display = expected_ty. display ( context. db ( ) ) ;
32943326
@@ -3624,22 +3656,29 @@ impl<'db> BindingError<'db> {
36243656 }
36253657 }
36263658
3627- fn get_node ( node : ast:: AnyNodeRef , argument_index : Option < usize > ) -> ast:: AnyNodeRef {
3659+ fn get_node ( node : ast:: AnyNodeRef < ' _ > , argument_index : Option < usize > ) -> ast:: AnyNodeRef < ' _ > {
36283660 // If we have a Call node and an argument index, report the diagnostic on the correct
36293661 // argument node; otherwise, report it on the entire provided node.
3662+ match Self :: get_argument_node ( node, argument_index) {
3663+ Some ( ast:: ArgOrKeyword :: Arg ( expr) ) => expr. into ( ) ,
3664+ Some ( ast:: ArgOrKeyword :: Keyword ( expr) ) => expr. into ( ) ,
3665+ None => node,
3666+ }
3667+ }
3668+
3669+ fn get_argument_node (
3670+ node : ast:: AnyNodeRef < ' _ > ,
3671+ argument_index : Option < usize > ,
3672+ ) -> Option < ArgOrKeyword < ' _ > > {
36303673 match ( node, argument_index) {
3631- ( ast:: AnyNodeRef :: ExprCall ( call_node) , Some ( argument_index) ) => {
3632- match call_node
3674+ ( ast:: AnyNodeRef :: ExprCall ( call_node) , Some ( argument_index) ) => Some (
3675+ call_node
36333676 . arguments
36343677 . arguments_source_order ( )
36353678 . nth ( argument_index)
3636- . expect ( "argument index should not be out of range" )
3637- {
3638- ast:: ArgOrKeyword :: Arg ( expr) => expr. into ( ) ,
3639- ast:: ArgOrKeyword :: Keyword ( keyword) => keyword. into ( ) ,
3640- }
3641- }
3642- _ => node,
3679+ . expect ( "argument index should not be out of range" ) ,
3680+ ) ,
3681+ _ => None ,
36433682 }
36443683 }
36453684}
0 commit comments