Skip to content

Commit c414849

Browse files
committed
[ty] Filter overloads based on Any / Unknown
1 parent 4a9739f commit c414849

File tree

2 files changed

+137
-11
lines changed

2 files changed

+137
-11
lines changed

crates/ty_python_semantic/src/types.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5781,9 +5781,9 @@ impl<'db> KnownInstanceType<'db> {
57815781

57825782
#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq)]
57835783
pub enum DynamicType {
5784-
// An explicitly annotated `typing.Any`
5784+
/// An explicitly annotated `typing.Any`
57855785
Any,
5786-
// An unannotated value, or a dynamic type resulting from an error
5786+
/// An unannotated value, or a dynamic type resulting from an error
57875787
Unknown,
57885788
/// Temporary type for symbols that can't be inferred yet because of missing implementations.
57895789
///

crates/ty_python_semantic/src/types/call/bind.rs

Lines changed: 135 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1029,7 +1029,7 @@ impl<'db> From<Binding<'db>> for Bindings<'db> {
10291029
signature_type,
10301030
dunder_call_is_possibly_unbound: false,
10311031
bound_type: None,
1032-
return_type: None,
1032+
overload_call_return_type: None,
10331033
overloads: smallvec![from],
10341034
};
10351035
Bindings {
@@ -1074,7 +1074,7 @@ pub(crate) struct CallableBinding<'db> {
10741074
/// performed, and one of the expansion evaluated successfully for all of the argument lists.
10751075
/// This type is then the union of all the return types of the matched overloads for the
10761076
/// expanded argument lists.
1077-
return_type: Option<Type<'db>>,
1077+
overload_call_return_type: Option<OverloadCallReturnType<'db>>,
10781078

10791079
/// The bindings of each overload of this callable. Will be empty if the type is not callable.
10801080
///
@@ -1097,7 +1097,7 @@ impl<'db> CallableBinding<'db> {
10971097
signature_type,
10981098
dunder_call_is_possibly_unbound: false,
10991099
bound_type: None,
1100-
return_type: None,
1100+
overload_call_return_type: None,
11011101
overloads,
11021102
}
11031103
}
@@ -1108,7 +1108,7 @@ impl<'db> CallableBinding<'db> {
11081108
signature_type,
11091109
dunder_call_is_possibly_unbound: false,
11101110
bound_type: None,
1111-
return_type: None,
1111+
overload_call_return_type: None,
11121112
overloads: smallvec![],
11131113
}
11141114
}
@@ -1196,9 +1196,18 @@ impl<'db> CallableBinding<'db> {
11961196
// If only one overload evaluates without error, it is the winning match.
11971197
return;
11981198
}
1199-
MatchingOverloadIndex::Multiple(_) => {
1199+
MatchingOverloadIndex::Multiple(indexes) => {
12001200
// If two or more candidate overloads remain, proceed to step 4.
1201-
// TODO: Step 4 and Step 5 goes here...
1201+
tracing::info!(
1202+
"Multiple overloads match: {:?}, filtering based on Any",
1203+
indexes
1204+
);
1205+
1206+
// TODO: Step 4
1207+
1208+
// Step 5
1209+
self.filter_overloads_using_any_or_unknown(db, argument_types.types(), &indexes);
1210+
12021211
// We're returning here because this shouldn't lead to argument type expansion.
12031212
return;
12041213
}
@@ -1277,7 +1286,10 @@ impl<'db> CallableBinding<'db> {
12771286
// If the number of return types is equal to the number of expanded argument lists,
12781287
// they all evaluated successfully. So, we need to combine their return types by
12791288
// union to determine the final return type.
1280-
self.return_type = Some(UnionType::from_elements(db, return_types));
1289+
self.overload_call_return_type =
1290+
Some(OverloadCallReturnType::ArgumentTypeExpansion(
1291+
UnionType::from_elements(db, return_types),
1292+
));
12811293

12821294
// Restore the bindings state to the one that merges the bindings state evaluating
12831295
// each of the expanded argument list.
@@ -1296,6 +1308,99 @@ impl<'db> CallableBinding<'db> {
12961308
snapshotter.restore(self, post_evaluation_snapshot);
12971309
}
12981310

1311+
/// Filter overloads based on [`Any`] or [`Unknown`] argument types.
1312+
///
1313+
/// This is the step 5 of the [overload call evaluation algorithm][1].
1314+
///
1315+
/// The filtering works on the remaining overloads that are present at the
1316+
/// `matching_overload_indexes` and are filtered out by marking them as unmatched overloads
1317+
/// using the [`mark_as_unmatched_overload`] method.
1318+
///
1319+
/// [`Any`]: crate::types::DynamicType::Any
1320+
/// [`Unknown`]: crate::types::DynamicType::Unknown
1321+
/// [`mark_as_unmatched_overload`]: Binding::mark_as_unmatched_overload
1322+
/// [1]: https://typing.python.org/en/latest/spec/overload.html#overload-call-evaluation
1323+
fn filter_overloads_using_any_or_unknown(
1324+
&mut self,
1325+
db: &'db dyn Db,
1326+
argument_types: &[Type<'db>],
1327+
matching_overload_indexes: &[usize],
1328+
) {
1329+
let top_materialized_argument_type = TupleType::from_elements(
1330+
db,
1331+
argument_types.iter().map(|argument_type| {
1332+
argument_type.top_materialization(db, TypeVarVariance::Covariant)
1333+
}),
1334+
);
1335+
1336+
// A flag to indicate whether we've found the overload that makes the remaining overloads
1337+
// unmatched for the given argument types.
1338+
let mut filter_remaining_overloads = false;
1339+
1340+
for (upto, current_index) in matching_overload_indexes.iter().enumerate() {
1341+
if filter_remaining_overloads {
1342+
self.overloads[*current_index].mark_as_unmatched_overload();
1343+
continue;
1344+
}
1345+
let mut unions = Vec::with_capacity(argument_types.len());
1346+
for argument_index in 0..argument_types.len() {
1347+
let mut union = vec![];
1348+
for overload_index in &matching_overload_indexes[..=upto] {
1349+
let overload = &self.overloads[*overload_index];
1350+
let Some(parameter_index) = overload.argument_parameters[argument_index] else {
1351+
// There is no parameter for this argument in this overload.
1352+
continue;
1353+
};
1354+
union.push(
1355+
overload.signature.parameters()[parameter_index]
1356+
.annotated_type()
1357+
.unwrap_or(Type::unknown()),
1358+
);
1359+
}
1360+
if union.is_empty() {
1361+
continue;
1362+
}
1363+
unions.push(UnionType::from_elements(db, union));
1364+
}
1365+
if unions.len() != argument_types.len() {
1366+
continue;
1367+
}
1368+
if top_materialized_argument_type
1369+
.is_assignable_to(db, TupleType::from_elements(db, unions))
1370+
{
1371+
filter_remaining_overloads = true;
1372+
}
1373+
}
1374+
1375+
// Once this filtering process is applied for all arguments, examine the return types of
1376+
// the remaining overloads. If the resulting return types for all remaining overloads are
1377+
// equivalent, proceed to step 6.
1378+
let are_return_types_equivalent_for_all_matching_overloads = {
1379+
let mut matching_overloads = self.matching_overloads();
1380+
if let Some(first_overload_return_type) = matching_overloads
1381+
.next()
1382+
.map(|(_, overload)| overload.return_type())
1383+
{
1384+
matching_overloads.all(|(_, overload)| {
1385+
overload
1386+
.return_type()
1387+
.is_equivalent_to(db, first_overload_return_type)
1388+
})
1389+
} else {
1390+
// No matching overload
1391+
true
1392+
}
1393+
};
1394+
1395+
if !are_return_types_equivalent_for_all_matching_overloads {
1396+
// Overload matching is ambiguous.
1397+
for (_, overload) in self.matching_overloads_mut() {
1398+
overload.mark_as_unmatched_overload();
1399+
}
1400+
self.overload_call_return_type = Some(OverloadCallReturnType::Ambiguous);
1401+
}
1402+
}
1403+
12991404
fn as_result(&self) -> Result<(), CallErrorKind> {
13001405
if !self.is_callable() {
13011406
return Err(CallErrorKind::NotCallable);
@@ -1370,8 +1475,11 @@ impl<'db> CallableBinding<'db> {
13701475
/// For an invalid call to an overloaded function, we return `Type::unknown`, since we cannot
13711476
/// make any useful conclusions about which overload was intended to be called.
13721477
pub(crate) fn return_type(&self) -> Type<'db> {
1373-
if let Some(return_type) = self.return_type {
1374-
return return_type;
1478+
if let Some(overload_call_return_type) = self.overload_call_return_type {
1479+
return match overload_call_return_type {
1480+
OverloadCallReturnType::ArgumentTypeExpansion(return_type) => return_type,
1481+
OverloadCallReturnType::Ambiguous => Type::any(),
1482+
};
13751483
}
13761484
if let Some((_, first_overload)) = self.matching_overloads().next() {
13771485
return first_overload.return_type();
@@ -1414,6 +1522,10 @@ impl<'db> CallableBinding<'db> {
14141522
return;
14151523
}
14161524

1525+
if self.overload_call_return_type.is_some() {
1526+
return;
1527+
}
1528+
14171529
match self.overloads.as_slice() {
14181530
[] => {}
14191531
[overload] => {
@@ -1521,6 +1633,12 @@ impl<'a, 'db> IntoIterator for &'a CallableBinding<'db> {
15211633
}
15221634
}
15231635

1636+
#[derive(Debug, Copy, Clone)]
1637+
enum OverloadCallReturnType<'db> {
1638+
ArgumentTypeExpansion(Type<'db>),
1639+
Ambiguous,
1640+
}
1641+
15241642
#[derive(Debug)]
15251643
enum MatchingOverloadIndex {
15261644
/// No matching overloads found.
@@ -1855,6 +1973,10 @@ impl<'db> Binding<'db> {
18551973
.map(|(arg_and_type, _)| arg_and_type)
18561974
}
18571975

1976+
fn mark_as_unmatched_overload(&mut self) {
1977+
self.errors.push(BindingError::UnmatchedOverload);
1978+
}
1979+
18581980
fn report_diagnostics(
18591981
&self,
18601982
context: &InferContext<'db, '_>,
@@ -2140,6 +2262,8 @@ pub(crate) enum BindingError<'db> {
21402262
/// We use this variant to report errors in `property.__get__` and `property.__set__`, which
21412263
/// can occur when the call to the underlying getter/setter fails.
21422264
InternalCallError(&'static str),
2265+
/// This overload of the callable does not match the arguments.
2266+
UnmatchedOverload,
21432267
}
21442268

21452269
impl<'db> BindingError<'db> {
@@ -2332,6 +2456,8 @@ impl<'db> BindingError<'db> {
23322456
}
23332457
}
23342458
}
2459+
2460+
Self::UnmatchedOverload => {}
23352461
}
23362462
}
23372463

0 commit comments

Comments
 (0)