diff --git a/crates/ty_python_semantic/resources/mdtest/assignment/annotations.md b/crates/ty_python_semantic/resources/mdtest/assignment/annotations.md index ceb588d7fe5b8d..d97092720baca0 100644 --- a/crates/ty_python_semantic/resources/mdtest/assignment/annotations.md +++ b/crates/ty_python_semantic/resources/mdtest/assignment/annotations.md @@ -417,6 +417,8 @@ reveal_type(x) # revealed: Literal[1] python-version = "3.12" ``` +`generic_list.py`: + ```py from typing import Literal @@ -427,14 +429,13 @@ a = f("a") reveal_type(a) # revealed: list[Literal["a"]] b: list[int | Literal["a"]] = f("a") -reveal_type(b) # revealed: list[Literal["a"] | int] +reveal_type(b) # revealed: list[int | Literal["a"]] c: list[int | str] = f("a") -reveal_type(c) # revealed: list[str | int] +reveal_type(c) # revealed: list[int | str] d: list[int | tuple[int, int]] = f((1, 2)) -# TODO: We could avoid reordering the union elements here. -reveal_type(d) # revealed: list[tuple[int, int] | int] +reveal_type(d) # revealed: list[int | tuple[int, int]] e: list[int] = f(True) reveal_type(e) # revealed: list[int] @@ -455,10 +456,218 @@ j: int | str = f2(True) reveal_type(j) # revealed: Literal[True] ``` -Types are not widened unnecessarily: +A function's arguments are also inferred using the type context: + +`typed_dict.py`: + +```py +from typing import TypedDict + +class TD(TypedDict): + x: int + +def f[T](x: list[T]) -> T: + return x[0] + +a: TD = f([{"x": 0}, {"x": 1}]) +reveal_type(a) # revealed: TD + +b: TD | None = f([{"x": 0}, {"x": 1}]) +reveal_type(b) # revealed: TD + +# error: [missing-typed-dict-key] "Missing required key 'x' in TypedDict `TD` constructor" +# error: [invalid-key] "Invalid key for TypedDict `TD`: Unknown key "y"" +# error: [invalid-assignment] "Object of type `Unknown | dict[Unknown | str, Unknown | int]` is not assignable to `TD`" +c: TD = f([{"y": 0}, {"x": 1}]) + +# error: [missing-typed-dict-key] "Missing required key 'x' in TypedDict `TD` constructor" +# error: [invalid-key] "Invalid key for TypedDict `TD`: Unknown key "y"" +# error: [invalid-assignment] "Object of type `Unknown | dict[Unknown | str, Unknown | int]` is not assignable to `TD | None`" +c: TD | None = f([{"y": 0}, {"x": 1}]) +``` + +But not in a way that leads to assignability errors: + +`dict_any.py`: ```py -def id[T](x: T) -> T: +from typing import TypedDict, Any + +class TD(TypedDict, total=False): + x: str + +class TD2(TypedDict): + x: str + +def f(self, dt: dict[str, Any], key: str): + # TODO: This should not error once typed dict assignability is implemented. + # error: [invalid-assignment] + x1: TD = dt.get(key, {}) + reveal_type(x1) # revealed: TD + + x2: TD = dt.get(key, {"x": 0}) + reveal_type(x2) # revealed: Any + + x3: TD | None = dt.get(key, {}) + # TODO: This should reveal `Any` once typed dict assignability is implemented. + reveal_type(x3) # revealed: Any | None + + x4: TD | None = dt.get(key, {"x": 0}) + reveal_type(x4) # revealed: Any + + x5: TD2 = dt.get(key, {}) + reveal_type(x5) # revealed: Any + + x6: TD2 = dt.get(key, {"x": 0}) + reveal_type(x6) # revealed: Any + + x7: TD2 | None = dt.get(key, {}) + reveal_type(x7) # revealed: Any + + x8: TD2 | None = dt.get(key, {"x": 0}) + reveal_type(x8) # revealed: Any +``` + +## Prefer the declared type of generic classes + +```toml +[environment] +python-version = "3.12" +``` + +```py +from typing import Any + +def f[T](x: T) -> list[T]: + return [x] + +def f2[T](x: T) -> list[T] | None: + return [x] + +def f3[T](x: T) -> list[T] | dict[T, T]: + return [x] + +a = f(1) +reveal_type(a) # revealed: list[Literal[1]] + +b: list[Any] = f(1) +reveal_type(b) # revealed: list[Any] + +c: list[Any] = [1] +reveal_type(c) # revealed: list[Any] + +d: list[Any] | None = f(1) +reveal_type(d) # revealed: list[Any] + +e: list[Any] | None = [1] +reveal_type(e) # revealed: list[Any] + +f: list[Any] | None = f2(1) +# TODO: Better constraint solver. +reveal_type(f) # revealed: list[Literal[1]] | None + +g: list[Any] | dict[Any, Any] = f3(1) +# TODO: Better constraint solver. +reveal_type(g) # revealed: list[Literal[1]] | dict[Literal[1], Literal[1]] +``` + +We currently prefer the generic declared type regardless of its variance: + +```py +class Bivariant[T]: + pass + +class Covariant[T]: + def pop(self) -> T: + raise NotImplementedError + +class Contravariant[T]: + def push(self, value: T) -> None: + pass + +class Invariant[T]: + x: T + +def bivariant[T](x: T) -> Bivariant[T]: + return Bivariant() + +def covariant[T](x: T) -> Covariant[T]: + return Covariant() + +def contravariant[T](x: T) -> Contravariant[T]: + return Contravariant() + +def invariant[T](x: T) -> Invariant[T]: + return Invariant() + +x1 = bivariant(1) +x2 = covariant(1) +x3 = contravariant(1) +x4 = invariant(1) + +reveal_type(x1) # revealed: Bivariant[Literal[1]] +reveal_type(x2) # revealed: Covariant[Literal[1]] +reveal_type(x3) # revealed: Contravariant[Literal[1]] +reveal_type(x4) # revealed: Invariant[Literal[1]] + +x5: Bivariant[Any] = bivariant(1) +x6: Covariant[Any] = covariant(1) +x7: Contravariant[Any] = contravariant(1) +x8: Invariant[Any] = invariant(1) + +# TODO: This could reveal `Bivariant[Any]`. +reveal_type(x5) # revealed: Bivariant[Literal[1]] +reveal_type(x6) # revealed: Covariant[Any] +reveal_type(x7) # revealed: Contravariant[Any] +reveal_type(x8) # revealed: Invariant[Any] +``` + +## Narrow generic unions + +```toml +[environment] +python-version = "3.12" +``` + +```py +from typing import reveal_type, TypedDict + +def identity[T](x: T) -> T: + return x + +def _(narrow: dict[str, str], target: list[str] | dict[str, str] | None): + target = identity(narrow) + reveal_type(target) # revealed: dict[str, str] + +def _(narrow: list[str], target: list[str] | dict[str, str] | None): + target = identity(narrow) + reveal_type(target) # revealed: list[str] + +def _(narrow: list[str] | dict[str, str], target: list[str] | dict[str, str] | None): + target = identity(narrow) + reveal_type(target) # revealed: list[str] | dict[str, str] + +class TD(TypedDict): + x: int + +def _(target: list[TD] | dict[str, TD] | None): + target = identity([{"x": 1}]) + reveal_type(target) # revealed: list[TD] + +def _(target: list[TD] | dict[str, TD] | None): + target = identity({"x": {"x": 1}}) + reveal_type(target) # revealed: dict[str, TD] +``` + +## Prefer the inferred type of non-generic classes + +```toml +[environment] +python-version = "3.12" +``` + +```py +def identity[T](x: T) -> T: return x def lst[T](x: T) -> list[T]: @@ -466,20 +675,18 @@ def lst[T](x: T) -> list[T]: def _(i: int): a: int | None = i - b: int | None = id(i) - c: int | str | None = id(i) + b: int | None = identity(i) + c: int | str | None = identity(i) reveal_type(a) # revealed: int reveal_type(b) # revealed: int reveal_type(c) # revealed: int a: list[int | None] | None = [i] - b: list[int | None] | None = id([i]) - c: list[int | None] | int | None = id([i]) + b: list[int | None] | None = identity([i]) + c: list[int | None] | int | None = identity([i]) reveal_type(a) # revealed: list[int | None] - # TODO: these should reveal `list[int | None]` - # we currently do not use the call expression annotation as type context for argument inference - reveal_type(b) # revealed: list[Unknown | int] - reveal_type(c) # revealed: list[Unknown | int] + reveal_type(b) # revealed: list[int | None] + reveal_type(c) # revealed: list[int | None] a: list[int | None] | None = [i] b: list[int | None] | None = lst(i) @@ -489,9 +696,44 @@ def _(i: int): reveal_type(c) # revealed: list[int | None] a: list | None = [] - b: list | None = id([]) - c: list | int | None = id([]) + b: list | None = identity([]) + c: list | int | None = identity([]) reveal_type(a) # revealed: list[Unknown] reveal_type(b) # revealed: list[Unknown] reveal_type(c) # revealed: list[Unknown] + +def f[T](x: list[T]) -> T: + return x[0] + +def _(a: int, b: str, c: int | str): + x1: int = f(lst(a)) + reveal_type(x1) # revealed: int + + x2: int | str = f(lst(a)) + reveal_type(x2) # revealed: int + + x3: int | None = f(lst(a)) + reveal_type(x3) # revealed: int + + x4: str = f(lst(b)) + reveal_type(x4) # revealed: str + + x5: int | str = f(lst(b)) + reveal_type(x5) # revealed: str + + x6: str | None = f(lst(b)) + reveal_type(x6) # revealed: str + + x7: int | str = f(lst(c)) + reveal_type(x7) # revealed: int | str + + x8: int | str = f(lst(c)) + reveal_type(x8) # revealed: int | str + + # TODO: Ideally this would reveal `int | str`. This is a known limitation of our + # call inference solver, and would # require an extra inference attempt without type + # context, or with type context # of subsets of the union, both of which are impractical + # for performance reasons. + x9: int | str | None = f(lst(c)) + reveal_type(x9) # revealed: int | str | None ``` diff --git a/crates/ty_python_semantic/resources/mdtest/bidirectional.md b/crates/ty_python_semantic/resources/mdtest/bidirectional.md index 6b908737282bb1..1211f92fe577a8 100644 --- a/crates/ty_python_semantic/resources/mdtest/bidirectional.md +++ b/crates/ty_python_semantic/resources/mdtest/bidirectional.md @@ -50,8 +50,8 @@ def _(l: list[int] | None = None): def f[T](x: T, cond: bool) -> T | list[T]: return x if cond else [x] -# TODO: no error -# error: [invalid-assignment] "Object of type `Literal[1] | list[Literal[1]]` is not assignable to `int | list[int]`" +# TODO: Better constraint solver. +# error: [invalid-assignment] l5: int | list[int] = f(1, True) ``` diff --git a/crates/ty_python_semantic/resources/mdtest/dataclasses/fields.md b/crates/ty_python_semantic/resources/mdtest/dataclasses/fields.md index f091a1c9910b4b..28a69081e57cf0 100644 --- a/crates/ty_python_semantic/resources/mdtest/dataclasses/fields.md +++ b/crates/ty_python_semantic/resources/mdtest/dataclasses/fields.md @@ -37,7 +37,7 @@ class Data: content: list[int] = field(default_factory=list) timestamp: datetime = field(default_factory=datetime.now, init=False) -# revealed: (self: Data, content: list[int] = Unknown) -> None +# revealed: (self: Data, content: list[int] = list[int]) -> None reveal_type(Data.__init__) data = Data([1, 2, 3]) @@ -63,7 +63,6 @@ class Person: age: int | None = field(default=None, kw_only=True) role: str = field(default="user", kw_only=True) -# TODO: this would ideally show a default value of `None` for `age` # revealed: (self: Person, name: str, *, age: int | None = None, role: str = Literal["user"]) -> None reveal_type(Person.__init__) diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index 6c9cdefa20b06a..15df7b4449a75b 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -885,20 +885,31 @@ impl<'db> Type<'db> { } } - // If the type is a specialized instance of the given `KnownClass`, returns the specialization. + /// If the type is a specialized instance of the given `KnownClass`, returns the specialization. pub(crate) fn known_specialization( &self, db: &'db dyn Db, known_class: KnownClass, ) -> Option> { let class_literal = known_class.try_to_class_literal(db)?; - self.specialization_of(db, Some(class_literal)) + self.specialization_of(db, class_literal) } - // If the type is a specialized instance of the given class, returns the specialization. - // - // If no class is provided, returns the specialization of any class instance. + /// If this type is a class instance, returns its specialization. + pub(crate) fn class_specialization(self, db: &'db dyn Db) -> Option> { + self.specialization_of_optional(db, None) + } + + /// If the type is a specialized instance of the given class, returns the specialization. pub(crate) fn specialization_of( + self, + db: &'db dyn Db, + expected_class: ClassLiteral<'_>, + ) -> Option> { + self.specialization_of_optional(db, Some(expected_class)) + } + + fn specialization_of_optional( self, db: &'db dyn Db, expected_class: Option>, @@ -5578,7 +5589,7 @@ impl<'db> Type<'db> { ) -> Result, CallError<'db>> { self.bindings(db) .match_parameters(db, argument_types) - .check_types(db, argument_types, &TypeContext::default(), &[]) + .check_types(db, argument_types, TypeContext::default(), &[]) } /// Look up a dunder method on the meta-type of `self` and call it. @@ -5630,7 +5641,8 @@ impl<'db> Type<'db> { let bindings = dunder_callable .bindings(db) .match_parameters(db, argument_types) - .check_types(db, argument_types, &tcx, &[])?; + .check_types(db, argument_types, tcx, &[])?; + if boundness == Definedness::PossiblyUndefined { return Err(CallDunderError::PossiblyUnbound(Box::new(bindings))); } diff --git a/crates/ty_python_semantic/src/types/call/bind.rs b/crates/ty_python_semantic/src/types/call/bind.rs index 60868cfc3f33f0..db1d06c32ef3f8 100644 --- a/crates/ty_python_semantic/src/types/call/bind.rs +++ b/crates/ty_python_semantic/src/types/call/bind.rs @@ -35,11 +35,11 @@ use crate::types::generics::{ use crate::types::signatures::{Parameter, ParameterForm, ParameterKind, Parameters}; use crate::types::tuple::{TupleLength, TupleType}; use crate::types::{ - BoundMethodType, ClassLiteral, DataclassFlags, DataclassParams, FieldInstance, - KnownBoundMethodType, KnownClass, KnownInstanceType, MemberLookupPolicy, NominalInstanceType, - PropertyInstanceType, SpecialFormType, TrackedConstraintSet, TypeAliasType, TypeContext, - UnionBuilder, UnionType, WrapperDescriptorKind, enums, ide_support, infer_isolated_expression, - todo_type, + BoundMethodType, BoundTypeVarIdentity, ClassLiteral, DataclassFlags, DataclassParams, + FieldInstance, KnownBoundMethodType, KnownClass, KnownInstanceType, MemberLookupPolicy, + NominalInstanceType, PropertyInstanceType, SpecialFormType, TrackedConstraintSet, + TypeAliasType, TypeContext, UnionBuilder, UnionType, WrapperDescriptorKind, enums, ide_support, + infer_isolated_expression, todo_type, }; use ruff_db::diagnostic::{Annotation, Diagnostic, SubDiagnostic, SubDiagnosticSeverity}; use ruff_python_ast::{self as ast, ArgOrKeyword, PythonVersion}; @@ -48,7 +48,7 @@ use ruff_python_ast::{self as ast, ArgOrKeyword, PythonVersion}; /// compatible with _all_ of the types in the union for the call to be valid. /// /// It's guaranteed that the wrapped bindings have no errors. -#[derive(Debug)] +#[derive(Debug, Clone)] pub(crate) struct Bindings<'db> { /// The type that is (hopefully) callable. callable_type: Type<'db>, @@ -150,9 +150,27 @@ impl<'db> Bindings<'db> { mut self, db: &'db dyn Db, argument_types: &CallArguments<'_, 'db>, - call_expression_tcx: &TypeContext<'db>, + call_expression_tcx: TypeContext<'db>, dataclass_field_specifiers: &[Type<'db>], ) -> Result> { + match self.check_types_impl( + db, + argument_types, + call_expression_tcx, + dataclass_field_specifiers, + ) { + Ok(()) => Ok(self), + Err(err) => Err(CallError(err, Box::new(self))), + } + } + + pub(crate) fn check_types_impl( + &mut self, + db: &'db dyn Db, + argument_types: &CallArguments<'_, 'db>, + call_expression_tcx: TypeContext<'db>, + dataclass_field_specifiers: &[Type<'db>], + ) -> Result<(), CallErrorKind> { for element in &mut self.elements { if let Some(mut updated_argument_forms) = element.check_types(db, argument_types, call_expression_tcx) @@ -197,16 +215,13 @@ impl<'db> Bindings<'db> { } if all_ok { - Ok(self) + Ok(()) } else if any_binding_error { - Err(CallError(CallErrorKind::BindingError, Box::new(self))) + Err(CallErrorKind::BindingError) } else if all_not_callable { - Err(CallError(CallErrorKind::NotCallable, Box::new(self))) + Err(CallErrorKind::NotCallable) } else { - Err(CallError( - CallErrorKind::PossiblyNotCallable, - Box::new(self), - )) + Err(CallErrorKind::PossiblyNotCallable) } } @@ -1365,7 +1380,7 @@ impl<'db> From> for Bindings<'db> { /// If the arguments cannot be matched to formal parameters, we store information about the /// specific errors that occurred when trying to match them up. If the callable has multiple /// overloads, we store this error information for each overload. -#[derive(Debug)] +#[derive(Debug, Clone)] pub(crate) struct CallableBinding<'db> { /// The type that is (hopefully) callable. pub(crate) callable_type: Type<'db>, @@ -1486,7 +1501,7 @@ impl<'db> CallableBinding<'db> { &mut self, db: &'db dyn Db, argument_types: &CallArguments<'_, 'db>, - call_expression_tcx: &TypeContext<'db>, + call_expression_tcx: TypeContext<'db>, ) -> Option { // If this callable is a bound method, prepend the self instance onto the arguments list // before checking. @@ -2267,7 +2282,7 @@ pub(crate) enum MatchingOverloadIndex { Multiple(Vec), } -#[derive(Default, Debug)] +#[derive(Default, Debug, Clone)] struct ArgumentForms { values: Vec>, conflicting: Vec, @@ -2672,7 +2687,7 @@ struct ArgumentTypeChecker<'a, 'db> { arguments: &'a CallArguments<'a, 'db>, argument_matches: &'a [MatchedArgument<'db>], parameter_tys: &'a mut [Option>], - call_expression_tcx: &'a TypeContext<'db>, + call_expression_tcx: TypeContext<'db>, return_ty: Type<'db>, errors: &'a mut Vec>, @@ -2688,7 +2703,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> { arguments: &'a CallArguments<'a, 'db>, argument_matches: &'a [MatchedArgument<'db>], parameter_tys: &'a mut [Option>], - call_expression_tcx: &'a TypeContext<'db>, + call_expression_tcx: TypeContext<'db>, return_ty: Type<'db>, errors: &'a mut Vec>, ) -> Self { @@ -2738,9 +2753,21 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> { return; }; + let return_with_tcx = self + .signature + .return_ty + .zip(self.call_expression_tcx.annotation); + self.inferable_typevars = generic_context.inferable_typevars(self.db); let mut builder = SpecializationBuilder::new(self.db, self.inferable_typevars); + // Prefer the declared type of generic classes. + let preferred_type_mappings = return_with_tcx.and_then(|(return_ty, tcx)| { + tcx.class_specialization(self.db)?; + builder.infer(return_ty, tcx).ok()?; + Some(builder.type_mappings().clone()) + }); + let parameters = self.signature.parameters(); for (argument_index, adjusted_argument_index, _, argument_type) in self.enumerate_argument_types() @@ -2753,9 +2780,21 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> { continue; }; - if let Err(error) = builder.infer( + let filter = |declared_ty: BoundTypeVarIdentity<'_>, inferred_ty: Type<'_>| { + // Avoid widening the inferred type if it is already assignable to the + // preferred declared type. + preferred_type_mappings + .as_ref() + .and_then(|types| types.get(&declared_ty)) + .is_none_or(|preferred_ty| { + !inferred_ty.is_assignable_to(self.db, *preferred_ty) + }) + }; + + if let Err(error) = builder.infer_filter( expected_type, variadic_argument_type.unwrap_or(argument_type), + filter, ) { self.errors.push(BindingError::SpecializationError { error, @@ -2765,15 +2804,14 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> { } } - // Build the specialization first without inferring the type context. - let isolated_specialization = builder.build(generic_context, *self.call_expression_tcx); + // Build the specialization first without inferring the complete type context. + let isolated_specialization = builder.build(generic_context, self.call_expression_tcx); let isolated_return_ty = self .return_ty .apply_specialization(self.db, isolated_specialization); let mut try_infer_tcx = || { - let return_ty = self.signature.return_ty?; - let call_expression_tcx = self.call_expression_tcx.annotation?; + let (return_ty, call_expression_tcx) = return_with_tcx?; // A type variable is not a useful type-context for expression inference, and applying it // to the return type can lead to confusing unions in nested generic calls. @@ -2781,8 +2819,8 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> { return None; } - // If the return type is already assignable to the annotated type, we can ignore the - // type context and prefer the narrower inferred type. + // If the return type is already assignable to the annotated type, we ignore the rest of + // the type context and prefer the narrower inferred type. if isolated_return_ty.is_assignable_to(self.db, call_expression_tcx) { return None; } @@ -2791,8 +2829,8 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> { // annotated assignment, to closer match the order of any unions written in the type annotation. builder.infer(return_ty, call_expression_tcx).ok()?; - // Otherwise, build the specialization again after inferring the type context. - let specialization = builder.build(generic_context, *self.call_expression_tcx); + // Otherwise, build the specialization again after inferring the complete type context. + let specialization = builder.build(generic_context, self.call_expression_tcx); let return_ty = return_ty.apply_specialization(self.db, specialization); Some((Some(specialization), return_ty)) @@ -3051,7 +3089,7 @@ impl<'db> MatchedArgument<'db> { pub(crate) struct UnknownParameterNameError; /// Binding information for one of the overloads of a callable. -#[derive(Debug)] +#[derive(Debug, Clone)] pub(crate) struct Binding<'db> { pub(crate) signature: Signature<'db>, @@ -3150,7 +3188,7 @@ impl<'db> Binding<'db> { &mut self, db: &'db dyn Db, arguments: &CallArguments<'_, 'db>, - call_expression_tcx: &TypeContext<'db>, + call_expression_tcx: TypeContext<'db>, ) { let mut checker = ArgumentTypeChecker::new( db, diff --git a/crates/ty_python_semantic/src/types/class.rs b/crates/ty_python_semantic/src/types/class.rs index de0b4c180ca3be..9ed36b76b27347 100644 --- a/crates/ty_python_semantic/src/types/class.rs +++ b/crates/ty_python_semantic/src/types/class.rs @@ -258,7 +258,7 @@ impl<'db> GenericAlias<'db> { ) -> Self { let tcx = tcx .annotation - .and_then(|ty| ty.specialization_of(db, Some(self.origin(db)))) + .and_then(|ty| ty.specialization_of(db, self.origin(db))) .map(|specialization| specialization.types(db)) .unwrap_or(&[]); diff --git a/crates/ty_python_semantic/src/types/generics.rs b/crates/ty_python_semantic/src/types/generics.rs index 992e6644017626..cc7360541cb948 100644 --- a/crates/ty_python_semantic/src/types/generics.rs +++ b/crates/ty_python_semantic/src/types/generics.rs @@ -1,4 +1,5 @@ use std::cell::RefCell; +use std::collections::hash_map::Entry; use std::fmt::Display; use itertools::Itertools; @@ -1324,6 +1325,11 @@ impl<'db> SpecializationBuilder<'db> { } } + /// Returns the current set of type mappings for this specialization. + pub(crate) fn type_mappings(&self) -> &FxHashMap, Type<'db>> { + &self.types + } + pub(crate) fn build( &mut self, generic_context: GenericContext<'db>, @@ -1331,7 +1337,7 @@ impl<'db> SpecializationBuilder<'db> { ) -> Specialization<'db> { let tcx_specialization = tcx .annotation - .and_then(|annotation| annotation.specialization_of(self.db, None)); + .and_then(|annotation| annotation.class_specialization(self.db)); let types = (generic_context.variables_inner(self.db).iter()).map(|(identity, variable)| { @@ -1354,19 +1360,43 @@ impl<'db> SpecializationBuilder<'db> { generic_context.specialize_partial(self.db, types) } - fn add_type_mapping(&mut self, bound_typevar: BoundTypeVarInstance<'db>, ty: Type<'db>) { - self.types - .entry(bound_typevar.identity(self.db)) - .and_modify(|existing| { - *existing = UnionType::from_elements(self.db, [*existing, ty]); - }) - .or_insert(ty); + fn add_type_mapping( + &mut self, + bound_typevar: BoundTypeVarInstance<'db>, + ty: Type<'db>, + filter: impl Fn(BoundTypeVarIdentity<'db>, Type<'db>) -> bool, + ) { + let identity = bound_typevar.identity(self.db); + match self.types.entry(identity) { + Entry::Occupied(mut entry) => { + if filter(identity, ty) { + *entry.get_mut() = UnionType::from_elements(self.db, [*entry.get(), ty]); + } + } + Entry::Vacant(entry) => { + entry.insert(ty); + } + } } + /// Infer type mappings for the specialization based on a given type and its declared type. pub(crate) fn infer( &mut self, formal: Type<'db>, actual: Type<'db>, + ) -> Result<(), SpecializationError<'db>> { + self.infer_filter(formal, actual, |_, _| true) + } + + /// Infer type mappings for the specialization based on a given type and its declared type. + /// + /// The filter predicate is provided with a type variable and the type being mapped to it. Type + /// mappings to which the predicate returns `false` will be ignored. + pub(crate) fn infer_filter( + &mut self, + formal: Type<'db>, + actual: Type<'db>, + filter: impl Fn(BoundTypeVarIdentity<'db>, Type<'db>) -> bool, ) -> Result<(), SpecializationError<'db>> { if formal == actual { return Ok(()); @@ -1400,8 +1430,8 @@ impl<'db> SpecializationBuilder<'db> { // Remove the union elements from `actual` that are not related to `formal`, and vice // versa. // - // For example, if `formal` is `list[T]` and `actual` is `list[int] | None`, we want to specialize `T` - // to `int`, and so ignore the `None`. + // For example, if `formal` is `list[T]` and `actual` is `list[int] | None`, we want to + // specialize `T` to `int`, and so ignore the `None`. let actual = actual.filter_disjoint_elements(self.db, formal, self.inferable); let formal = formal.filter_disjoint_elements(self.db, actual, self.inferable); @@ -1449,7 +1479,7 @@ impl<'db> SpecializationBuilder<'db> { if remaining_actual.is_never() { return Ok(()); } - self.add_type_mapping(*formal_bound_typevar, remaining_actual); + self.add_type_mapping(*formal_bound_typevar, remaining_actual, filter); } (Type::Union(formal), _) => { // Second, if the formal is a union, and precisely one union element _is_ a typevar (not @@ -1459,7 +1489,7 @@ impl<'db> SpecializationBuilder<'db> { let bound_typevars = (formal.elements(self.db).iter()).filter_map(|ty| ty.as_typevar()); if let Ok(bound_typevar) = bound_typevars.exactly_one() { - self.add_type_mapping(bound_typevar, actual); + self.add_type_mapping(bound_typevar, actual, filter); } } @@ -1487,13 +1517,13 @@ impl<'db> SpecializationBuilder<'db> { argument: ty, }); } - self.add_type_mapping(bound_typevar, ty); + self.add_type_mapping(bound_typevar, ty, filter); } Some(TypeVarBoundOrConstraints::Constraints(constraints)) => { // Prefer an exact match first. for constraint in constraints.elements(self.db) { if ty == *constraint { - self.add_type_mapping(bound_typevar, ty); + self.add_type_mapping(bound_typevar, ty, filter); return Ok(()); } } @@ -1503,7 +1533,7 @@ impl<'db> SpecializationBuilder<'db> { .when_assignable_to(self.db, *constraint, self.inferable) .is_always_satisfied(self.db) { - self.add_type_mapping(bound_typevar, *constraint); + self.add_type_mapping(bound_typevar, *constraint, filter); return Ok(()); } } @@ -1513,7 +1543,7 @@ impl<'db> SpecializationBuilder<'db> { }); } _ => { - self.add_type_mapping(bound_typevar, ty); + self.add_type_mapping(bound_typevar, ty, filter); } } } diff --git a/crates/ty_python_semantic/src/types/infer/builder.rs b/crates/ty_python_semantic/src/types/infer/builder.rs index 3b1142b89bbdcb..0bf922a6d43e9f 100644 --- a/crates/ty_python_semantic/src/types/infer/builder.rs +++ b/crates/ty_python_semantic/src/types/infer/builder.rs @@ -1,4 +1,4 @@ -use std::{iter, mem}; +use std::iter; use itertools::{Either, Itertools}; use ruff_db::diagnostic::{Annotation, DiagnosticId, Severity}; @@ -211,6 +211,7 @@ const NUM_FIELD_SPECIFIERS_INLINE: usize = 1; /// don't infer its types more than once. pub(super) struct TypeInferenceBuilder<'db, 'ast> { context: InferContext<'db, 'ast>, + index: &'db SemanticIndex<'db>, region: InferenceRegion<'db>, @@ -349,16 +350,19 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { assert_eq!(self.scope, inference.scope); self.expressions.extend(inference.expressions.iter()); - self.declarations.extend(inference.declarations()); + self.declarations + .extend(inference.declarations(), self.multi_inference_state); if !matches!(self.region, InferenceRegion::Scope(..)) { - self.bindings.extend(inference.bindings()); + self.bindings + .extend(inference.bindings(), self.multi_inference_state); } if let Some(extra) = &inference.extra { self.extend_cycle_recovery(extra.cycle_recovery); self.context.extend(&extra.diagnostics); - self.deferred.extend(extra.deferred.iter().copied()); + self.deferred + .extend(extra.deferred.iter().copied(), self.multi_inference_state); } } @@ -377,7 +381,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { self.extend_cycle_recovery(extra.cycle_recovery); if !matches!(self.region, InferenceRegion::Scope(..)) { - self.bindings.extend(extra.bindings.iter().copied()); + self.bindings + .extend(extra.bindings.iter().copied(), self.multi_inference_state); } } } @@ -398,6 +403,11 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { self.scope } + /// Set the multi-inference state, returning the previous value. + fn set_multi_inference_state(&mut self, state: MultiInferenceState) -> MultiInferenceState { + std::mem::replace(&mut self.multi_inference_state, state) + } + /// Are we currently inferring types in file with deferred types? /// This is true for stub files, for files with `__future__.annotations`, and /// by default for all source files in Python 3.14 and later. @@ -1637,7 +1647,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } } - self.bindings.insert(binding, bound_ty); + self.bindings + .insert(binding, bound_ty, self.multi_inference_state); inferred_ty } @@ -1704,7 +1715,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } TypeAndQualifiers::declared(Type::unknown()) }; - self.declarations.insert(declaration, ty); + self.declarations + .insert(declaration, ty, self.multi_inference_state); } fn add_declaration_with_binding( @@ -1778,8 +1790,10 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } } }; - self.declarations.insert(definition, declared_ty); - self.bindings.insert(definition, inferred_ty); + self.declarations + .insert(definition, declared_ty, self.multi_inference_state); + self.bindings + .insert(definition, inferred_ty, self.multi_inference_state); } fn add_unknown_declaration_with_binding( @@ -2198,7 +2212,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { // `infer_function_type_params`, rather than here. if type_params.is_none() { if self.defer_annotations() { - self.deferred.insert(definition); + self.deferred.insert(definition, self.multi_inference_state); } else { let previous_typevar_binding_context = self.typevar_binding_context.replace(definition); @@ -2756,7 +2770,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { // Inference of bases deferred in stubs, or if any are string literals. if self.in_stub() || class_node.bases().iter().any(contains_string_literal) { - self.deferred.insert(definition); + self.deferred.insert(definition, self.multi_inference_state); } else { let previous_typevar_binding_context = self.typevar_binding_context.replace(definition); @@ -3126,7 +3140,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { None => None, }; if bound_or_constraint.is_some() || default.is_some() { - self.deferred.insert(definition); + self.deferred.insert(definition, self.multi_inference_state); } let identity = TypeVarIdentity::new(self.db(), &name.id, Some(definition), TypeVarKind::Pep695); @@ -3190,7 +3204,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { default, } = node; if default.is_some() { - self.deferred.insert(definition); + self.deferred.insert(definition, self.multi_inference_state); } let identity = TypeVarIdentity::new( self.db(), @@ -3680,10 +3694,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { // Overwrite the previously inferred value, preferring later inferences, which are // likely more precise. Note that we still ensure each inference is assignable to // its declared type, so this mainly affects the IDE hover type. - let prev_multi_inference_state = mem::replace( - &mut builder.multi_inference_state, - MultiInferenceState::Overwrite, - ); + let prev_multi_inference_state = + builder.set_multi_inference_state(MultiInferenceState::Overwrite); // If we are inferring the argument multiple times, silence diagnostics to avoid duplicated warnings. let was_in_multi_inference = if let Some(first_tcx) = first_tcx { @@ -4625,7 +4637,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } if default.is_some() { - self.deferred.insert(definition); + self.deferred.insert(definition, self.multi_inference_state); } let identity = @@ -4867,7 +4879,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { }; if bound_or_constraints.is_some() || default.is_some() { - self.deferred.insert(definition); + self.deferred.insert(definition, self.multi_inference_state); } let identity = TypeVarIdentity::new(db, target_name, Some(definition), TypeVarKind::Legacy); @@ -5961,27 +5973,156 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } } - /// Infer the argument types for multiple potential bindings and overloads. - fn infer_all_argument_types<'a>( + fn infer_and_check_argument_types( &mut self, ast_arguments: &ast::Arguments, - arguments: &mut CallArguments<'a, 'db>, + argument_types: &mut CallArguments<'_, 'db>, + bindings: &mut Bindings<'db>, + call_expression_tcx: TypeContext<'db>, + ) -> Result<(), CallErrorKind> { + let db = self.db(); + + // If the type context is a union, attempt to narrow to a specific element. + let narrow_targets: &[_] = match call_expression_tcx.annotation { + // TODO: We could theoretically attempt to narrow to every element of + // the power set of this union. However, this leads to an exponential + // explosion of inference attempts, and is rarely needed in practice. + Some(Type::Union(union)) => union.elements(db), + _ => &[], + }; + + // We silence diagnostics until we successfully narrow to a specific type. + let mut speculated_bindings = bindings.clone(); + let was_in_multi_inference = self.context.set_multi_inference(true); + + let mut try_narrow = |narrowed_ty| { + let narrowed_tcx = TypeContext::new(Some(narrowed_ty)); + + // Attempt to infer the argument types using the narrowed type context. + self.infer_all_argument_types( + ast_arguments, + argument_types, + bindings, + narrowed_tcx, + MultiInferenceState::Ignore, + ); + + // Ensure the argument types match their annotated types. + if speculated_bindings + .check_types_impl( + db, + argument_types, + narrowed_tcx, + &self.dataclass_field_specifiers, + ) + .is_err() + { + return None; + } + + // Ensure the inferred return type is assignable to the (narrowed) declared type. + // + // TODO: Checking assignability against the full declared type could help avoid + // cases where the constraint solver is not smart enough to solve complex unions. + // We should see revisit this after the new constraint solver is implemented. + if !speculated_bindings + .return_type(db) + .is_assignable_to(db, narrowed_ty) + { + return None; + } + + // Successfully narrowed to an element of the union. + // + // If necessary, infer the argument types again with diagnostics enabled. + if !was_in_multi_inference { + self.context.set_multi_inference(was_in_multi_inference); + + self.infer_all_argument_types( + ast_arguments, + argument_types, + bindings, + narrowed_tcx, + MultiInferenceState::Intersect, + ); + } + + Some(bindings.check_types_impl( + db, + argument_types, + narrowed_tcx, + &self.dataclass_field_specifiers, + )) + }; + + // Prefer the declared type of generic classes. + for narrowed_ty in narrow_targets + .iter() + .filter(|ty| ty.class_specialization(db).is_some()) + { + if let Some(result) = try_narrow(*narrowed_ty) { + return result; + } + } + + // Try the remaining elements of the union. + // + // TODO: We could also attempt an inference without type context, but this + // leads to similar performance issues. + for narrowed_ty in narrow_targets + .iter() + .filter(|ty| ty.class_specialization(db).is_none()) + { + if let Some(result) = try_narrow(*narrowed_ty) { + return result; + } + } + + // Re-enable diagnostics, and infer against the entire union as a fallback. + self.context.set_multi_inference(was_in_multi_inference); + + self.infer_all_argument_types( + ast_arguments, + argument_types, + bindings, + call_expression_tcx, + MultiInferenceState::Intersect, + ); + + bindings.check_types_impl( + db, + argument_types, + call_expression_tcx, + &self.dataclass_field_specifiers, + ) + } + + /// Infer the argument types for all bindings. + /// + /// Note that this method may infer the type of a given argument expression multiple times with + /// distinct type context. The provided `MultiInferenceState` can be used to dictate multi-inference + /// behavior. + fn infer_all_argument_types( + &mut self, + ast_arguments: &ast::Arguments, + arguments_types: &mut CallArguments<'_, 'db>, bindings: &Bindings<'db>, + call_expression_tcx: TypeContext<'db>, + multi_inference_state: MultiInferenceState, ) { - debug_assert!( - ast_arguments.len() == arguments.len() - && arguments.len() == bindings.argument_forms().len() - ); + debug_assert_eq!(ast_arguments.len(), arguments_types.len()); + debug_assert_eq!(arguments_types.len(), bindings.argument_forms().len()); + let db = self.db(); let iter = itertools::izip!( 0.., - arguments.iter_mut(), + arguments_types.iter_mut(), bindings.argument_forms().iter().copied(), ast_arguments.arguments_source_order() ); let overloads_with_binding = bindings - .into_iter() + .iter() .filter_map(|binding| { match binding.matching_overload_index() { MatchingOverloadIndex::Single(_) | MatchingOverloadIndex::Multiple(_) => { @@ -6000,7 +6141,10 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { }, } }) - .flatten(); + .flatten() + .collect::>(); + + let old_multi_inference_state = self.set_multi_inference_state(multi_inference_state); for (argument_index, (_, argument_type), argument_form, ast_argument) in iter { let ast_argument = match ast_argument { @@ -6022,7 +6166,6 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } // Retrieve the parameter type for the current argument in a given overload and its binding. - let db = self.db(); let parameter_type = |overload: &Binding<'db>, binding: &CallableBinding<'db>| { let argument_index = if binding.bound_type.is_some() { argument_index + 1 @@ -6035,10 +6178,25 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { return None; }; - let parameter_type = + let mut parameter_type = overload.signature.parameters()[*parameter_index].annotated_type()?; - // TODO: For now, skip any parameter annotations that mention any typevars. There + // If this is a generic call, attempt to specialize the parameter type using the + // declared type context, if provided. + if let Some(generic_context) = overload.signature.generic_context + && let Some(return_ty) = overload.signature.return_ty + && let Some(declared_return_ty) = call_expression_tcx.annotation + { + let mut builder = + SpecializationBuilder::new(db, generic_context.inferable_typevars(db)); + + let _ = builder.infer(return_ty, declared_return_ty); + let specialization = builder.build(generic_context, call_expression_tcx); + + parameter_type = parameter_type.apply_specialization(db, specialization); + } + + // TODO: For now, skip any parameter annotations that still mention any typevars. There // are two issues: // // First, if we include those typevars in the type context that we use to infer the @@ -6069,26 +6227,15 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { // If there is only a single binding and overload, we can infer the argument directly with // the unique parameter type annotation. - if let Ok((overload, binding)) = overloads_with_binding.clone().exactly_one() { - self.infer_expression_impl( + if let Ok((overload, binding)) = overloads_with_binding.iter().exactly_one() { + *argument_type = Some(self.infer_expression( ast_argument, TypeContext::new(parameter_type(overload, binding)), - ); + )); } else { - // Otherwise, each type is a valid independent inference of the given argument, and we may - // require different permutations of argument types to correctly perform argument expansion - // during overload evaluation, so we take the intersection of all the types we inferred for - // each argument. - // - // Note that this applies to all nested expressions within each argument. - let old_multi_inference_state = mem::replace( - &mut self.multi_inference_state, - MultiInferenceState::Intersect, - ); - // We perform inference once without any type context, emitting any diagnostics that are unrelated // to bidirectional type inference. - self.infer_expression_impl(ast_argument, TypeContext::default()); + *argument_type = Some(self.infer_expression(ast_argument, TypeContext::default())); // We then silence any diagnostics emitted during multi-inference, as the type context is only // used as a hint to infer a more assignable argument type, and should not lead to diagnostics @@ -6097,24 +6244,28 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { // Infer the type of each argument once with each distinct parameter type as type context. let parameter_types = overloads_with_binding - .clone() + .iter() .filter_map(|(overload, binding)| parameter_type(overload, binding)) .collect::>(); for parameter_type in parameter_types { - self.infer_expression_impl( - ast_argument, - TypeContext::new(Some(parameter_type)), - ); + let inferred_ty = + self.infer_expression(ast_argument, TypeContext::new(Some(parameter_type))); + + // Each type is a valid independent inference of the given argument, and we may require different + // permutations of argument types to correctly perform argument expansion during overload evaluation, + // so we take the intersection of all the types we inferred for each argument. + *argument_type = argument_type + .map(|current| IntersectionType::from_elements(db, [inferred_ty, current])) + .or(Some(inferred_ty)); } - // Restore the multi-inference state. - self.multi_inference_state = old_multi_inference_state; + // Re-enable diagnostics. self.context.set_multi_inference(was_in_multi_inference); } - - *argument_type = self.try_expression_type(ast_argument); } + + self.set_multi_inference_state(old_multi_inference_state); } fn infer_argument_type( @@ -6275,6 +6426,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { let db = self.db(); match self.multi_inference_state { + MultiInferenceState::Ignore => {} + MultiInferenceState::Panic => { let previous = self.expressions.insert(expression.into(), ty); assert_eq!(previous, None); @@ -6593,7 +6746,10 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } validate_typed_dict_dict_literal(&self.context, typed_dict, dict, dict.into(), |expr| { - self.expression_type(expr) + item_types + .get(&expr.node_index().load()) + .copied() + .unwrap_or(Type::unknown()) }) .ok() .map(|_| Type::TypedDict(typed_dict)) @@ -7356,7 +7512,13 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { let infer_call_arguments = |bindings: Option>| { if let Some(bindings) = bindings { let bindings = bindings.match_parameters(self.db(), &call_arguments); - self.infer_all_argument_types(arguments, &mut call_arguments, &bindings); + self.infer_all_argument_types( + arguments, + &mut call_arguments, + &bindings, + tcx, + MultiInferenceState::Intersect, + ); } else { let argument_forms = vec![Some(ParameterForm::Value); call_arguments.len()]; self.infer_argument_types(arguments, &mut call_arguments, &argument_forms); @@ -7374,10 +7536,12 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } } - let bindings = callable_type + let mut bindings = callable_type .bindings(self.db()) .match_parameters(self.db(), &call_arguments); - self.infer_all_argument_types(arguments, &mut call_arguments, &bindings); + + let bindings_result = + self.infer_and_check_argument_types(arguments, &mut call_arguments, &mut bindings, tcx); // Validate `TypedDict` constructor calls after argument type inference if let Some(class_literal) = callable_type.as_class_literal() { @@ -7395,14 +7559,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } } - let mut bindings = match bindings.check_types( - self.db(), - &call_arguments, - &tcx, - &self.dataclass_field_specifiers[..], - ) { - Ok(bindings) => bindings, - Err(CallError(_, bindings)) => { + let mut bindings = match bindings_result { + Ok(()) => bindings, + Err(_) => { bindings.report_diagnostics(&self.context, call_expression.into()); return bindings.return_type(self.db()); } @@ -10030,8 +10189,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { .check_types( self.db(), &call_argument_types, - &TypeContext::default(), - &self.dataclass_field_specifiers[..], + TypeContext::default(), + &self.dataclass_field_specifiers, ) { Ok(bindings) => bindings, Err(CallError(_, bindings)) => { @@ -10763,8 +10922,14 @@ enum MultiInferenceState { Panic, /// Overwrite the previously inferred value. + /// + /// Note that `Overwrite` does not interact well with nested inferences: + /// it overwrites values that were written with `MultiInferenceState::Intersect`. Overwrite, + /// Ignore the newly inferred value. + Ignore, + /// Store the intersection of all types inferred for the expression. Intersect, } @@ -11008,7 +11173,11 @@ where self.0.iter().map(|(k, v)| (k, v)) } - fn insert(&mut self, key: K, value: V) { + fn insert(&mut self, key: K, value: V, multi_inference_state: MultiInferenceState) { + if matches!(multi_inference_state, MultiInferenceState::Ignore) { + return; + } + debug_assert!( !self.0.iter().any(|(existing, _)| existing == &key), "An existing entry already exists for key {key:?}", @@ -11022,17 +11191,21 @@ where } } -impl Extend<(K, V)> for VecMap +impl VecMap where K: Eq, K: std::fmt::Debug, V: std::fmt::Debug, { #[inline] - fn extend>(&mut self, iter: T) { + fn extend>( + &mut self, + iter: T, + multi_inference_state: MultiInferenceState, + ) { if cfg!(debug_assertions) { for (key, value) in iter { - self.insert(key, value); + self.insert(key, value, multi_inference_state); } } else { self.0.extend(iter); @@ -11070,7 +11243,11 @@ where V: Eq, V: std::fmt::Debug, { - fn insert(&mut self, value: V) { + fn insert(&mut self, value: V, multi_inference_state: MultiInferenceState) { + if matches!(multi_inference_state, MultiInferenceState::Ignore) { + return; + } + debug_assert!( !self.0.iter().any(|existing| existing == &value), "An existing entry already exists for {value:?}", @@ -11080,16 +11257,20 @@ where } } -impl Extend for VecSet +impl VecSet where V: Eq, V: std::fmt::Debug, { #[inline] - fn extend>(&mut self, iter: T) { + fn extend>( + &mut self, + iter: T, + multi_inference_state: MultiInferenceState, + ) { if cfg!(debug_assertions) { for value in iter { - self.insert(value); + self.insert(value, multi_inference_state); } } else { self.0.extend(iter);