diff --git a/crates/ty_python_semantic/resources/mdtest/call/overloads.md b/crates/ty_python_semantic/resources/mdtest/call/overloads.md index 2258035524679..d49ee12639484 100644 --- a/crates/ty_python_semantic/resources/mdtest/call/overloads.md +++ b/crates/ty_python_semantic/resources/mdtest/call/overloads.md @@ -399,3 +399,475 @@ def _(x: SomeEnum): # TODO: This should be `A | B | C` once enums are supported and are expanded reveal_type(f(x)) # revealed: A ``` + +## Filtering overloads with variadic arguments and parameters + +TODO + +## Filtering based on `Any` / `Unknown` + +This is the step 5 of the overload call evaluation algorithm which specifies that: + +> For all arguments, determine whether all possible materializations of the argument’s type are +> assignable to the corresponding parameter type for each of the remaining overloads. If so, +> eliminate all of the subsequent remaining overloads. + +This is only performed if the previous step resulted in more than one matching overload. + +### Single list argument + +`overloaded.pyi`: + +```pyi +from typing import Any, overload + +@overload +def f(x: list[int]) -> int: ... +@overload +def f(x: list[Any]) -> int: ... +@overload +def f(x: Any) -> str: ... +``` + +For the above definition, anything other than `list` should match the last overload: + +```py +from typing import Any + +from overloaded import f + +# Anything other than `list` should match the last overload +reveal_type(f(1)) # revealed: str + +def _(list_int: list[int], list_any: list[Any]): + reveal_type(f(list_int)) # revealed: int + reveal_type(f(list_any)) # revealed: int +``` + +### Single list argument (ambiguous) + +The overload definition is the same as above, but the return type of the second overload is changed +to `str` to make the overload matching ambiguous if the argument is a `list[Any]`. + +`overloaded.pyi`: + +```pyi +from typing import Any, overload + +@overload +def f(x: list[int]) -> int: ... +@overload +def f(x: list[Any]) -> str: ... +@overload +def f(x: Any) -> str: ... +``` + +```py +from typing import Any + +from overloaded import f + +# Anything other than `list` should match the last overload +reveal_type(f(1)) # revealed: str + +def _(list_int: list[int], list_any: list[Any]): + # All materializations of `list[int]` are assignable to `list[int]`, so it matches the first + # overload. + reveal_type(f(list_int)) # revealed: int + + # All materializations of `list[Any]` are assignable to `list[int]` and `list[Any]`, but the + # return type of first and second overloads are not equivalent, so the overload matching + # is ambiguous. + reveal_type(f(list_any)) # revealed: Unknown +``` + +### Single tuple argument + +`overloaded.pyi`: + +```pyi +from typing import Any, overload + +@overload +def f(x: tuple[int, str]) -> int: ... +@overload +def f(x: tuple[int, Any]) -> int: ... +@overload +def f(x: Any) -> str: ... +``` + +```py +from typing import Any + +from overloaded import f + +reveal_type(f("a")) # revealed: str +reveal_type(f((1, "b"))) # revealed: int +reveal_type(f((1, 2))) # revealed: int + +def _(int_str: tuple[int, str], int_any: tuple[int, Any], any_any: tuple[Any, Any]): + # All materializations are assignable to first overload, so second and third overloads are + # eliminated + reveal_type(f(int_str)) # revealed: int + + # All materializations are assignable to second overload, so the third overload is eliminated; + # the return type of first and second overload is equivalent + reveal_type(f(int_any)) # revealed: int + + # All materializations of `tuple[Any, Any]` are assignable to the parameters of all the + # overloads, but the return types aren't equivalent, so the overload matching is ambiguous + reveal_type(f(any_any)) # revealed: Unknown +``` + +### Multiple arguments + +`overloaded.pyi`: + +```pyi +from typing import Any, overload + +class A: ... +class B: ... + +@overload +def f(x: list[int], y: tuple[int, str]) -> A: ... +@overload +def f(x: list[Any], y: tuple[int, Any]) -> A: ... +@overload +def f(x: list[Any], y: tuple[Any, Any]) -> B: ... +``` + +```py +from typing import Any + +from overloaded import A, f + +def _(list_int: list[int], list_any: list[Any], int_str: tuple[int, str], int_any: tuple[int, Any], any_any: tuple[Any, Any]): + # All materializations of both argument types are assignable to the first overload, so the + # second and third overloads are filtered out + reveal_type(f(list_int, int_str)) # revealed: A + + # All materialization of first argument is assignable to first overload and for the second + # argument, they're assignable to the second overload, so the third overload is filtered out + reveal_type(f(list_int, int_any)) # revealed: A + + # All materialization of first argument is assignable to second overload and for the second + # argument, they're assignable to the first overload, so the third overload is filtered out + reveal_type(f(list_any, int_str)) # revealed: A + + # All materializations of both arguments are assignable to the second overload, so the third + # overload is filtered out + reveal_type(f(list_any, int_any)) # revealed: A + + # All materializations of first argument is assignable to the second overload and for the second + # argument, they're assignable to the third overload, so no overloads are filtered out; the + # return types of the remaining overloads are not equivalent, so overload matching is ambiguous + reveal_type(f(list_int, any_any)) # revealed: Unknown +``` + +### `LiteralString` and `str` + +`overloaded.pyi`: + +```pyi +from typing import overload +from typing_extensions import LiteralString + +@overload +def f(x: LiteralString) -> LiteralString: ... +@overload +def f(x: str) -> str: ... +``` + +```py +from typing import Any +from typing_extensions import LiteralString + +from overloaded import f + +def _(literal: LiteralString, string: str, any: Any): + reveal_type(f(literal)) # revealed: LiteralString + reveal_type(f(string)) # revealed: str + + # `Any` matches both overloads, but the return types are not equivalent. + # Pyright and mypy both reveal `str` here, contrary to the spec. + reveal_type(f(any)) # revealed: Unknown +``` + +### Generics + +`overloaded.pyi`: + +```pyi +from typing import Any, TypeVar, overload + +_T = TypeVar("_T") + +class A: ... +class B: ... + +@overload +def f(x: list[int]) -> A: ... +@overload +def f(x: list[_T]) -> _T: ... +@overload +def f(x: Any) -> B: ... +``` + +```py +from typing import Any + +from overloaded import f + +def _(list_int: list[int], list_str: list[str], list_any: list[Any], any: Any): + reveal_type(f(list_int)) # revealed: A + # TODO: Should be `str` + reveal_type(f(list_str)) # revealed: Unknown + reveal_type(f(list_any)) # revealed: Unknown + reveal_type(f(any)) # revealed: Unknown +``` + +### Generics (multiple arguments) + +`overloaded.pyi`: + +```pyi +from typing import Any, TypeVar, overload + +_T = TypeVar("_T") + +@overload +def f(x: int, y: Any) -> int: ... +@overload +def f(x: str, y: _T) -> _T: ... +``` + +```py +from typing import Any + +from overloaded import f + +def _(integer: int, string: str, any: Any, list_any: list[Any]): + reveal_type(f(integer, string)) # revealed: int + reveal_type(f(string, integer)) # revealed: int + + # This matches the second overload and is _not_ the case of ambiguous overload matching. + reveal_type(f(string, any)) # revealed: Any + + reveal_type(f(string, list_any)) # revealed: list[Any] +``` + +### Generic `self` + +`overloaded.pyi`: + +```pyi +from typing import Any, overload, TypeVar, Generic + +_T = TypeVar("_T") + +class A(Generic[_T]): + @overload + def method(self: "A[int]") -> int: ... + @overload + def method(self: "A[Any]") -> int: ... + +class B(Generic[_T]): + @overload + def method(self: "B[int]") -> int: ... + @overload + def method(self: "B[Any]") -> str: ... +``` + +```py +from typing import Any + +from overloaded import A, B + +def _(a_int: A[int], a_str: A[str], a_any: A[Any]): + reveal_type(a_int.method()) # revealed: int + reveal_type(a_str.method()) # revealed: int + reveal_type(a_any.method()) # revealed: int + +def _(b_int: B[int], b_str: B[str], b_any: B[Any]): + reveal_type(b_int.method()) # revealed: int + reveal_type(b_str.method()) # revealed: str + reveal_type(b_any.method()) # revealed: Unknown +``` + +### Variadic argument + +TODO: A variadic parameter is being assigned to a number of parameters of the same type + +### Non-participating fully-static parameter + +Ref: + +A non-participating parameter would be the one where the set of materializations of the argument +type, that are assignable to the parameter type at the same index, is same for the overloads for +which step 5 needs to be performed. + +`overloaded.pyi`: + +```pyi +from typing import Literal, overload + +@overload +def f(x: str, *, flag: Literal[True]) -> int: ... +@overload +def f(x: str, *, flag: Literal[False] = ...) -> str: ... +@overload +def f(x: str, *, flag: bool = ...) -> int | str: ... +``` + +In the following example, for the `f(any, flag=True)` call, the materializations of first argument +type `Any` that are assignable to `str` is same for overloads 1 and 3 (at the time of step 5), so +for the purposes of overload matching that parameter can be ignored. If `Any` materializes to +anything that's not assignable to `str`, all of the overloads would already be filtered out which +will raise a `no-matching-overload` error. + +```py +from typing import Any + +from overloaded import f + +def _(any: Any): + reveal_type(f(any, flag=True)) # revealed: int + reveal_type(f(any, flag=False)) # revealed: str +``` + +### Non-participating gradual parameter + +`overloaded.pyi`: + +```pyi +from typing import Any, Literal, overload + +@overload +def f(x: tuple[str, Any], *, flag: Literal[True]) -> int: ... +@overload +def f(x: tuple[str, Any], *, flag: Literal[False] = ...) -> str: ... +@overload +def f(x: tuple[str, Any], *, flag: bool = ...) -> int | str: ... +``` + +```py +from typing import Any + +from overloaded import f + +def _(any: Any): + reveal_type(f(any, flag=True)) # revealed: int + reveal_type(f(any, flag=False)) # revealed: str +``` + +### Argument type expansion + +This filtering can also happen for each of the expanded argument lists. + +#### No ambiguity + +`overloaded.pyi`: + +```pyi +from typing import Any, overload + +class A: ... +class B: ... + +@overload +def f(x: tuple[A, B]) -> A: ... +@overload +def f(x: tuple[B, A]) -> B: ... +@overload +def f(x: tuple[A, Any]) -> A: ... +@overload +def f(x: tuple[B, Any]) -> B: ... +``` + +Here, the argument `tuple[A | B, Any]` doesn't match any of the overloads, so we perform argument +type expansion which results in two argument lists: + +1. `tuple[A, Any]` +1. `tuple[B, Any]` + +The first argument list matches overload 1 and 3 via `Any` materialization for which the return +types are equivalent (`A`). Similarly, the second argument list matches overload 2 and 4 via `Any` +materialization for which the return types are equivalent (`B`). The final return type for the call +will be the union of the return types. + +```py +from typing import Any + +from overloaded import A, B, f + +def _(arg: tuple[A | B, Any]): + reveal_type(f(arg)) # revealed: A | B +``` + +#### One argument list ambiguous + +The example used here is same as the previous one, but the return type of the last overload is +changed so that it's not equivalent to the return type of the second overload, creating an ambiguous +matching for the second argument list. + +`overloaded.pyi`: + +```pyi +from typing import Any, overload + +class A: ... +class B: ... +class C: ... + +@overload +def f(x: tuple[A, B]) -> A: ... +@overload +def f(x: tuple[B, A]) -> B: ... +@overload +def f(x: tuple[A, Any]) -> A: ... +@overload +def f(x: tuple[B, Any]) -> C: ... +``` + +```py +from typing import Any + +from overloaded import A, B, C, f + +def _(arg: tuple[A | B, Any]): + reveal_type(f(arg)) # revealed: A | Unknown +``` + +#### Both argument lists ambiguous + +Here, both argument lists created by expanding the argument type are ambiguous, so the final return +type is `Any`. + +`overloaded.pyi`: + +```pyi +from typing import Any, overload + +class A: ... +class B: ... +class C: ... + +@overload +def f(x: tuple[A, B]) -> A: ... +@overload +def f(x: tuple[B, A]) -> B: ... +@overload +def f(x: tuple[A, Any]) -> C: ... +@overload +def f(x: tuple[B, Any]) -> C: ... +``` + +```py +from typing import Any + +from overloaded import A, B, C, f + +def _(arg: tuple[A | B, Any]): + reveal_type(f(arg)) # revealed: Unknown +``` diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index 08deb8327cd12..59afe93cca86c 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -5831,9 +5831,9 @@ impl<'db> KnownInstanceType<'db> { #[derive(Copy, Clone, Debug, Eq, Hash, PartialEq)] pub enum DynamicType { - // An explicitly annotated `typing.Any` + /// An explicitly annotated `typing.Any` Any, - // An unannotated value, or a dynamic type resulting from an error + /// An unannotated value, or a dynamic type resulting from an error Unknown, /// Temporary type for symbols that can't be inferred yet because of missing implementations. /// diff --git a/crates/ty_python_semantic/src/types/call/bind.rs b/crates/ty_python_semantic/src/types/call/bind.rs index 45f7d5694da00..fb41f559af469 100644 --- a/crates/ty_python_semantic/src/types/call/bind.rs +++ b/crates/ty_python_semantic/src/types/call/bind.rs @@ -3,6 +3,8 @@ //! [signatures][crate::types::signatures], we have to handle the fact that the callable might be a //! union of types, each of which might contain multiple overloads. +use std::collections::HashSet; + use itertools::Itertools; use ruff_db::parsed::parsed_module; use smallvec::{SmallVec, smallvec}; @@ -1029,7 +1031,7 @@ impl<'db> From> for Bindings<'db> { signature_type, dunder_call_is_possibly_unbound: false, bound_type: None, - return_type: None, + overload_call_return_type: None, overloads: smallvec![from], }; Bindings { @@ -1068,13 +1070,21 @@ pub(crate) struct CallableBinding<'db> { /// The type of the bound `self` or `cls` parameter if this signature is for a bound method. pub(crate) bound_type: Option>, - /// The return type of this callable. + /// The return type of this overloaded callable. + /// + /// This is [`Some`] only in the following cases: + /// 1. Argument type expansion was performed and one of the expansions evaluated successfully + /// for all of the argument lists, or + /// 2. Overload call evaluation was ambiguous, meaning that multiple overloads matched the + /// argument lists, but they all had different return types + /// + /// For (1), the final return type is the union of all the return types of the matched + /// overloads for the expanded argument lists. /// - /// This is only `Some` if it's an overloaded callable, "argument type expansion" was - /// performed, and one of the expansion evaluated successfully for all of the argument lists. - /// This type is then the union of all the return types of the matched overloads for the - /// expanded argument lists. - return_type: Option>, + /// For (2), the final return type is [`Unknown`]. + /// + /// [`Unknown`]: crate::types::DynamicType::Unknown + overload_call_return_type: Option>, /// The bindings of each overload of this callable. Will be empty if the type is not callable. /// @@ -1097,7 +1107,7 @@ impl<'db> CallableBinding<'db> { signature_type, dunder_call_is_possibly_unbound: false, bound_type: None, - return_type: None, + overload_call_return_type: None, overloads, } } @@ -1108,7 +1118,7 @@ impl<'db> CallableBinding<'db> { signature_type, dunder_call_is_possibly_unbound: false, bound_type: None, - return_type: None, + overload_call_return_type: None, overloads: smallvec![], } } @@ -1176,7 +1186,7 @@ impl<'db> CallableBinding<'db> { } }; - let snapshotter = MatchingOverloadsSnapshotter::new(matching_overload_indexes); + let snapshotter = CallableBindingSnapshotter::new(matching_overload_indexes); // State of the bindings _before_ evaluating (type checking) the matching overloads using // the non-expanded argument types. @@ -1196,9 +1206,13 @@ impl<'db> CallableBinding<'db> { // If only one overload evaluates without error, it is the winning match. return; } - MatchingOverloadIndex::Multiple(_) => { + MatchingOverloadIndex::Multiple(indexes) => { // If two or more candidate overloads remain, proceed to step 4. - // TODO: Step 4 and Step 5 goes here... + // TODO: Step 4 + + // Step 5 + self.filter_overloads_using_any_or_unknown(db, argument_types.types(), &indexes); + // We're returning here because this shouldn't lead to argument type expansion. return; } @@ -1225,7 +1239,7 @@ impl<'db> CallableBinding<'db> { // This is the merged state of the bindings after evaluating all of the expanded // argument lists. This will be the final state to restore the bindings to if all of // the expanded argument lists evaluated successfully. - let mut merged_evaluation_state: Option> = None; + let mut merged_evaluation_state: Option> = None; let mut return_types = Vec::new(); @@ -1241,10 +1255,16 @@ impl<'db> CallableBinding<'db> { MatchingOverloadIndex::Single(index) => { Some(self.overloads[index].return_type()) } - MatchingOverloadIndex::Multiple(index) => { - // TODO: Step 4 and Step 5 goes here... but for now we just use the return - // type of the first matched overload. - Some(self.overloads[index[0]].return_type()) + MatchingOverloadIndex::Multiple(matching_overload_indexes) => { + // TODO: Step 4 + + self.filter_overloads_using_any_or_unknown( + db, + expanded_argument_types, + &matching_overload_indexes, + ); + + Some(self.return_type()) } }; @@ -1274,17 +1294,23 @@ impl<'db> CallableBinding<'db> { } if return_types.len() == expanded_argument_lists.len() { - // If the number of return types is equal to the number of expanded argument lists, - // they all evaluated successfully. So, we need to combine their return types by - // union to determine the final return type. - self.return_type = Some(UnionType::from_elements(db, return_types)); - // Restore the bindings state to the one that merges the bindings state evaluating // each of the expanded argument list. + // + // Note that this needs to happen *before* setting the return type, because this + // will restore the return type to the one before argument type expansion. if let Some(merged_evaluation_state) = merged_evaluation_state { snapshotter.restore(self, merged_evaluation_state); } + // If the number of return types is equal to the number of expanded argument lists, + // they all evaluated successfully. So, we need to combine their return types by + // union to determine the final return type. + self.overload_call_return_type = + Some(OverloadCallReturnType::ArgumentTypeExpansion( + UnionType::from_elements(db, return_types), + )); + return; } } @@ -1296,6 +1322,137 @@ impl<'db> CallableBinding<'db> { snapshotter.restore(self, post_evaluation_snapshot); } + /// Filter overloads based on [`Any`] or [`Unknown`] argument types. + /// + /// This is the step 5 of the [overload call evaluation algorithm][1]. + /// + /// The filtering works on the remaining overloads that are present at the + /// `matching_overload_indexes` and are filtered out by marking them as unmatched overloads + /// using the [`mark_as_unmatched_overload`] method. + /// + /// [`Any`]: crate::types::DynamicType::Any + /// [`Unknown`]: crate::types::DynamicType::Unknown + /// [`mark_as_unmatched_overload`]: Binding::mark_as_unmatched_overload + /// [1]: https://typing.python.org/en/latest/spec/overload.html#overload-call-evaluation + fn filter_overloads_using_any_or_unknown( + &mut self, + db: &'db dyn Db, + argument_types: &[Type<'db>], + matching_overload_indexes: &[usize], + ) { + // These are the parameter indexes that matches the arguments that participate in the + // filtering process. + // + // The parameter types at these indexes have at least one overload where the type isn't + // gradual equivalent to the parameter types at the same index for other overloads. + let mut participating_parameter_indexes = HashSet::new(); + + // These only contain the top materialized argument types for the corresponding + // participating parameter indexes. + let mut top_materialized_argument_types = vec![]; + + for (argument_index, argument_type) in argument_types.iter().enumerate() { + let mut first_parameter_type: Option> = None; + let mut participating_parameter_index = None; + + for overload_index in matching_overload_indexes { + let overload = &self.overloads[*overload_index]; + let Some(parameter_index) = overload.argument_parameters[argument_index] else { + // There is no parameter for this argument in this overload. + break; + }; + // TODO: For an unannotated `self` / `cls` parameter, the type should be + // `typing.Self` / `type[typing.Self]` + let current_parameter_type = overload.signature.parameters()[parameter_index] + .annotated_type() + .unwrap_or(Type::unknown()); + if let Some(first_parameter_type) = first_parameter_type { + if !first_parameter_type.is_gradual_equivalent_to(db, current_parameter_type) { + participating_parameter_index = Some(parameter_index); + break; + } + } else { + first_parameter_type = Some(current_parameter_type); + } + } + + if let Some(parameter_index) = participating_parameter_index { + participating_parameter_indexes.insert(parameter_index); + top_materialized_argument_types.push(argument_type.top_materialization(db)); + } + } + + let top_materialized_argument_type = + TupleType::from_elements(db, top_materialized_argument_types); + + // A flag to indicate whether we've found the overload that makes the remaining overloads + // unmatched for the given argument types. + let mut filter_remaining_overloads = false; + + for (upto, current_index) in matching_overload_indexes.iter().enumerate() { + if filter_remaining_overloads { + self.overloads[*current_index].mark_as_unmatched_overload(); + continue; + } + let mut parameter_types = Vec::with_capacity(argument_types.len()); + for argument_index in 0..argument_types.len() { + // The parameter types at the current argument index. + let mut current_parameter_types = vec![]; + for overload_index in &matching_overload_indexes[..=upto] { + let overload = &self.overloads[*overload_index]; + let Some(parameter_index) = overload.argument_parameters[argument_index] else { + // There is no parameter for this argument in this overload. + continue; + }; + if !participating_parameter_indexes.contains(¶meter_index) { + // This parameter doesn't participate in the filtering process. + continue; + } + // TODO: For an unannotated `self` / `cls` parameter, the type should be + // `typing.Self` / `type[typing.Self]` + let parameter_type = overload.signature.parameters()[parameter_index] + .annotated_type() + .unwrap_or(Type::unknown()); + current_parameter_types.push(parameter_type); + } + if current_parameter_types.is_empty() { + continue; + } + parameter_types.push(UnionType::from_elements(db, current_parameter_types)); + } + if top_materialized_argument_type + .is_assignable_to(db, TupleType::from_elements(db, parameter_types)) + { + filter_remaining_overloads = true; + } + } + + // Once this filtering process is applied for all arguments, examine the return types of + // the remaining overloads. If the resulting return types for all remaining overloads are + // equivalent, proceed to step 6. + let are_return_types_equivalent_for_all_matching_overloads = { + let mut matching_overloads = self.matching_overloads(); + if let Some(first_overload_return_type) = matching_overloads + .next() + .map(|(_, overload)| overload.return_type()) + { + matching_overloads.all(|(_, overload)| { + overload + .return_type() + .is_equivalent_to(db, first_overload_return_type) + }) + } else { + // No matching overload + true + } + }; + + if !are_return_types_equivalent_for_all_matching_overloads { + // Overload matching is ambiguous. + self.overload_call_return_type = Some(OverloadCallReturnType::Ambiguous); + } + } + fn as_result(&self) -> Result<(), CallErrorKind> { if !self.is_callable() { return Err(CallErrorKind::NotCallable); @@ -1370,8 +1527,11 @@ impl<'db> CallableBinding<'db> { /// For an invalid call to an overloaded function, we return `Type::unknown`, since we cannot /// make any useful conclusions about which overload was intended to be called. pub(crate) fn return_type(&self) -> Type<'db> { - if let Some(return_type) = self.return_type { - return return_type; + if let Some(overload_call_return_type) = self.overload_call_return_type { + return match overload_call_return_type { + OverloadCallReturnType::ArgumentTypeExpansion(return_type) => return_type, + OverloadCallReturnType::Ambiguous => Type::unknown(), + }; } if let Some((_, first_overload)) = self.matching_overloads().next() { return first_overload.return_type(); @@ -1521,6 +1681,12 @@ impl<'a, 'db> IntoIterator for &'a CallableBinding<'db> { } } +#[derive(Debug, Copy, Clone)] +enum OverloadCallReturnType<'db> { + ArgumentTypeExpansion(Type<'db>), + Ambiguous, +} + #[derive(Debug)] enum MatchingOverloadIndex { /// No matching overloads found. @@ -1855,6 +2021,11 @@ impl<'db> Binding<'db> { .map(|(arg_and_type, _)| arg_and_type) } + /// Mark this overload binding as an unmatched overload. + fn mark_as_unmatched_overload(&mut self) { + self.errors.push(BindingError::UnmatchedOverload); + } + fn report_diagnostics( &self, context: &InferContext<'db, '_>, @@ -1915,23 +2086,27 @@ struct BindingSnapshot<'db> { errors: Vec>, } -/// Represents the snapshot of the matched overload bindings. -/// -/// The reason that this only contains the matched overloads are: -/// 1. Avoid creating snapshots for the overloads that have been filtered by the arity check -/// 2. Avoid duplicating errors when merging the snapshots on a successful evaluation of all the -/// expanded argument lists #[derive(Clone, Debug)] -struct MatchingOverloadsSnapshot<'db>(Vec<(usize, BindingSnapshot<'db>)>); +struct CallableBindingSnapshot<'db> { + overload_return_type: Option>, -impl<'db> MatchingOverloadsSnapshot<'db> { + /// Represents the snapshot of the matched overload bindings. + /// + /// The reason that this only contains the matched overloads are: + /// 1. Avoid creating snapshots for the overloads that have been filtered by the arity check + /// 2. Avoid duplicating errors when merging the snapshots on a successful evaluation of all + /// the expanded argument lists + matching_overloads: Vec<(usize, BindingSnapshot<'db>)>, +} + +impl<'db> CallableBindingSnapshot<'db> { /// Update the state of the matched overload bindings in this snapshot with the current /// state in the given `binding`. fn update(&mut self, binding: &CallableBinding<'db>) { // Here, the `snapshot` is the state of this binding for the previous argument list and // `binding` would contain the state after evaluating the current argument list. for (snapshot, binding) in self - .0 + .matching_overloads .iter_mut() .map(|(index, snapshot)| (snapshot, &binding.overloads[*index])) { @@ -1967,13 +2142,13 @@ impl<'db> MatchingOverloadsSnapshot<'db> { /// A helper to take snapshots of the matched overload bindings for the current state of the /// bindings. -struct MatchingOverloadsSnapshotter(Vec); +struct CallableBindingSnapshotter(Vec); -impl MatchingOverloadsSnapshotter { +impl CallableBindingSnapshotter { /// Creates a new snapshotter for the given indexes of the matched overloads. fn new(indexes: Vec) -> Self { debug_assert!(indexes.len() > 1); - MatchingOverloadsSnapshotter(indexes) + CallableBindingSnapshotter(indexes) } /// Takes a snapshot of the current state of the matched overload bindings. @@ -1981,23 +2156,26 @@ impl MatchingOverloadsSnapshotter { /// # Panics /// /// Panics if the indexes of the matched overloads are not valid for the given binding. - fn take<'db>(&self, binding: &CallableBinding<'db>) -> MatchingOverloadsSnapshot<'db> { - MatchingOverloadsSnapshot( - self.0 + fn take<'db>(&self, binding: &CallableBinding<'db>) -> CallableBindingSnapshot<'db> { + CallableBindingSnapshot { + overload_return_type: binding.overload_call_return_type, + matching_overloads: self + .0 .iter() .map(|index| (*index, binding.overloads[*index].snapshot())) .collect(), - ) + } } /// Restores the state of the matched overload bindings from the given snapshot. fn restore<'db>( &self, binding: &mut CallableBinding<'db>, - snapshot: MatchingOverloadsSnapshot<'db>, + snapshot: CallableBindingSnapshot<'db>, ) { - debug_assert_eq!(self.0.len(), snapshot.0.len()); - for (index, snapshot) in snapshot.0 { + debug_assert_eq!(self.0.len(), snapshot.matching_overloads.len()); + binding.overload_call_return_type = snapshot.overload_return_type; + for (index, snapshot) in snapshot.matching_overloads { binding.overloads[index].restore(snapshot); } } @@ -2140,6 +2318,9 @@ pub(crate) enum BindingError<'db> { /// We use this variant to report errors in `property.__get__` and `property.__set__`, which /// can occur when the call to the underlying getter/setter fails. InternalCallError(&'static str), + /// This overload binding of the callable does not match the arguments. + // TODO: We could expand this with an enum to specify why the overload is unmatched. + UnmatchedOverload, } impl<'db> BindingError<'db> { @@ -2332,6 +2513,8 @@ impl<'db> BindingError<'db> { } } } + + Self::UnmatchedOverload => {} } }