@@ -1029,7 +1029,7 @@ impl<'db> From<Binding<'db>> for Bindings<'db> {
10291029 signature_type,
10301030 dunder_call_is_possibly_unbound : false ,
10311031 bound_type : None ,
1032- return_type : None ,
1032+ overload_call_return_type : None ,
10331033 overloads : smallvec ! [ from] ,
10341034 } ;
10351035 Bindings {
@@ -1074,7 +1074,7 @@ pub(crate) struct CallableBinding<'db> {
10741074 /// performed, and one of the expansion evaluated successfully for all of the argument lists.
10751075 /// This type is then the union of all the return types of the matched overloads for the
10761076 /// expanded argument lists.
1077- return_type : Option < Type < ' db > > ,
1077+ overload_call_return_type : Option < OverloadCallReturnType < ' db > > ,
10781078
10791079 /// The bindings of each overload of this callable. Will be empty if the type is not callable.
10801080 ///
@@ -1097,7 +1097,7 @@ impl<'db> CallableBinding<'db> {
10971097 signature_type,
10981098 dunder_call_is_possibly_unbound : false ,
10991099 bound_type : None ,
1100- return_type : None ,
1100+ overload_call_return_type : None ,
11011101 overloads,
11021102 }
11031103 }
@@ -1108,7 +1108,7 @@ impl<'db> CallableBinding<'db> {
11081108 signature_type,
11091109 dunder_call_is_possibly_unbound : false ,
11101110 bound_type : None ,
1111- return_type : None ,
1111+ overload_call_return_type : None ,
11121112 overloads : smallvec ! [ ] ,
11131113 }
11141114 }
@@ -1196,9 +1196,18 @@ impl<'db> CallableBinding<'db> {
11961196 // If only one overload evaluates without error, it is the winning match.
11971197 return ;
11981198 }
1199- MatchingOverloadIndex :: Multiple ( _ ) => {
1199+ MatchingOverloadIndex :: Multiple ( indexes ) => {
12001200 // If two or more candidate overloads remain, proceed to step 4.
1201- // TODO: Step 4 and Step 5 goes here...
1201+ tracing:: info!(
1202+ "Multiple overloads match: {:?}, filtering based on Any" ,
1203+ indexes
1204+ ) ;
1205+
1206+ // TODO: Step 4
1207+
1208+ // Step 5
1209+ self . filter_overloads_using_any_or_unknown ( db, argument_types. types ( ) , & indexes) ;
1210+
12021211 // We're returning here because this shouldn't lead to argument type expansion.
12031212 return ;
12041213 }
@@ -1277,7 +1286,10 @@ impl<'db> CallableBinding<'db> {
12771286 // If the number of return types is equal to the number of expanded argument lists,
12781287 // they all evaluated successfully. So, we need to combine their return types by
12791288 // union to determine the final return type.
1280- self . return_type = Some ( UnionType :: from_elements ( db, return_types) ) ;
1289+ self . overload_call_return_type =
1290+ Some ( OverloadCallReturnType :: ArgumentTypeExpansion (
1291+ UnionType :: from_elements ( db, return_types) ,
1292+ ) ) ;
12811293
12821294 // Restore the bindings state to the one that merges the bindings state evaluating
12831295 // each of the expanded argument list.
@@ -1296,6 +1308,99 @@ impl<'db> CallableBinding<'db> {
12961308 snapshotter. restore ( self , post_evaluation_snapshot) ;
12971309 }
12981310
1311+ /// Filter overloads based on [`Any`] or [`Unknown`] argument types.
1312+ ///
1313+ /// This is the step 5 of the [overload call evaluation algorithm][1].
1314+ ///
1315+ /// The filtering works on the remaining overloads that are present at the
1316+ /// `matching_overload_indexes` and are filtered out by marking them as unmatched overloads
1317+ /// using the [`mark_as_unmatched_overload`] method.
1318+ ///
1319+ /// [`Any`]: crate::types::DynamicType::Any
1320+ /// [`Unknown`]: crate::types::DynamicType::Unknown
1321+ /// [`mark_as_unmatched_overload`]: Binding::mark_as_unmatched_overload
1322+ /// [1]: https://typing.python.org/en/latest/spec/overload.html#overload-call-evaluation
1323+ fn filter_overloads_using_any_or_unknown (
1324+ & mut self ,
1325+ db : & ' db dyn Db ,
1326+ argument_types : & [ Type < ' db > ] ,
1327+ matching_overload_indexes : & [ usize ] ,
1328+ ) {
1329+ let top_materialized_argument_type = TupleType :: from_elements (
1330+ db,
1331+ argument_types. iter ( ) . map ( |argument_type| {
1332+ argument_type. top_materialization ( db, TypeVarVariance :: Covariant )
1333+ } ) ,
1334+ ) ;
1335+
1336+ // A flag to indicate whether we've found the overload that makes the remaining overloads
1337+ // unmatched for the given argument types.
1338+ let mut filter_remaining_overloads = false ;
1339+
1340+ for ( upto, current_index) in matching_overload_indexes. iter ( ) . enumerate ( ) {
1341+ if filter_remaining_overloads {
1342+ self . overloads [ * current_index] . mark_as_unmatched_overload ( ) ;
1343+ continue ;
1344+ }
1345+ let mut unions = Vec :: with_capacity ( argument_types. len ( ) ) ;
1346+ for argument_index in 0 ..argument_types. len ( ) {
1347+ let mut union = vec ! [ ] ;
1348+ for overload_index in & matching_overload_indexes[ ..=upto] {
1349+ let overload = & self . overloads [ * overload_index] ;
1350+ let Some ( parameter_index) = overload. argument_parameters [ argument_index] else {
1351+ // There is no parameter for this argument in this overload.
1352+ continue ;
1353+ } ;
1354+ union. push (
1355+ overload. signature . parameters ( ) [ parameter_index]
1356+ . annotated_type ( )
1357+ . unwrap_or ( Type :: unknown ( ) ) ,
1358+ ) ;
1359+ }
1360+ if union. is_empty ( ) {
1361+ continue ;
1362+ }
1363+ unions. push ( UnionType :: from_elements ( db, union) ) ;
1364+ }
1365+ if unions. len ( ) != argument_types. len ( ) {
1366+ continue ;
1367+ }
1368+ if top_materialized_argument_type
1369+ . is_assignable_to ( db, TupleType :: from_elements ( db, unions) )
1370+ {
1371+ filter_remaining_overloads = true ;
1372+ }
1373+ }
1374+
1375+ // Once this filtering process is applied for all arguments, examine the return types of
1376+ // the remaining overloads. If the resulting return types for all remaining overloads are
1377+ // equivalent, proceed to step 6.
1378+ let are_return_types_equivalent_for_all_matching_overloads = {
1379+ let mut matching_overloads = self . matching_overloads ( ) ;
1380+ if let Some ( first_overload_return_type) = matching_overloads
1381+ . next ( )
1382+ . map ( |( _, overload) | overload. return_type ( ) )
1383+ {
1384+ matching_overloads. all ( |( _, overload) | {
1385+ overload
1386+ . return_type ( )
1387+ . is_equivalent_to ( db, first_overload_return_type)
1388+ } )
1389+ } else {
1390+ // No matching overload
1391+ true
1392+ }
1393+ } ;
1394+
1395+ if !are_return_types_equivalent_for_all_matching_overloads {
1396+ // Overload matching is ambiguous.
1397+ for ( _, overload) in self . matching_overloads_mut ( ) {
1398+ overload. mark_as_unmatched_overload ( ) ;
1399+ }
1400+ self . overload_call_return_type = Some ( OverloadCallReturnType :: Ambiguous ) ;
1401+ }
1402+ }
1403+
12991404 fn as_result ( & self ) -> Result < ( ) , CallErrorKind > {
13001405 if !self . is_callable ( ) {
13011406 return Err ( CallErrorKind :: NotCallable ) ;
@@ -1370,8 +1475,11 @@ impl<'db> CallableBinding<'db> {
13701475 /// For an invalid call to an overloaded function, we return `Type::unknown`, since we cannot
13711476 /// make any useful conclusions about which overload was intended to be called.
13721477 pub ( crate ) fn return_type ( & self ) -> Type < ' db > {
1373- if let Some ( return_type) = self . return_type {
1374- return return_type;
1478+ if let Some ( overload_call_return_type) = self . overload_call_return_type {
1479+ return match overload_call_return_type {
1480+ OverloadCallReturnType :: ArgumentTypeExpansion ( return_type) => return_type,
1481+ OverloadCallReturnType :: Ambiguous => Type :: any ( ) ,
1482+ } ;
13751483 }
13761484 if let Some ( ( _, first_overload) ) = self . matching_overloads ( ) . next ( ) {
13771485 return first_overload. return_type ( ) ;
@@ -1414,6 +1522,10 @@ impl<'db> CallableBinding<'db> {
14141522 return ;
14151523 }
14161524
1525+ if self . overload_call_return_type . is_some ( ) {
1526+ return ;
1527+ }
1528+
14171529 match self . overloads . as_slice ( ) {
14181530 [ ] => { }
14191531 [ overload] => {
@@ -1521,6 +1633,12 @@ impl<'a, 'db> IntoIterator for &'a CallableBinding<'db> {
15211633 }
15221634}
15231635
1636+ #[ derive( Debug , Copy , Clone ) ]
1637+ enum OverloadCallReturnType < ' db > {
1638+ ArgumentTypeExpansion ( Type < ' db > ) ,
1639+ Ambiguous ,
1640+ }
1641+
15241642#[ derive( Debug ) ]
15251643enum MatchingOverloadIndex {
15261644 /// No matching overloads found.
@@ -1855,6 +1973,10 @@ impl<'db> Binding<'db> {
18551973 . map ( |( arg_and_type, _) | arg_and_type)
18561974 }
18571975
1976+ fn mark_as_unmatched_overload ( & mut self ) {
1977+ self . errors . push ( BindingError :: UnmatchedOverload ) ;
1978+ }
1979+
18581980 fn report_diagnostics (
18591981 & self ,
18601982 context : & InferContext < ' db , ' _ > ,
@@ -2140,6 +2262,8 @@ pub(crate) enum BindingError<'db> {
21402262 /// We use this variant to report errors in `property.__get__` and `property.__set__`, which
21412263 /// can occur when the call to the underlying getter/setter fails.
21422264 InternalCallError ( & ' static str ) ,
2265+ /// This overload of the callable does not match the arguments.
2266+ UnmatchedOverload ,
21432267}
21442268
21452269impl < ' db > BindingError < ' db > {
@@ -2332,6 +2456,8 @@ impl<'db> BindingError<'db> {
23322456 }
23332457 }
23342458 }
2459+
2460+ Self :: UnmatchedOverload => { }
23352461 }
23362462 }
23372463
0 commit comments