diff --git a/crates/ruff_benchmark/benches/ty_walltime.rs b/crates/ruff_benchmark/benches/ty_walltime.rs index ef7cd9333b63e..bf0ce69d5f870 100644 --- a/crates/ruff_benchmark/benches/ty_walltime.rs +++ b/crates/ruff_benchmark/benches/ty_walltime.rs @@ -164,7 +164,7 @@ static PANDAS: std::sync::LazyLock> = std::sync::LazyLock::ne max_dep_date: "2025-06-17", python_version: PythonVersion::PY312, }, - 3000, + 3300, ) }); @@ -199,7 +199,8 @@ static SYMPY: std::sync::LazyLock> = std::sync::LazyLock::new max_dep_date: "2025-06-17", python_version: PythonVersion::PY312, }, - 13000, + // TODO: With better decorator support, `__slots__` support, etc., it should be possible to reduce the number of errors considerably. + 70000, ) }); diff --git a/crates/ty_python_semantic/resources/corpus/cycle_into_callable.py b/crates/ty_python_semantic/resources/corpus/cycle_into_callable.py new file mode 100644 index 0000000000000..ce4cd6a795d02 --- /dev/null +++ b/crates/ty_python_semantic/resources/corpus/cycle_into_callable.py @@ -0,0 +1,17 @@ +# Regression test for https://github.com/astral-sh/ruff/issues/17371 +# panicked in commit d1088545a08aeb57b67ec1e3a7f5141159efefa5 +# error message: +# dependency graph cycle when querying ClassType < 'db >::into_callable_(Id(1c00)) + +try: + class foo[T: bar](object): + pass + bar = foo +except Exception: + bar = lambda: 0 +def bar(): + pass + +@bar() +class bar: + pass diff --git a/crates/ty_python_semantic/resources/corpus/divergent.py b/crates/ty_python_semantic/resources/corpus/divergent.py new file mode 100644 index 0000000000000..1ef6726cf2563 --- /dev/null +++ b/crates/ty_python_semantic/resources/corpus/divergent.py @@ -0,0 +1,72 @@ +def f(cond: bool): + if cond: + result = () + result += (f(cond),) + return result + + return None + +reveal_type(f(True)) + +def f(cond: bool): + if cond: + result = () + result += (f(cond),) + return result + + return None + +def f(cond: bool): + result = None + if cond: + result = () + result += (f(cond),) + + return result + +reveal_type(f(True)) + +def f(cond: bool): + result = None + if cond: + result = [f(cond) for _ in range(1)] + + return result + +reveal_type(f(True)) + +class Foo: + def value(self): + return 1 + +def unwrap(value): + if isinstance(value, Foo): + foo = value + return foo.value() + elif type(value) is tuple: + length = len(value) + if length == 0: + return () + elif length == 1: + return (unwrap(value[0]),) + else: + result = [] + for item in value: + result.append(unwrap(item)) + return tuple(result) + else: + raise TypeError() + +def descent(x: int, y: int): + if x > y: + y, x = descent(y, x) + return x, y + if x == 1: + return (1, 0) + if y == 1: + return (0, 1) + else: + return descent(x-1, y-1) + +def count_set_bits(n): + return 1 + count_set_bits(n & n - 1) if n else 0 diff --git a/crates/ty_python_semantic/resources/mdtest/annotations/invalid.md b/crates/ty_python_semantic/resources/mdtest/annotations/invalid.md index 5202b0921d8eb..53e81032e8e39 100644 --- a/crates/ty_python_semantic/resources/mdtest/annotations/invalid.md +++ b/crates/ty_python_semantic/resources/mdtest/annotations/invalid.md @@ -79,7 +79,9 @@ def outer_sync(): # `yield` from is only valid syntax inside a synchronous func a: (yield from [1]), # error: [invalid-type-form] "`yield from` expressions are not allowed in type expressions" ): ... -async def baz(): ... +async def baz(): + yield + async def outer_async(): # avoid unrelated syntax errors on `yield` and `await` def _( a: 1, # error: [invalid-type-form] "Int literals are not allowed in this context in a type expression" diff --git a/crates/ty_python_semantic/resources/mdtest/call/union.md b/crates/ty_python_semantic/resources/mdtest/call/union.md index b7df4043834f4..7a1736dd39051 100644 --- a/crates/ty_python_semantic/resources/mdtest/call/union.md +++ b/crates/ty_python_semantic/resources/mdtest/call/union.md @@ -111,7 +111,7 @@ def _(flag: bool): # error: [call-non-callable] "Object of type `Literal["This is a string literal"]` is not callable" x = f(3) - reveal_type(x) # revealed: Unknown + reveal_type(x) # revealed: None | Unknown ``` ## Union of binding errors @@ -128,7 +128,7 @@ def _(flag: bool): # error: [too-many-positional-arguments] "Too many positional arguments to function `f1`: expected 0, got 1" # error: [too-many-positional-arguments] "Too many positional arguments to function `f2`: expected 0, got 1" x = f(3) - reveal_type(x) # revealed: Unknown + reveal_type(x) # revealed: None ``` ## One not-callable, one wrong argument @@ -146,7 +146,7 @@ def _(flag: bool): # error: [too-many-positional-arguments] "Too many positional arguments to function `f1`: expected 0, got 1" # error: [call-non-callable] "Object of type `C` is not callable" x = f(3) - reveal_type(x) # revealed: Unknown + reveal_type(x) # revealed: None | Unknown ``` ## Union including a special-cased function diff --git a/crates/ty_python_semantic/resources/mdtest/diagnostics/semantic_syntax_errors.md b/crates/ty_python_semantic/resources/mdtest/diagnostics/semantic_syntax_errors.md index 5e77445b079bd..efed211d92187 100644 --- a/crates/ty_python_semantic/resources/mdtest/diagnostics/semantic_syntax_errors.md +++ b/crates/ty_python_semantic/resources/mdtest/diagnostics/semantic_syntax_errors.md @@ -125,7 +125,8 @@ match obj: ```py class C: - def __await__(self): ... + def __await__(self): + yield # error: [invalid-syntax] "`return` statement outside of a function" return diff --git a/crates/ty_python_semantic/resources/mdtest/function/return_type.md b/crates/ty_python_semantic/resources/mdtest/function/return_type.md index 980afbdf46622..aec1d42efcf55 100644 --- a/crates/ty_python_semantic/resources/mdtest/function/return_type.md +++ b/crates/ty_python_semantic/resources/mdtest/function/return_type.md @@ -256,6 +256,331 @@ def f(cond: bool) -> int: return 2 ``` +## Inferred return type + +### Free function + +If a function's return type is not annotated, it is inferred. The inferred type is the union of all +possible return types. + +```py +def f(): + return 1 + +reveal_type(f()) # revealed: Literal[1] +# TODO: should be `def f() -> Literal[1]` +reveal_type(f) # revealed: def f() -> Unknown + +def g(cond: bool): + if cond: + return 1 + else: + return "a" + +reveal_type(g(True)) # revealed: Literal[1, "a"] + +# This function implicitly returns `None`. +def h(x: int, y: str): + if x > 10: + return x + elif x > 5: + return y + +reveal_type(h(1, "a")) # revealed: int | str | None + +lambda_func = lambda: 1 +# TODO: lambda function type inference +# Should be `Literal[1]` +reveal_type(lambda_func()) # revealed: Unknown + +def generator(): + yield 1 + yield 2 + return None + +# TODO: Should be `Generator[Literal[1, 2], Any, None]` +reveal_type(generator()) # revealed: Unknown + +async def async_generator(): + yield + +# TODO: Should be `AsyncGenerator[None, Any]` +reveal_type(async_generator()) # revealed: Unknown + +async def coroutine(): + return + +# TODO: Should be `CoroutineType[Any, Any, None]` +reveal_type(coroutine()) # revealed: Unknown +``` + +The return type of a recursive function is also inferred. When the return type inference would +diverge, it is truncated and replaced with the special dynamic type `Divergent`. + +```toml +[environment] +python-version = "3.12" +``` + +```py +def fibonacci(n: int): + if n == 0: + return 0 + elif n == 1: + return 1 + else: + return fibonacci(n - 1) + fibonacci(n - 2) + +reveal_type(fibonacci(5)) # revealed: int + +def even(n: int): + if n == 0: + return True + else: + return odd(n - 1) + +def odd(n: int): + if n == 0: + return False + else: + return even(n - 1) + +reveal_type(even(1)) # revealed: bool +reveal_type(odd(1)) # revealed: bool + +def repeat_a(n: int): + if n <= 0: + return "" + else: + return repeat_a(n - 1) + "a" + +reveal_type(repeat_a(3)) # revealed: str + +def divergent(value): + if type(value) is tuple: + return (divergent(value[0]),) + else: + return None + +# tuple[tuple[tuple[...] | None] | None] | None => tuple[Divergent] | None +reveal_type(divergent((1,))) # revealed: tuple[Divergent] | None + +def call_divergent(x: int): + return (divergent((1, 2, 3)), x) + +reveal_type(call_divergent(1)) # revealed: tuple[tuple[Divergent] | None, int] + +def list1[T](x: T) -> list[T]: + return [x] + +def divergent2(value): + if type(value) is tuple: + return (divergent2(value[0]),) + elif type(value) is list: + return list1(divergent2(value[0])) + else: + return None + +reveal_type(divergent2((1,))) # revealed: tuple[Divergent] | list[Divergent] | None + +def list_int(x: int): + if x > 0: + return list1(list_int(x - 1)) + else: + return list1(x) + +# TODO: should be `list[int]` +reveal_type(list_int(1)) # revealed: list[Divergent] | list[int] + +def tuple_obj(cond: bool): + if cond: + x = object() + else: + x = tuple_obj(cond) + return (x,) + +reveal_type(tuple_obj(True)) # revealed: tuple[object] + +def get_non_empty(node): + for child in node.children: + node = get_non_empty(child) + if node is not None: + return node + return None + +reveal_type(get_non_empty(None)) # revealed: (Divergent & ~None) | None + +def nested_scope(): + def inner(): + return nested_scope() + return inner() + +reveal_type(nested_scope()) # revealed: Never + +def eager_nested_scope(): + class A: + x = eager_nested_scope() + + return A.x + +reveal_type(eager_nested_scope()) # revealed: Unknown + +class C: + def flip(self) -> "D": + return D() + +class D(C): + # TODO invalid override error + def flip(self) -> "C": + return C() + +def c_or_d(n: int): + if n == 0: + return D() + else: + return c_or_d(n - 1).flip() + +# In fixed-point iteration of the return type inference, the return type is monotonically widened. +# For example, once the return type of `c_or_d` is determined to be `C`, +# it will never be determined to be a subtype `D` in the subsequent iterations. +reveal_type(c_or_d(1)) # revealed: C +``` + +### Class method + +If a method's return type is not annotated, it is also inferred, but the inferred type is a union of +all possible return types and `Unknown`. This is because a method of a class may be overridden by +its subtypes. For example, if the return type of a method is inferred to be `int`, the type the +coder really intended might be `int | None`, in which case it would be impossible for the overridden +method to return `None`. + +```py +class C: + def f(self): + return 1 + +class D(C): + def f(self): + return None + +reveal_type(C().f()) # revealed: Literal[1] | Unknown +reveal_type(D().f()) # revealed: None | Literal[1] | Unknown +``` + +However, in the following cases, `Unknown` is not included in the inferred return type because there +is no ambiguity in the subclass. + +- The class or the method is marked as `final`. + +```py +from typing import final + +@final +class C: + def f(self): + return 1 + +class D: + @final + def f(self): + return "a" + +reveal_type(C().f()) # revealed: Literal[1] +reveal_type(D().f()) # revealed: Literal["a"] +``` + +- The method overrides the methods of the base classes, and the return types of the base class + methods are known (In this case, the return type of the method is the intersection of the return + types of the methods in the base classes). + +```toml +[environment] +python-version = "3.12" +``` + +```py +from typing import Literal + +class C: + def f(self) -> int: + return 1 + + def g[T](self, x: T) -> T: + return x + + def h[T: int](self, x: T) -> T: + return x + + def i[T: int](self, x: T) -> list[T]: + return [x] + +class D(C): + def f(self): + return 2 + # TODO: This should be an invalid-override error. + def g(self, x: int): + return 2 + # A strict application of the Liskov Substitution Principle would consider + # this an invalid override because it violates the guarantee that the method returns + # the same type as its input type (any type smaller than int), + # but neither mypy nor pyright will throw an error for this. + def h(self, x: int): + return 2 + + def i(self, x: int): + return [2] + +class E(D): + def f(self): + return 3 + +reveal_type(C().f()) # revealed: int +reveal_type(D().f()) # revealed: int +reveal_type(E().f()) # revealed: int +reveal_type(C().g(1)) # revealed: Literal[1] +reveal_type(D().g(1)) # revealed: Literal[2] | Unknown +reveal_type(C().h(1)) # revealed: Literal[1] +reveal_type(D().h(1)) # revealed: Literal[2] | Unknown +reveal_type(C().h(True)) # revealed: Literal[True] +reveal_type(D().h(True)) # revealed: Literal[2] | Unknown +reveal_type(C().i(1)) # revealed: list[Literal[1]] +# TODO: better type for list elements +reveal_type(D().i(1)) # revealed: list[Unknown | int] | list[Unknown] + +class F: + def f(self) -> Literal[1, 2]: + return 2 + +class G: + def f(self) -> Literal[2, 3]: + return 2 + +class H(F, G): + # TODO: should be an invalid-override error + def f(self): + raise NotImplementedError + +class I(F, G): + # TODO: should be an invalid-override error + @final + def f(self): + raise NotImplementedError + +# We use a return type of `F.f` according to the MRO. +reveal_type(H().f()) # revealed: Literal[1, 2] +reveal_type(I().f()) # revealed: Never + +class C2[T]: + def f(self, x: T) -> T: + return x + +class D2(C2[int]): + def f(self, x: int): + return x + +reveal_type(D2().f(1)) # revealed: int +``` + ## Invalid return type diff --git a/crates/ty_python_semantic/resources/mdtest/narrow/type.md b/crates/ty_python_semantic/resources/mdtest/narrow/type.md index 3cf1aa23dbd3c..96a1bcbc847a3 100644 --- a/crates/ty_python_semantic/resources/mdtest/narrow/type.md +++ b/crates/ty_python_semantic/resources/mdtest/narrow/type.md @@ -144,20 +144,25 @@ def _(x: A | B): reveal_type(x) # revealed: A | B ``` -## No narrowing for custom `type` callable +## No special narrowing for custom `type` callable ```py +def type(x: object): + return int + class A: ... class B: ... -def type(x): - return int - def _(x: A | B): + # The custom `type` function always returns `int`, + # so any branch other than `type(...) is int` is unreachable. if type(x) is A: + reveal_type(x) # revealed: Never + # And the condition here is always `True` and has no effect on the narrowing of `x`. + elif type(x) is int: reveal_type(x) # revealed: A | B else: - reveal_type(x) # revealed: A | B + reveal_type(x) # revealed: Never ``` ## No narrowing for multiple arguments diff --git a/crates/ty_python_semantic/src/semantic_index/scope.rs b/crates/ty_python_semantic/src/semantic_index/scope.rs index 62c7ff2a97b14..3bb15e0cbc3d5 100644 --- a/crates/ty_python_semantic/src/semantic_index/scope.rs +++ b/crates/ty_python_semantic/src/semantic_index/scope.rs @@ -1,6 +1,9 @@ use std::ops::Range; -use ruff_db::{files::File, parsed::ParsedModuleRef}; +use ruff_db::{ + files::File, + parsed::{ParsedModuleRef, parsed_module}, +}; use ruff_index::newtype_index; use ruff_python_ast as ast; @@ -26,6 +29,10 @@ pub struct ScopeId<'db> { impl get_size2::GetSize for ScopeId<'_> {} impl<'db> ScopeId<'db> { + pub(crate) fn is_non_lambda_function(self, db: &'db dyn Db) -> bool { + self.node(db).scope_kind().is_non_lambda_function() + } + pub(crate) fn is_annotation(self, db: &'db dyn Db) -> bool { self.node(db).scope_kind().is_annotation() } @@ -63,6 +70,18 @@ impl<'db> ScopeId<'db> { NodeWithScopeKind::GeneratorExpression(_) => "", } } + + pub(crate) fn is_coroutine_function(self, db: &'db dyn Db) -> bool { + let module = parsed_module(db, self.file(db)).load(db); + self.node(db) + .as_function() + .is_some_and(|func| func.node(&module).is_async && !self.is_generator_function(db)) + } + + pub(crate) fn is_generator_function(self, db: &'db dyn Db) -> bool { + let index = semantic_index(db, self.file(db)); + self.file_scope_id(db).is_generator_function(index) + } } /// ID that uniquely identifies a scope inside of a module. diff --git a/crates/ty_python_semantic/src/semantic_index/use_def.rs b/crates/ty_python_semantic/src/semantic_index/use_def.rs index 39f3a1a8ecfad..104cf0949be47 100644 --- a/crates/ty_python_semantic/src/semantic_index/use_def.rs +++ b/crates/ty_python_semantic/src/semantic_index/use_def.rs @@ -590,7 +590,6 @@ impl<'db> UseDefMap<'db> { .map(|symbol_id| (symbol_id, self.end_of_scope_symbol_bindings(symbol_id))) } - /// This function is intended to be called only once inside `TypeInferenceBuilder::infer_function_body`. pub(crate) fn can_implicitly_return_none(&self, db: &dyn crate::Db) -> bool { !self .reachability_constraints diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index b2cb2d61f1c00..befc3f1739ffa 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -1,6 +1,5 @@ use infer::nearest_enclosing_class; use itertools::{Either, Itertools}; -use ruff_db::parsed::parsed_module; use std::borrow::Cow; @@ -13,6 +12,7 @@ use diagnostic::{ }; use ruff_db::diagnostic::{Annotation, Diagnostic, Span, SubDiagnostic, SubDiagnosticSeverity}; use ruff_db::files::File; +use ruff_db::parsed::parsed_module; use ruff_python_ast::name::Name; use ruff_python_ast::{self as ast, AnyNodeRef}; use ruff_text_size::{Ranged, TextRange}; @@ -49,7 +49,7 @@ pub use crate::types::display::DisplaySettings; use crate::types::display::TupleSpecialization; use crate::types::enums::{enum_metadata, is_single_member_enum}; use crate::types::function::{ - DataclassTransformerParams, FunctionSpans, FunctionType, KnownFunction, + DataclassTransformerParams, FunctionDecorators, FunctionSpans, FunctionType, KnownFunction, }; use crate::types::generics::{ GenericContext, PartialSpecialization, Specialization, bind_typevar, walk_generic_context, @@ -61,7 +61,6 @@ pub use crate::types::ide_support::{ definitions_for_keyword_argument, definitions_for_name, find_active_signature_from_details, inlay_hint_function_argument_details, }; -use crate::types::infer::infer_unpack_types; use crate::types::mro::{Mro, MroError, MroIterator}; pub(crate) use crate::types::narrow::infer_narrowing_constraint; use crate::types::signatures::{ParameterForm, walk_signature}; @@ -110,6 +109,25 @@ mod definition; #[cfg(test)] mod property_tests; +fn return_type_cycle_recover<'db>( + _db: &'db dyn Db, + _value: &Type<'db>, + _count: u32, + _self: BoundMethodType<'db>, +) -> salsa::CycleRecoveryAction> { + salsa::CycleRecoveryAction::Iterate +} + +fn return_type_cycle_initial<'db>(db: &'db dyn Db, method: BoundMethodType<'db>) -> Type<'db> { + Type::divergent( + method + .function(db) + .literal(db) + .last_definition(db) + .body_scope(db), + ) +} + pub fn check_types(db: &dyn Db, file: File) -> Vec { let _span = tracing::trace_span!("check_types", ?file).entered(); @@ -225,6 +243,41 @@ pub(crate) struct TryBool; pub(crate) type NormalizedVisitor<'db> = TypeTransformer<'db, Normalized>; pub(crate) struct Normalized; +/// A [`TypeTransformer`] that is used in `recursive_type_normalized` methods. +/// Calling [`Type::recursive_type_normalized`] will normalize the recursive type. +/// A recursive type here means a type that contains a `Divergent` type. +/// Normalizing recursive types allows recursive type inference for divergent functions to converge. +pub(crate) struct RecursiveTypeNormalizedVisitor<'db> { + transformer: TypeTransformer<'db, Normalized>, + div: Type<'db>, +} + +impl<'db> RecursiveTypeNormalizedVisitor<'db> { + fn new(div: Type<'db>) -> Self { + // TODO: Divergent only + debug_assert!(matches!( + div, + Type::Never | Type::Dynamic(DynamicType::Divergent(_)) + )); + Self { + transformer: NormalizedVisitor::default(), + div, + } + } + + fn visit(&self, item: Type<'db>, func: impl FnOnce() -> Type<'db>) -> Type<'db> { + self.transformer.visit(item, func) + } + + fn visit_no_shift(&self, item: Type<'db>, func: impl FnOnce() -> Type<'db>) -> Type<'db> { + self.transformer.visit_no_shift(item, func) + } + + fn level(&self) -> usize { + self.transformer.level() + } +} + /// How a generic type has been specialized. /// /// This matters only if there is at least one invariant type parameter. @@ -538,6 +591,20 @@ impl<'db> PropertyInstanceType<'db> { ) } + fn recursive_type_normalized( + self, + db: &'db dyn Db, + visitor: &RecursiveTypeNormalizedVisitor<'db>, + ) -> Self { + Self::new( + db, + self.getter(db) + .map(|ty| ty.recursive_type_normalized(db, visitor)), + self.setter(db) + .map(|ty| ty.recursive_type_normalized(db, visitor)), + ) + } + fn find_legacy_typevars_impl( self, db: &'db dyn Db, @@ -1271,6 +1338,99 @@ impl<'db> Type<'db> { } } + #[must_use] + pub(crate) fn recursive_type_normalized( + self, + db: &'db dyn Db, + visitor: &RecursiveTypeNormalizedVisitor<'db>, + ) -> Self { + if visitor.level() == 0 && self == visitor.div { + // int | Divergent = int | (int | (int | ...)) = int + return Type::Never; + } else if visitor.level() >= 1 && self.has_divergent_type(db, visitor.div) { + // G[G[Divergent]] = G[Divergent] + return visitor.div; + } + match self { + Type::Union(union) => { + // As explained above, `Divergent` in a union type does not mean true divergence, + // so we normalize the type while keeping the nesting level the same. + visitor.visit_no_shift(self, || union.recursive_type_normalized(db, visitor)) + } + Type::Intersection(intersection) => visitor.visit(self, || { + Type::Intersection(intersection.recursive_type_normalized(db, visitor)) + }), + Type::Callable(callable) => visitor.visit(self, || { + Type::Callable(callable.recursive_type_normalized(db, visitor)) + }), + Type::ProtocolInstance(protocol) => visitor.visit(self, || { + Type::ProtocolInstance(protocol.recursive_type_normalized(db, visitor)) + }), + Type::NominalInstance(instance) => visitor.visit(self, || { + Type::NominalInstance(instance.recursive_type_normalized(db, visitor)) + }), + Type::FunctionLiteral(function) => visitor.visit(self, || { + Type::FunctionLiteral(function.recursive_type_normalized(db, visitor)) + }), + Type::PropertyInstance(property) => visitor.visit(self, || { + Type::PropertyInstance(property.recursive_type_normalized(db, visitor)) + }), + Type::KnownBoundMethod(method_kind) => visitor.visit(self, || { + Type::KnownBoundMethod(method_kind.recursive_type_normalized(db, visitor)) + }), + Type::BoundMethod(method) => visitor.visit(self, || { + Type::BoundMethod(method.recursive_type_normalized(db, visitor)) + }), + Type::BoundSuper(bound_super) => visitor.visit(self, || { + Type::BoundSuper(bound_super.recursive_type_normalized(db, visitor)) + }), + Type::GenericAlias(generic) => visitor.visit(self, || { + Type::GenericAlias(generic.recursive_type_normalized(db, visitor)) + }), + Type::SubclassOf(subclass_of) => visitor.visit(self, || { + Type::SubclassOf(subclass_of.recursive_type_normalized(db, visitor)) + }), + Type::TypeVar(bound_typevar) => visitor.visit(self, || { + Type::TypeVar(bound_typevar.recursive_type_normalized(db, visitor)) + }), + Type::NonInferableTypeVar(bound_typevar) => visitor.visit(self, || { + Type::NonInferableTypeVar(bound_typevar.recursive_type_normalized(db, visitor)) + }), + Type::KnownInstance(known_instance) => visitor.visit(self, || { + Type::KnownInstance(known_instance.recursive_type_normalized(db, visitor)) + }), + Type::TypeIs(type_is) => visitor.visit(self, || { + type_is.with_type( + db, + type_is + .return_type(db) + .recursive_type_normalized(db, visitor), + ) + }), + Type::Dynamic(dynamic) => Type::Dynamic(dynamic.recursive_type_normalized()), + Type::TypedDict(_) => { + // TODO: Normalize TypedDicts + self + } + Type::TypeAlias(_) => self, + Type::LiteralString + | Type::AlwaysFalsy + | Type::AlwaysTruthy + | Type::BooleanLiteral(_) + | Type::BytesLiteral(_) + | Type::EnumLiteral(_) + | Type::StringLiteral(_) + | Type::Never + | Type::WrapperDescriptor(_) + | Type::DataclassDecorator(_) + | Type::DataclassTransformer(_) + | Type::ModuleLiteral(_) + | Type::ClassLiteral(_) + | Type::SpecialForm(_) + | Type::IntLiteral(_) => self, + } + } + /// Return `true` if subtyping is always reflexive for this type; `T <: T` is always true for /// any `T` of this type. /// @@ -4844,6 +5004,19 @@ impl<'db> Type<'db> { } } + /// Returns the inferred return type of `self` if it is a function literal / bound method. + fn infer_return_type(self, db: &'db dyn Db) -> Option> { + match self { + Type::FunctionLiteral(function_type) if !function_type.file(db).is_stub(db) => { + Some(function_type.infer_return_type(db)) + } + Type::BoundMethod(method_type) if !method_type.function(db).file(db).is_stub(db) => { + Some(method_type.infer_return_type(db)) + } + _ => None, + } + } + /// Calls `self`. Returns a [`CallError`] if `self` is (always or possibly) not callable, or if /// the arguments are not compatible with the formal parameters. /// @@ -5006,9 +5179,7 @@ impl<'db> Type<'db> { let special_case = match self { Type::NominalInstance(nominal) => nominal.tuple_spec(db), Type::GenericAlias(alias) if alias.origin(db).is_tuple(db) => { - Some(Cow::Owned(TupleSpec::homogeneous(todo_type!( - "*tuple[] annotations" - )))) + Some(Cow::Owned(TupleSpec::homogeneous(todo_type!("*tuple[] annotations")))) } Type::StringLiteral(string_literal_ty) => { let string_literal = string_literal_ty.value(db); @@ -5941,7 +6112,6 @@ impl<'db> Type<'db> { .unwrap_or(SubclassOfInner::unknown()), ), }, - Type::StringLiteral(_) | Type::LiteralString => KnownClass::Str.to_class_literal(db), Type::Dynamic(dynamic) => SubclassOfType::from(db, SubclassOfInner::Dynamic(dynamic)), // TODO intersections @@ -6853,6 +7023,37 @@ impl<'db> TypeMapping<'_, 'db> { } } + fn recursive_type_normalized( + &self, + db: &'db dyn Db, + visitor: &RecursiveTypeNormalizedVisitor<'db>, + ) -> Self { + match self { + TypeMapping::Specialization(specialization) => { + TypeMapping::Specialization(specialization.recursive_type_normalized(db, visitor)) + } + TypeMapping::PartialSpecialization(partial) => { + TypeMapping::PartialSpecialization(partial.recursive_type_normalized(db, visitor)) + } + TypeMapping::PromoteLiterals => TypeMapping::PromoteLiterals, + TypeMapping::BindLegacyTypevars(binding_context) => { + TypeMapping::BindLegacyTypevars(*binding_context) + } + TypeMapping::BindSelf(self_type) => { + TypeMapping::BindSelf(self_type.recursive_type_normalized(db, visitor)) + } + TypeMapping::ReplaceSelf { new_upper_bound } => TypeMapping::ReplaceSelf { + new_upper_bound: new_upper_bound.recursive_type_normalized(db, visitor), + }, + TypeMapping::MarkTypeVarsInferable(binding_context) => { + TypeMapping::MarkTypeVarsInferable(*binding_context) + } + TypeMapping::Materialize(materialization_kind) => { + TypeMapping::Materialize(*materialization_kind) + } + } + } + /// Update the generic context of a [`Signature`] according to the current type mapping pub(crate) fn update_signature_generic_context( &self, @@ -7004,6 +7205,34 @@ impl<'db> KnownInstanceType<'db> { } } + fn recursive_type_normalized( + self, + db: &'db dyn Db, + visitor: &RecursiveTypeNormalizedVisitor<'db>, + ) -> Self { + match self { + Self::SubscriptedProtocol(context) => { + Self::SubscriptedProtocol(context.recursive_type_normalized(db, visitor)) + } + Self::SubscriptedGeneric(context) => { + Self::SubscriptedGeneric(context.recursive_type_normalized(db, visitor)) + } + Self::TypeVar(typevar) => Self::TypeVar(typevar.recursive_type_normalized(db, visitor)), + Self::TypeAliasType(type_alias) => { + Self::TypeAliasType(type_alias.recursive_type_normalized(db, visitor)) + } + Self::Deprecated(deprecated) => { + // Nothing to normalize + Self::Deprecated(deprecated) + } + Self::Field(field) => Self::Field(field.recursive_type_normalized(db, visitor)), + Self::ConstraintSet(set) => { + // Nothing to normalize + Self::ConstraintSet(set) + } + } + } + fn class(self, db: &'db dyn Db) -> KnownClass { match self { Self::SubscriptedProtocol(_) | Self::SubscriptedGeneric(_) => KnownClass::SpecialForm, @@ -7146,6 +7375,10 @@ impl DynamicType<'_> { Self::Any } } + + fn recursive_type_normalized(self) -> Self { + self + } } impl std::fmt::Display for DynamicType<'_> { @@ -7480,6 +7713,19 @@ impl<'db> FieldInstance<'db> { self.kw_only(db), ) } + + fn recursive_type_normalized( + self, + db: &'db dyn Db, + visitor: &RecursiveTypeNormalizedVisitor<'db>, + ) -> Self { + FieldInstance::new( + db, + self.default_type(db).recursive_type_normalized(db, visitor), + self.init(db), + self.kw_only(db), + ) + } } /// Whether this typevar was created via the legacy `TypeVar` constructor, using PEP 695 syntax, @@ -7666,6 +7912,14 @@ impl<'db> TypeVarInstance<'db> { ) } + fn recursive_type_normalized( + self, + _db: &'db dyn Db, + _visitor: &RecursiveTypeNormalizedVisitor<'db>, + ) -> Self { + self + } + fn materialize_impl( self, db: &'db dyn Db, @@ -7964,6 +8218,18 @@ impl<'db> BoundTypeVarInstance<'db> { ) } + fn recursive_type_normalized( + self, + db: &'db dyn Db, + visitor: &RecursiveTypeNormalizedVisitor<'db>, + ) -> Self { + Self::new( + db, + self.typevar(db).recursive_type_normalized(db, visitor), + self.binding_context(db), + ) + } + fn materialize_impl( self, db: &'db dyn Db, @@ -9245,6 +9511,77 @@ impl<'db> BoundMethodType<'db> { ) } + /// Infers this method scope's types and returns the inferred return type. + #[salsa::tracked(cycle_fn=return_type_cycle_recover, cycle_initial=return_type_cycle_initial, heap_size=get_size2::heap_size)] + pub(crate) fn infer_return_type(self, db: &'db dyn Db) -> Type<'db> { + let scope = self + .function(db) + .literal(db) + .last_definition(db) + .body_scope(db); + let inference = infer_scope_types(db, scope); + inference.infer_return_type(db, Type::BoundMethod(self)) + } + + #[salsa::tracked(heap_size=ruff_memory_usage::heap_size)] + fn class_definition(self, db: &'db dyn Db) -> Option> { + let definition_scope = self.function(db).definition(db).scope(db); + let index = semantic_index(db, definition_scope.file(db)); + Some(index.expect_single_definition(definition_scope.node(db).as_class()?)) + } + + pub(crate) fn is_final(self, db: &'db dyn Db) -> bool { + if self + .function(db) + .has_known_decorator(db, FunctionDecorators::FINAL) + { + return true; + } + let Some(class_ty) = self + .class_definition(db) + .and_then(|class| binding_type(db, class).into_class_literal()) + else { + return false; + }; + class_ty + .known_function_decorators(db) + .any(|deco| deco == KnownFunction::Final) + } + + pub(super) fn base_return_type(self, db: &'db dyn Db) -> Option> { + let class = binding_type(db, self.class_definition(db)?).to_class_type(db)?; + let name = self.function(db).name(db); + + let base = class + .iter_mro(db) + .nth(1) + .and_then(class_base::ClassBase::into_class)?; + let base_member = base.class_member(db, name, MemberLookupPolicy::default()); + if let Place::Type(Type::FunctionLiteral(base_func), _) = base_member.place { + if let [signature] = base_func.signature(db).overloads.as_slice() { + let unspecialized_return_ty = signature.return_ty.unwrap_or_else(|| { + let base_method_ty = + base_func.into_bound_method_type(db, Type::instance(db, class)); + base_method_ty.infer_return_type(db) + }); + if let Some(generic_context) = signature.generic_context.as_ref() { + // If the return type of the base method contains a type variable, replace it with `Unknown` to avoid dangling type variables. + Some( + unspecialized_return_ty + .apply_specialization(db, generic_context.unknown_specialization(db)), + ) + } else { + Some(unspecialized_return_ty) + } + } else { + // TODO: Handle overloaded base methods. + None + } + } else { + None + } + } + fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self { Self::new( db, @@ -9253,6 +9590,19 @@ impl<'db> BoundMethodType<'db> { ) } + fn recursive_type_normalized( + self, + db: &'db dyn Db, + visitor: &RecursiveTypeNormalizedVisitor<'db>, + ) -> Self { + Self::new( + db, + self.function(db).recursive_type_normalized(db, visitor), + self.self_instance(db) + .recursive_type_normalized(db, visitor), + ) + } + fn has_relation_to_impl( self, db: &'db dyn Db, @@ -9380,6 +9730,18 @@ impl<'db> CallableType<'db> { ) } + fn recursive_type_normalized( + self, + db: &'db dyn Db, + visitor: &RecursiveTypeNormalizedVisitor<'db>, + ) -> Self { + CallableType::new( + db, + self.signatures(db).recursive_type_normalized(db, visitor), + self.is_function_like(db), + ) + } + fn apply_type_mapping_impl<'a>( self, db: &'db dyn Db, @@ -9615,6 +9977,36 @@ impl<'db> KnownBoundMethodType<'db> { } } + fn recursive_type_normalized( + self, + db: &'db dyn Db, + visitor: &RecursiveTypeNormalizedVisitor<'db>, + ) -> Self { + match self { + KnownBoundMethodType::FunctionTypeDunderGet(function) => { + KnownBoundMethodType::FunctionTypeDunderGet( + function.recursive_type_normalized(db, visitor), + ) + } + KnownBoundMethodType::FunctionTypeDunderCall(function) => { + KnownBoundMethodType::FunctionTypeDunderCall( + function.recursive_type_normalized(db, visitor), + ) + } + KnownBoundMethodType::PropertyDunderGet(property) => { + KnownBoundMethodType::PropertyDunderGet( + property.recursive_type_normalized(db, visitor), + ) + } + KnownBoundMethodType::PropertyDunderSet(property) => { + KnownBoundMethodType::PropertyDunderSet( + property.recursive_type_normalized(db, visitor), + ) + } + KnownBoundMethodType::StrStartswith(_) | KnownBoundMethodType::PathOpen => self, + } + } + /// Return the [`KnownClass`] that inhabitants of this type are instances of at runtime fn class(self) -> KnownClass { match self { @@ -10022,6 +10414,14 @@ impl<'db> PEP695TypeAliasType<'db> { fn normalized_impl(self, _db: &'db dyn Db, _visitor: &NormalizedVisitor<'db>) -> Self { self } + + fn recursive_type_normalized( + self, + _db: &'db dyn Db, + _visitor: &RecursiveTypeNormalizedVisitor<'db>, + ) -> Self { + self + } } #[allow(clippy::ref_option, clippy::trivially_copy_pass_by_ref)] @@ -10088,6 +10488,19 @@ impl<'db> ManualPEP695TypeAliasType<'db> { self.value(db).normalized_impl(db, visitor), ) } + + fn recursive_type_normalized( + self, + db: &'db dyn Db, + visitor: &RecursiveTypeNormalizedVisitor<'db>, + ) -> Self { + Self::new( + db, + self.name(db), + self.definition(db), + self.value(db).recursive_type_normalized(db, visitor), + ) + } } #[derive( @@ -10130,6 +10543,21 @@ impl<'db> TypeAliasType<'db> { } } + fn recursive_type_normalized( + self, + db: &'db dyn Db, + visitor: &RecursiveTypeNormalizedVisitor<'db>, + ) -> Self { + match self { + TypeAliasType::PEP695(type_alias) => { + TypeAliasType::PEP695(type_alias.recursive_type_normalized(db, visitor)) + } + TypeAliasType::ManualPEP695(type_alias) => { + TypeAliasType::ManualPEP695(type_alias.recursive_type_normalized(db, visitor)) + } + } + } + pub(crate) fn name(self, db: &'db dyn Db) -> &'db str { match self { TypeAliasType::PEP695(type_alias) => type_alias.name(db), @@ -10411,6 +10839,23 @@ impl<'db> UnionType<'db> { .build() } + fn recursive_type_normalized( + self, + db: &'db dyn Db, + visitor: &RecursiveTypeNormalizedVisitor<'db>, + ) -> Type<'db> { + self.elements(db) + .iter() + .map(|ty| ty.recursive_type_normalized(db, visitor)) + .fold( + UnionBuilder::new(db) + .order_elements(false) + .unpack_aliases(false), + UnionBuilder::add, + ) + .build() + } + pub(crate) fn is_equivalent_to_impl( self, db: &'db dyn Db, @@ -10511,6 +10956,29 @@ impl<'db> IntersectionType<'db> { ) } + pub(crate) fn recursive_type_normalized( + self, + db: &'db dyn Db, + visitor: &RecursiveTypeNormalizedVisitor<'db>, + ) -> Self { + fn normalized_set<'db>( + db: &'db dyn Db, + elements: &FxOrderSet>, + visitor: &RecursiveTypeNormalizedVisitor<'db>, + ) -> FxOrderSet> { + elements + .iter() + .map(|ty| ty.recursive_type_normalized(db, visitor)) + .collect() + } + + IntersectionType::new( + db, + normalized_set(db, self.positive(db), visitor), + normalized_set(db, self.negative(db), visitor), + ) + } + /// Return `true` if `self` represents exactly the same set of possible runtime objects as `other` pub(crate) fn is_equivalent_to_impl( self, @@ -10788,6 +11256,24 @@ impl<'db> SuperOwnerKind<'db> { } } + fn recursive_type_normalized( + self, + db: &'db dyn Db, + visitor: &RecursiveTypeNormalizedVisitor<'db>, + ) -> Self { + match self { + SuperOwnerKind::Dynamic(dynamic) => { + SuperOwnerKind::Dynamic(dynamic.recursive_type_normalized()) + } + SuperOwnerKind::Class(class) => { + SuperOwnerKind::Class(class.recursive_type_normalized(db, visitor)) + } + SuperOwnerKind::Instance(instance) => { + SuperOwnerKind::Instance(instance.recursive_type_normalized(db, visitor)) + } + } + } + fn iter_mro(self, db: &'db dyn Db) -> impl Iterator> { match self { SuperOwnerKind::Dynamic(dynamic) => { @@ -11062,6 +11548,18 @@ impl<'db> BoundSuperType<'db> { self.owner(db).normalized_impl(db, visitor), ) } + + fn recursive_type_normalized( + self, + db: &'db dyn Db, + visitor: &RecursiveTypeNormalizedVisitor<'db>, + ) -> Self { + Self::new( + db, + self.pivot_class(db).recursive_type_normalized(db, visitor), + self.owner(db).recursive_type_normalized(db, visitor), + ) + } } #[salsa::interned(debug, heap_size=ruff_memory_usage::heap_size)] @@ -11256,8 +11754,7 @@ pub(crate) mod tests { let file = system_path_to_file(&db, "src/foo.py").unwrap(); let file_scope_id = FileScopeId::global(); let scope = file_scope_id.to_scope_id(&db, file); - - let div = Type::Dynamic(DynamicType::Divergent(DivergentType { scope })); + let div = Type::divergent(scope); // The `Divergent` type must not be eliminated in union with other dynamic types, // as this would prevent detection of divergent type inference using `Divergent`. @@ -11283,6 +11780,27 @@ pub(crate) mod tests { let union = UnionType::from_elements(&db, [KnownClass::Object.to_instance(&db), div]); assert_eq!(union.display(&db).to_string(), "object"); + let visitor = RecursiveTypeNormalizedVisitor::new(div); + let recursice = UnionType::from_elements( + &db, + [ + KnownClass::List.to_specialized_instance(&db, [div]), + Type::none(&db), + ], + ); + let nested_rec = KnownClass::List.to_specialized_instance(&db, [recursice]); + assert_eq!( + nested_rec.display(&db).to_string(), + "list[list[Divergent] | None]" + ); + let normalized = nested_rec.recursive_type_normalized(&db, &visitor); + assert_eq!(normalized.display(&db).to_string(), "list[Divergent]"); + + let union = UnionType::from_elements(&db, [div, KnownClass::Int.to_instance(&db)]); + assert_eq!(union.display(&db).to_string(), "Divergent | int"); + let normalized = union.recursive_type_normalized(&db, &visitor); + assert_eq!(normalized.display(&db).to_string(), "int"); + // The same can be said about intersections for the `Never` type. let intersection = IntersectionBuilder::new(&db) .add_positive(Type::Never) diff --git a/crates/ty_python_semantic/src/types/builder.rs b/crates/ty_python_semantic/src/types/builder.rs index cc1c35790a90b..a8e24bc57f08d 100644 --- a/crates/ty_python_semantic/src/types/builder.rs +++ b/crates/ty_python_semantic/src/types/builder.rs @@ -206,7 +206,9 @@ enum ReduceResult<'db> { // // For now (until we solve https://github.com/astral-sh/ty/issues/957), keep this number // below 200, which is the salsa fixpoint iteration limit. -const MAX_UNION_LITERALS: usize = 199; +// +// If we can handle fixed-point iterations properly, we should be able to reset the limit to 199. +const MAX_UNION_LITERALS: usize = 188; pub(crate) struct UnionBuilder<'db> { elements: Vec>, diff --git a/crates/ty_python_semantic/src/types/call/bind.rs b/crates/ty_python_semantic/src/types/call/bind.rs index f5b2e4e349306..7ad25d95fbd10 100644 --- a/crates/ty_python_semantic/src/types/call/bind.rs +++ b/crates/ty_python_semantic/src/types/call/bind.rs @@ -2720,10 +2720,15 @@ impl<'db> Binding<'db> { } } } + for (keywords_index, keywords_type) in keywords_arguments { matcher.match_keyword_variadic(db, keywords_index, keywords_type); } - self.return_ty = self.signature.return_ty.unwrap_or(Type::unknown()); + self.return_ty = self.signature.return_ty.unwrap_or_else(|| { + self.callable_type + .infer_return_type(db) + .unwrap_or(Type::unknown()) + }); self.parameter_tys = vec![None; parameters.len()].into_boxed_slice(); self.argument_matches = matcher.finish(); } diff --git a/crates/ty_python_semantic/src/types/class.rs b/crates/ty_python_semantic/src/types/class.rs index 0afb14480a384..bc107019e2ae1 100644 --- a/crates/ty_python_semantic/src/types/class.rs +++ b/crates/ty_python_semantic/src/types/class.rs @@ -4,7 +4,7 @@ use super::TypeVarVariance; use super::{ BoundTypeVarInstance, IntersectionBuilder, MemberLookupPolicy, Mro, MroError, MroIterator, SpecialFormType, SubclassOfType, Truthiness, Type, TypeQualifiers, class_base::ClassBase, - function::FunctionType, infer_expression_type, infer_unpack_types, + function::FunctionType, }; use crate::FxOrderMap; use crate::module_resolver::KnownModule; @@ -20,7 +20,7 @@ use crate::types::diagnostic::{INVALID_LEGACY_TYPE_VARIABLE, INVALID_TYPE_ALIAS_ use crate::types::enums::enum_metadata; use crate::types::function::{DataclassTransformerParams, KnownFunction}; use crate::types::generics::{GenericContext, Specialization, walk_specialization}; -use crate::types::infer::nearest_enclosing_class; +use crate::types::infer::{infer_expression_type, infer_unpack_types, nearest_enclosing_class}; use crate::types::signatures::{CallableSignature, Parameter, Parameters, Signature}; use crate::types::tuple::{TupleSpec, TupleType}; use crate::types::typed_dict::typed_dict_params_from_class_def; @@ -28,10 +28,10 @@ use crate::types::{ ApplyTypeMappingVisitor, Binding, BoundSuperError, BoundSuperType, CallableType, DataclassParams, DeprecatedInstance, FindLegacyTypeVarsVisitor, HasRelationToVisitor, IsEquivalentVisitor, KnownInstanceType, ManualPEP695TypeAliasType, MaterializationKind, - NormalizedVisitor, PropertyInstanceType, StringLiteralType, TypeAliasType, TypeContext, - TypeMapping, TypeRelation, TypeVarBoundOrConstraints, TypeVarInstance, TypeVarKind, - TypedDictParams, UnionBuilder, VarianceInferable, declaration_type, determine_upper_bound, - infer_definition_types, + NormalizedVisitor, PropertyInstanceType, RecursiveTypeNormalizedVisitor, StringLiteralType, + TypeAliasType, TypeContext, TypeMapping, TypeRelation, TypeVarBoundOrConstraints, + TypeVarInstance, TypeVarKind, TypedDictParams, UnionBuilder, VarianceInferable, + declaration_type, determine_upper_bound, infer_definition_types, }; use crate::{ Db, FxIndexMap, FxOrderSet, Program, @@ -269,6 +269,19 @@ impl<'db> GenericAlias<'db> { ) } + pub(super) fn recursive_type_normalized( + self, + db: &'db dyn Db, + visitor: &RecursiveTypeNormalizedVisitor<'db>, + ) -> Self { + Self::new( + db, + self.origin(db), + self.specialization(db) + .recursive_type_normalized(db, visitor), + ) + } + pub(crate) fn definition(self, db: &'db dyn Db) -> Definition<'db> { self.origin(db).definition(db) } @@ -394,6 +407,17 @@ impl<'db> ClassType<'db> { } } + pub(super) fn recursive_type_normalized( + self, + db: &'db dyn Db, + visitor: &RecursiveTypeNormalizedVisitor<'db>, + ) -> Self { + match self { + Self::NonGeneric(_) => self, + Self::Generic(generic) => Self::Generic(generic.recursive_type_normalized(db, visitor)), + } + } + pub(super) fn has_pep_695_type_params(self, db: &'db dyn Db) -> bool { match self { Self::NonGeneric(class) => class.has_pep_695_type_params(db), @@ -1208,7 +1232,6 @@ impl<'db> ClassType<'db> { } } -#[allow(clippy::trivially_copy_pass_by_ref)] fn into_callable_cycle_recover<'db>( _db: &'db dyn Db, _value: &Type<'db>, @@ -1218,8 +1241,8 @@ fn into_callable_cycle_recover<'db>( salsa::CycleRecoveryAction::Iterate } -fn into_callable_cycle_initial<'db>(_db: &'db dyn Db, _self: ClassType<'db>) -> Type<'db> { - Type::Never +fn into_callable_cycle_initial<'db>(db: &'db dyn Db, _self: ClassType<'db>) -> Type<'db> { + Type::Callable(CallableType::bottom(db)) } impl<'db> From> for ClassType<'db> { diff --git a/crates/ty_python_semantic/src/types/class_base.rs b/crates/ty_python_semantic/src/types/class_base.rs index 547ace5923150..5f82cae00a4c7 100644 --- a/crates/ty_python_semantic/src/types/class_base.rs +++ b/crates/ty_python_semantic/src/types/class_base.rs @@ -4,8 +4,8 @@ use crate::types::generics::Specialization; use crate::types::tuple::TupleType; use crate::types::{ ApplyTypeMappingVisitor, ClassLiteral, ClassType, DynamicType, KnownClass, KnownInstanceType, - MaterializationKind, MroError, MroIterator, NormalizedVisitor, SpecialFormType, Type, - TypeMapping, todo_type, + MaterializationKind, MroError, MroIterator, NormalizedVisitor, RecursiveTypeNormalizedVisitor, + SpecialFormType, Type, TypeMapping, todo_type, }; /// Enumeration of the possible kinds of types we allow in class bases. @@ -43,6 +43,18 @@ impl<'db> ClassBase<'db> { } } + pub(super) fn recursive_type_normalized( + self, + db: &'db dyn Db, + visitor: &RecursiveTypeNormalizedVisitor<'db>, + ) -> Self { + match self { + Self::Dynamic(dynamic) => Self::Dynamic(dynamic.recursive_type_normalized()), + Self::Class(class) => Self::Class(class.recursive_type_normalized(db, visitor)), + Self::Protocol | Self::Generic | Self::TypedDict => self, + } + } + pub(crate) fn name(self, db: &'db dyn Db) -> &'db str { match self { ClassBase::Class(class) => class.name(db), diff --git a/crates/ty_python_semantic/src/types/cyclic.rs b/crates/ty_python_semantic/src/types/cyclic.rs index e46d30f40f5e6..69f24eca0dcad 100644 --- a/crates/ty_python_semantic/src/types/cyclic.rs +++ b/crates/ty_python_semantic/src/types/cyclic.rs @@ -58,6 +58,10 @@ pub struct CycleDetector { /// sort-of defeat the point of a cache if we did!) cache: RefCell>, + /// The nesting level of the `visit` method. + /// This is necessary for normalizing recursive types. + level: RefCell, + fallback: R, _tag: PhantomData, @@ -68,12 +72,13 @@ impl CycleDetector { CycleDetector { seen: RefCell::new(FxIndexSet::default()), cache: RefCell::new(FxHashMap::default()), + level: RefCell::new(0), fallback, _tag: PhantomData, } } - pub fn visit(&self, item: T, func: impl FnOnce() -> R) -> R { + fn visit_impl(&self, shift_level: bool, item: T, func: impl FnOnce() -> R) -> R { if let Some(val) = self.cache.borrow().get(&item) { return val.clone(); } @@ -83,12 +88,30 @@ impl CycleDetector { return self.fallback.clone(); } + if shift_level { + *self.level.borrow_mut() += 1; + } let ret = func(); + if shift_level { + *self.level.borrow_mut() -= 1; + } self.seen.borrow_mut().pop(); self.cache.borrow_mut().insert(item, ret.clone()); ret } + + pub fn visit(&self, item: T, func: impl FnOnce() -> R) -> R { + self.visit_impl(true, item, func) + } + + pub(crate) fn visit_no_shift(&self, item: T, func: impl FnOnce() -> R) -> R { + self.visit_impl(false, item, func) + } + + pub(crate) fn level(&self) -> usize { + *self.level.borrow() + } } impl Default for CycleDetector { diff --git a/crates/ty_python_semantic/src/types/function.rs b/crates/ty_python_semantic/src/types/function.rs index ed10c7b819436..ae2cfa377be8b 100644 --- a/crates/ty_python_semantic/src/types/function.rs +++ b/crates/ty_python_semantic/src/types/function.rs @@ -72,19 +72,34 @@ use crate::types::diagnostic::{ report_bad_argument_to_get_protocol_members, report_bad_argument_to_protocol_interface, report_runtime_check_against_non_runtime_checkable_protocol, }; -use crate::types::generics::{GenericContext, walk_generic_context}; +use crate::types::generics::GenericContext; +use crate::types::infer::infer_scope_types; use crate::types::narrow::ClassInfoConstraintFunction; use crate::types::signatures::{CallableSignature, Signature}; use crate::types::visitor::any_over_type; use crate::types::{ BoundMethodType, BoundTypeVarInstance, CallableType, ClassBase, ClassLiteral, ClassType, DeprecatedInstance, DynamicType, FindLegacyTypeVarsVisitor, HasRelationToVisitor, - IsEquivalentVisitor, KnownClass, KnownInstanceType, NormalizedVisitor, SpecialFormType, - TrackedConstraintSet, Truthiness, Type, TypeMapping, TypeRelation, UnionBuilder, all_members, - binding_type, todo_type, walk_type_mapping, + IsEquivalentVisitor, KnownClass, KnownInstanceType, NormalizedVisitor, + RecursiveTypeNormalizedVisitor, SpecialFormType, TrackedConstraintSet, Truthiness, Type, + TypeMapping, TypeRelation, UnionBuilder, all_members, binding_type, todo_type, + walk_generic_context, walk_type_mapping, }; use crate::{Db, FxOrderSet, ModuleName, resolve_module}; +fn return_type_cycle_recover<'db>( + _db: &'db dyn Db, + _value: &Type<'db>, + _count: u32, + _self: FunctionType<'db>, +) -> salsa::CycleRecoveryAction> { + salsa::CycleRecoveryAction::Iterate +} + +fn return_type_cycle_initial<'db>(db: &'db dyn Db, function: FunctionType<'db>) -> Type<'db> { + Type::divergent(function.literal(db).last_definition(db).body_scope(db)) +} + /// A collection of useful spans for annotating functions. /// /// This can be retrieved via `FunctionType::spans` or @@ -576,7 +591,7 @@ impl<'db> FunctionLiteral<'db> { self.last_definition(db).spans(db) } - #[salsa::tracked(returns(ref), heap_size=ruff_memory_usage::heap_size)] + #[salsa::tracked(returns(ref), cycle_fn=overloads_and_implementation_cycle_recover, cycle_initial=overloads_and_implementation_cycle_initial, heap_size=ruff_memory_usage::heap_size)] fn overloads_and_implementation( self, db: &'db dyn Db, @@ -682,6 +697,33 @@ impl<'db> FunctionLiteral<'db> { .map(|ctx| ctx.normalized_impl(db, visitor)); Self::new(db, self.last_definition(db), context) } + + fn recursive_type_normalized( + self, + db: &'db dyn Db, + visitor: &RecursiveTypeNormalizedVisitor<'db>, + ) -> Self { + let context = self + .inherited_generic_context(db) + .map(|ctx| ctx.recursive_type_normalized(db, visitor)); + Self::new(db, self.last_definition(db), context) + } +} + +fn overloads_and_implementation_cycle_recover<'db>( + _db: &'db dyn Db, + _value: &(Box<[OverloadLiteral<'db>]>, Option>), + _count: u32, + _function: FunctionLiteral<'db>, +) -> salsa::CycleRecoveryAction<(Box<[OverloadLiteral<'db>]>, Option>)> { + salsa::CycleRecoveryAction::Iterate +} + +fn overloads_and_implementation_cycle_initial<'db>( + _db: &'db dyn Db, + _function: FunctionLiteral<'db>, +) -> (Box<[OverloadLiteral<'db>]>, Option>) { + (Box::new([]), None) } /// Represents a function type, which might be a non-generic function, or a specialization of a @@ -1024,6 +1066,31 @@ impl<'db> FunctionType<'db> { .collect(); Self::new(db, self.literal(db).normalized_impl(db, visitor), mappings) } + + pub(super) fn recursive_type_normalized( + self, + db: &'db dyn Db, + visitor: &RecursiveTypeNormalizedVisitor<'db>, + ) -> Self { + let mappings: Box<_> = self + .type_mappings(db) + .iter() + .map(|mapping| mapping.recursive_type_normalized(db, visitor)) + .collect(); + Self::new( + db, + self.literal(db).recursive_type_normalized(db, visitor), + mappings, + ) + } + + /// Infers this function scope's types and returns the inferred return type. + #[salsa::tracked(cycle_fn=return_type_cycle_recover, cycle_initial=return_type_cycle_initial, heap_size=get_size2::heap_size)] + pub(crate) fn infer_return_type(self, db: &'db dyn Db) -> Type<'db> { + let scope = self.literal(db).last_definition(db).body_scope(db); + let inference = infer_scope_types(db, scope); + inference.infer_return_type(db, Type::FunctionLiteral(self)) + } } /// Evaluate an `isinstance` call. Return `Truthiness::AlwaysTrue` if we can definitely infer that diff --git a/crates/ty_python_semantic/src/types/generics.rs b/crates/ty_python_semantic/src/types/generics.rs index 7650a4eac3e38..918ac2c758030 100644 --- a/crates/ty_python_semantic/src/types/generics.rs +++ b/crates/ty_python_semantic/src/types/generics.rs @@ -12,15 +12,15 @@ use crate::semantic_index::definition::Definition; use crate::semantic_index::scope::{FileScopeId, NodeWithScopeKind}; use crate::types::class::ClassType; use crate::types::class_base::ClassBase; -use crate::types::infer::infer_definition_types; use crate::types::instance::{Protocol, ProtocolInstanceType}; use crate::types::signatures::{Parameter, Parameters, Signature}; use crate::types::tuple::{TupleSpec, TupleType, walk_tuple_type}; use crate::types::{ ApplyTypeMappingVisitor, BoundTypeVarInstance, FindLegacyTypeVarsVisitor, HasRelationToVisitor, IsEquivalentVisitor, KnownClass, KnownInstanceType, MaterializationKind, NormalizedVisitor, - Type, TypeMapping, TypeRelation, TypeVarBoundOrConstraints, TypeVarInstance, TypeVarKind, - TypeVarVariance, UnionType, binding_type, declaration_type, + RecursiveTypeNormalizedVisitor, Type, TypeMapping, TypeRelation, TypeVarBoundOrConstraints, + TypeVarInstance, TypeVarKind, TypeVarVariance, UnionType, binding_type, declaration_type, + infer_definition_types, }; use crate::{Db, FxOrderSet}; @@ -411,6 +411,19 @@ impl<'db> GenericContext<'db> { Self::from_typevar_instances(db, variables) } + pub(super) fn recursive_type_normalized( + self, + db: &'db dyn Db, + visitor: &RecursiveTypeNormalizedVisitor<'db>, + ) -> Self { + let variables = self + .variables(db) + .iter() + .map(|bound_typevar| bound_typevar.recursive_type_normalized(db, visitor)); + + Self::from_typevar_instances(db, variables) + } + fn heap_size((variables,): &(FxOrderSet>,)) -> usize { ruff_memory_usage::order_set_heap_size(variables) } @@ -753,6 +766,31 @@ impl<'db> Specialization<'db> { ) } + pub(super) fn recursive_type_normalized( + self, + db: &'db dyn Db, + visitor: &RecursiveTypeNormalizedVisitor<'db>, + ) -> Self { + let types: Box<[_]> = self + .types(db) + .iter() + .map(|ty| ty.recursive_type_normalized(db, visitor)) + .collect(); + let tuple_inner = self + .tuple_inner(db) + .map(|tuple| tuple.recursive_type_normalized(db, visitor)); + let context = self + .generic_context(db) + .recursive_type_normalized(db, visitor); + Self::new( + db, + context, + types, + self.materialization_kind(db), + tuple_inner, + ) + } + pub(super) fn materialize_impl( self, db: &'db dyn Db, @@ -1012,6 +1050,24 @@ impl<'db> PartialSpecialization<'_, 'db> { types, } } + + pub(super) fn recursive_type_normalized( + &self, + db: &'db dyn Db, + visitor: &RecursiveTypeNormalizedVisitor<'db>, + ) -> PartialSpecialization<'db, 'db> { + let generic_context = self.generic_context.recursive_type_normalized(db, visitor); + let types: Cow<_> = self + .types + .iter() + .map(|ty| ty.recursive_type_normalized(db, visitor)) + .collect(); + + PartialSpecialization { + generic_context, + types, + } + } } /// Performs type inference between parameter annotations and argument types, producing a diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index 22aeaba771557..486951aab1720 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -48,11 +48,14 @@ use crate::semantic_index::ast_ids::node_key::ExpressionNodeKey; use crate::semantic_index::definition::Definition; use crate::semantic_index::expression::Expression; use crate::semantic_index::scope::ScopeId; -use crate::semantic_index::{SemanticIndex, semantic_index}; +use crate::semantic_index::{SemanticIndex, semantic_index, use_def_map}; use crate::types::diagnostic::TypeCheckDiagnostics; use crate::types::generics::Specialization; use crate::types::unpacker::{UnpackResult, Unpacker}; -use crate::types::{ClassLiteral, KnownClass, Truthiness, Type, TypeAndQualifiers}; +use crate::types::{ + ClassLiteral, KnownClass, RecursiveTypeNormalizedVisitor, Truthiness, Type, TypeAndQualifiers, + UnionBuilder, +}; use crate::unpack::Unpack; use builder::TypeInferenceBuilder; @@ -83,10 +86,14 @@ pub(crate) fn infer_scope_types<'db>(db: &'db dyn Db, scope: ScopeId<'db>) -> Sc fn scope_cycle_recover<'db>( _db: &'db dyn Db, _value: &ScopeInference<'db>, - _count: u32, - _scope: ScopeId<'db>, + count: u32, + scope: ScopeId<'db>, ) -> salsa::CycleRecoveryAction> { - salsa::CycleRecoveryAction::Iterate + if count == ITERATIONS_BEFORE_FALLBACK { + salsa::CycleRecoveryAction::Fallback(ScopeInference::cycle_fallback(scope)) + } else { + salsa::CycleRecoveryAction::Iterate + } } fn scope_cycle_initial<'db>(_db: &'db dyn Db, scope: ScopeId<'db>) -> ScopeInference<'db> { @@ -190,6 +197,7 @@ pub(crate) fn infer_expression_types<'db>( infer_expression_types_impl(db, InferExpression::new(db, expression, tcx)) } +/// When using types ​​in [`ExpressionInference`], you must use [`ExpressionInference::cycle_recovery`]. #[salsa::tracked(returns(ref), cycle_fn=expression_cycle_recover, cycle_initial=expression_cycle_initial, heap_size=ruff_memory_usage::heap_size)] fn infer_expression_types_impl<'db>( db: &'db dyn Db, @@ -549,6 +557,8 @@ pub(crate) struct ScopeInference<'db> { /// The extra data that is only present for few inference regions. extra: Option>>, + + scope: ScopeId<'db>, } #[derive(Debug, Eq, PartialEq, get_size2::GetSize, salsa::Update, Default)] @@ -558,13 +568,17 @@ struct ScopeInferenceExtra<'db> { /// The diagnostics for this region. diagnostics: TypeCheckDiagnostics, + + /// The returnees of this region (if this is a function body). + /// + /// These are stored in `Vec` to delay the creation of the union type as long as possible. + returnees: Vec>, } impl<'db> ScopeInference<'db> { fn cycle_initial(scope: ScopeId<'db>) -> Self { - let _ = scope; - Self { + scope, extra: Some(Box::new(ScopeInferenceExtra { cycle_recovery: Some(CycleRecovery::Initial), ..ScopeInferenceExtra::default() @@ -573,6 +587,17 @@ impl<'db> ScopeInference<'db> { } } + fn cycle_fallback(scope: ScopeId<'db>) -> Self { + Self { + scope, + extra: Some(Box::new(ScopeInferenceExtra { + cycle_recovery: Some(CycleRecovery::Divergent(scope)), + ..ScopeInferenceExtra::default() + })), + expressions: FxHashMap::default(), + } + } + pub(crate) fn diagnostics(&self) -> Option<&TypeCheckDiagnostics> { self.extra.as_deref().map(|extra| &extra.diagnostics) } @@ -597,6 +622,84 @@ impl<'db> ScopeInference<'db> { .as_ref() .and_then(|extra| extra.cycle_recovery.map(CycleRecovery::fallback_type)) } + + /// When using `ScopeInference` during type inference, + /// use this method to get the cycle recovery type so that divergent types are propagated. + pub(super) fn cycle_recovery(&self) -> Option> { + self.fallback_type() + } + + /// Returns the inferred return type of this function body (union of all possible return types), + /// or `None` if the region is not a function body. + /// In the case of methods, the return type of the superclass method is further unioned. + /// If there is no superclass method and this method is not `final`, it will be unioned with `Unknown`. + pub(crate) fn infer_return_type(&self, db: &'db dyn Db, callee_ty: Type<'db>) -> Type<'db> { + // TODO: coroutine function type inference + // TODO: generator function type inference + if self.scope.is_coroutine_function(db) || self.scope.is_generator_function(db) { + return Type::unknown(); + } + + let mut union = UnionBuilder::new(db); + let div = Type::divergent(self.scope); + if let Some(cycle_recovery) = self.cycle_recovery() { + union = union.add(cycle_recovery); + } + let visitor = RecursiveTypeNormalizedVisitor::new(div); + // Here, we use the dynamic type `Divergent` to detect divergent type inference and ensure that we obtain finite results. + // For example, consider the following recursive function: + // ```py + // def div(n: int): + // if n == 0: + // return None + // else: + // return (div(n-1),) + // ``` + // If we try to infer the return type of this function naively, we will get `tuple[tuple[tuple[...] | None] | None] | None`, which never converges. + // So, when we detect a cycle, we set the cycle initial type to `Divergent`. Then the type obtained in the first cycle is `tuple[Divergent] | None`. + // Let's call such a type containing `Divergent` a "recursive type". + // Next, if there is a type containing a recursive type (let's call this a nested recursive type), we replace the inner recursive type with the `Divergent` type. + // All recursive types are flattened in the next cycle, resulting in a convergence of the return type in finite cycles. + // 0th: Divergent + // 1st: tuple[Divergent] | None + // 2nd: tuple[tuple[Divergent] | None] | None => tuple[Divergent] | None + if let Some(previous_cycle_value) = callee_ty.infer_return_type(db) { + // In fixed-point iteration of return type inference, the return type must be monotonically widened and not "oscillate". + // Here, monotonicity is guaranteed by pre-unioning the type of the previous iteration into the current result. + union = union.add(previous_cycle_value.recursive_type_normalized(db, &visitor)); + } + + let Some(extra) = &self.extra else { + unreachable!( + "infer_return_type should only be called on a function body scope inference" + ); + }; + for returnee in &extra.returnees { + let ty = returnee.map_or(Type::none(db), |expression| { + self.expression_type(expression) + }); + union = union.add(ty.recursive_type_normalized(db, &visitor)); + } + let use_def = use_def_map(db, self.scope); + if use_def.can_implicitly_return_none(db) { + union = union.add(Type::none(db)); + } + if let Type::BoundMethod(method_ty) = callee_ty { + // If the method is not final and the typing is implicit, the inferred return type should be unioned with `Unknown`. + // If any method in a base class does not have an annotated return type, `base_return_type` will include `Unknown`. + // On the other hand, if the return types of all methods in the base classes are annotated, there is no need to include `Unknown`. + if !method_ty.is_final(db) { + union = union.add( + method_ty + .base_return_type(db) + .unwrap_or(Type::unknown()) + .recursive_type_normalized(db, &visitor), + ); + } + } + + union.build() + } } /// The inferred types for a definition region. @@ -743,6 +846,13 @@ impl<'db> DefinitionInference<'db> { .and_then(|extra| extra.cycle_recovery.map(CycleRecovery::fallback_type)) } + /// When using `DefinitionInference` during type inference, + /// use this method to get the cycle recovery type so that divergent types are propagated. + #[allow(unused)] + pub(super) fn cycle_recovery(&self) -> Option> { + self.fallback_type() + } + pub(crate) fn undecorated_type(&self) -> Option> { self.extra.as_ref().and_then(|extra| extra.undecorated_type) } @@ -831,6 +941,13 @@ impl<'db> ExpressionInference<'db> { .and_then(|extra| extra.cycle_recovery.map(CycleRecovery::fallback_type)) } + /// When using `ExpressionInference` during type inference, + /// use this method to get the cycle recovery type so that divergent types are propagated. + #[allow(unused)] + pub(super) fn cycle_recovery(&self) -> Option> { + self.fallback_type() + } + /// Returns true if all places in this expression are definitely bound. pub(crate) fn all_places_definitely_bound(&self) -> bool { self.extra diff --git a/crates/ty_python_semantic/src/types/infer/builder.rs b/crates/ty_python_semantic/src/types/infer/builder.rs index 6f20e78272351..0694286e35ba3 100644 --- a/crates/ty_python_semantic/src/types/infer/builder.rs +++ b/crates/ty_python_semantic/src/types/infer/builder.rs @@ -11,10 +11,10 @@ use ruff_text_size::{Ranged, TextRange}; use rustc_hash::{FxHashMap, FxHashSet}; use super::{ - CycleRecovery, DefinitionInference, DefinitionInferenceExtra, ExpressionInference, - ExpressionInferenceExtra, InferenceRegion, ScopeInference, ScopeInferenceExtra, - infer_deferred_types, infer_definition_types, infer_expression_types, - infer_same_file_expression_type, infer_scope_types, infer_unpack_types, + DefinitionInference, DefinitionInferenceExtra, ExpressionInference, ExpressionInferenceExtra, + InferenceRegion, ScopeInference, ScopeInferenceExtra, infer_deferred_types, + infer_definition_types, infer_expression_types, infer_same_file_expression_type, + infer_unpack_types, }; use crate::module_name::{ModuleName, ModuleNameResolutionError}; use crate::module_resolver::{ @@ -75,11 +75,13 @@ use crate::types::diagnostic::{ use crate::types::function::{ FunctionDecorators, FunctionLiteral, FunctionType, KnownFunction, OverloadLiteral, }; -use crate::types::generics::{GenericContext, bind_typevar}; -use crate::types::generics::{LegacyGenericBase, SpecializationBuilder}; +use crate::types::generics::{ + GenericContext, LegacyGenericBase, SpecializationBuilder, bind_typevar, +}; +use crate::types::infer::CycleRecovery; use crate::types::instance::SliceLiteral; use crate::types::mro::MroErrorKind; -use crate::types::signatures::Signature; +use crate::types::signatures::{Parameter, Parameters, Signature}; use crate::types::subclass_of::SubclassOfInner; use crate::types::tuple::{Tuple, TupleLength, TupleSpec, TupleType}; use crate::types::typed_dict::{ @@ -90,11 +92,11 @@ use crate::types::visitor::any_over_type; use crate::types::{ BoundTypeVarInstance, CallDunderError, CallableType, ClassLiteral, ClassType, DataclassParams, DynamicType, IntersectionBuilder, IntersectionType, KnownClass, KnownInstanceType, - MemberLookupPolicy, MetaclassCandidate, PEP695TypeAliasType, Parameter, ParameterForm, - Parameters, SpecialFormType, SubclassOfType, TrackedConstraintSet, Truthiness, Type, - TypeAliasType, TypeAndQualifiers, TypeContext, TypeMapping, TypeQualifiers, - TypeVarBoundOrConstraintsEvaluation, TypeVarDefaultEvaluation, TypeVarInstance, TypeVarKind, - UnionBuilder, UnionType, binding_type, todo_type, + MemberLookupPolicy, MetaclassCandidate, PEP695TypeAliasType, ParameterForm, SpecialFormType, + SubclassOfType, TrackedConstraintSet, Truthiness, Type, TypeAliasType, TypeAndQualifiers, + TypeContext, TypeMapping, TypeQualifiers, TypeVarBoundOrConstraintsEvaluation, + TypeVarDefaultEvaluation, TypeVarInstance, TypeVarKind, UnionBuilder, UnionType, binding_type, + infer_scope_types, todo_type, }; use crate::types::{ClassBase, add_inferred_python_version_hint_to_diagnostic}; use crate::unpack::{EvaluationMode, UnpackPosition}; @@ -113,8 +115,8 @@ enum IntersectionOn { } #[derive(Debug, Clone, Copy, Eq, PartialEq)] -struct TypeAndRange<'db> { - ty: Type<'db>, +struct Returnee { + expression: Option, range: TextRange, } @@ -216,8 +218,10 @@ pub(super) struct TypeInferenceBuilder<'db, 'ast> { /// The list should only contain one entry per deferred. deferred: VecSet>, - /// The returned types and their corresponding ranges of the region, if it is a function body. - return_types_and_ranges: Vec>, + /// The returnees of this region (if this is a function body). + /// + /// These are stored in `Vec` to delay the creation of the union type as long as possible. + returnees: Vec, /// A set of functions that have been defined **and** called in this region. /// @@ -288,7 +292,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { context: InferContext::new(db, scope, module), index, region, - return_types_and_ranges: vec![], + returnees: vec![], called_functions: FxHashSet::default(), deferred_state: DeferredExpressionState::None, scope, @@ -410,11 +414,11 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } /// Get the already-inferred type of an expression node, or Unknown. - fn expression_type(&self, expr: &ast::Expr) -> Type<'db> { + fn expression_type(&self, expr: impl Into) -> Type<'db> { self.try_expression_type(expr).unwrap_or_else(Type::unknown) } - fn try_expression_type(&self, expr: &ast::Expr) -> Option> { + fn try_expression_type(&self, expr: impl Into) -> Option> { self.expressions .get(&expr.into()) .copied() @@ -1530,7 +1534,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } // In the following cases, the bound type may not be the same as the RHS value type. if let AnyNodeRef::ExprAttribute(ast::ExprAttribute { value, attr, .. }) = node { - let value_ty = self.try_expression_type(value).unwrap_or_else(|| { + let value_ty = self.try_expression_type(value.as_ref()).unwrap_or_else(|| { self.infer_maybe_standalone_expression(value, TypeContext::default()) }); // If the member is a data descriptor, the RHS value may differ from the value actually assigned. @@ -1544,7 +1548,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } } else if let AnyNodeRef::ExprSubscript(ast::ExprSubscript { value, .. }) = node { let value_ty = self - .try_expression_type(value) + .try_expression_type(value.as_ref()) .unwrap_or_else(|| self.infer_expression(value, TypeContext::default())); if !value_ty.is_typed_dict() && !is_safe_mutable_class(db, value_ty) { @@ -1707,9 +1711,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { ); } - fn record_return_type(&mut self, ty: Type<'db>, range: TextRange) { - self.return_types_and_ranges - .push(TypeAndRange { ty, range }); + fn record_returnee(&mut self, expression: Option, range: TextRange) { + self.returnees.push(Returnee { expression, range }); } fn infer_module(&mut self, module: &ast::ModModule) { @@ -1888,8 +1891,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } } - let has_empty_body = - self.return_types_and_ranges.is_empty() && is_stub_suite(&function.body); + let has_empty_body = self.returnees.is_empty() && is_stub_suite(&function.body); let mut enclosing_class_context = None; @@ -1945,28 +1947,34 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { return; } - for invalid in self - .return_types_and_ranges + for (invalid_ty, range) in self + .returnees .iter() .copied() - .filter_map(|ty_range| match ty_range.ty { - // We skip `is_assignable_to` checks for `NotImplemented`, - // so we remove it beforehand. - Type::Union(union) => Some(TypeAndRange { - ty: union.filter(self.db(), |ty| !ty.is_notimplemented(self.db())), - range: ty_range.range, - }), - ty if ty.is_notimplemented(self.db()) => None, - _ => Some(ty_range), + .filter_map(|returnee| { + match returnee + .expression + .map_or(Type::none(self.db()), |expression| { + self.expression_type(expression) + }) { + // We skip `is_assignable_to` checks for `NotImplemented`, + // so we remove it beforehand. + Type::Union(union) => Some(( + union.filter(self.db(), |ty| !ty.is_notimplemented(self.db())), + returnee.range, + )), + ty if ty.is_notimplemented(self.db()) => None, + ty => Some((ty, returnee.range)), + } }) - .filter(|ty_range| !ty_range.ty.is_assignable_to(self.db(), expected_ty)) + .filter(|(ty, _)| !ty.is_assignable_to(self.db(), expected_ty)) { report_invalid_return_type( &self.context, - invalid.range, + range, returns.range(), declared_ty, - invalid.ty, + invalid_ty, ); } if self @@ -1975,7 +1983,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { .can_implicitly_return_none(self.db()) && !Type::none(self.db()).is_assignable_to(self.db(), expected_ty) { - let no_return = self.return_types_and_ranges.is_empty(); + let no_return = self.returnees.is_empty(); report_implicit_return_type( &self.context, returns.range(), @@ -4754,17 +4762,16 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } fn infer_return_statement(&mut self, ret: &ast::StmtReturn) { - if let Some(ty) = - self.infer_optional_expression(ret.value.as_deref(), TypeContext::default()) - { - let range = ret - .value - .as_ref() - .map_or(ret.range(), |value| value.range()); - self.record_return_type(ty, range); - } else { - self.record_return_type(Type::none(self.db()), ret.range()); - } + self.infer_optional_expression(ret.value.as_deref(), TypeContext::default()); + let range = ret + .value + .as_ref() + .map_or(ret.range(), |value| value.range()); + let expression = ret + .value + .as_ref() + .map(|expr| ExpressionNodeKey::from(&**expr)); + self.record_returnee(expression, range); } fn infer_delete_statement(&mut self, delete: &ast::StmtDelete) { @@ -5844,7 +5851,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { // Special handling for `TypedDict` method calls if let ast::Expr::Attribute(ast::ExprAttribute { value, attr, .. }) = func.as_ref() { - let value_type = self.expression_type(value); + let value_type = self.expression_type(value.as_ref()); if let Type::TypedDict(typed_dict_ty) = value_type { if matches!(attr.id.as_str(), "pop" | "setdefault") && !arguments.args.is_empty() { // Validate the key argument for `TypedDict` methods @@ -9104,7 +9111,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { called_functions: _, index: _, region: _, - return_types_and_ranges: _, + returnees: _, } = self; let diagnostics = context.finish(); @@ -9166,7 +9173,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { called_functions: _, index: _, region: _, - return_types_and_ranges: _, + returnees: _, } = self; let _ = scope; @@ -9215,6 +9222,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { pub(super) fn finish_scope(mut self) -> ScopeInference<'db> { self.infer_region(); + let db = self.db(); let Self { context, @@ -9237,22 +9245,33 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { called_functions: _, index: _, region: _, - return_types_and_ranges: _, + returnees, } = self; - let _ = scope; let diagnostics = context.finish(); - let extra = (!diagnostics.is_empty() || cycle_recovery.is_some()).then(|| { + let extra = (!diagnostics.is_empty() + || cycle_recovery.is_some() + || scope.is_non_lambda_function(db)) + .then(|| { + let returnees = returnees + .into_iter() + .map(|returnee| returnee.expression) + .collect(); Box::new(ScopeInferenceExtra { cycle_recovery, diagnostics, + returnees, }) }); expressions.shrink_to_fit(); - ScopeInference { expressions, extra } + ScopeInference { + expressions, + extra, + scope, + } } } diff --git a/crates/ty_python_semantic/src/types/infer/builder/type_expression.rs b/crates/ty_python_semantic/src/types/infer/builder/type_expression.rs index c0c0483c7ed85..e48d0582017a6 100644 --- a/crates/ty_python_semantic/src/types/infer/builder/type_expression.rs +++ b/crates/ty_python_semantic/src/types/infer/builder/type_expression.rs @@ -7,14 +7,14 @@ use crate::types::diagnostic::{ report_invalid_arguments_to_annotated, report_invalid_arguments_to_callable, }; use crate::types::enums::is_enum_class; -use crate::types::signatures::{CallableSignature, Signature}; +use crate::types::signatures::{CallableSignature, Parameter, Parameters, Signature}; use crate::types::string_annotation::parse_string_annotation; use crate::types::tuple::{TupleSpecBuilder, TupleType}; use crate::types::visitor::any_over_type; use crate::types::{ CallableType, DynamicType, IntersectionBuilder, KnownClass, KnownInstanceType, - LintDiagnosticGuard, Parameter, Parameters, SpecialFormType, SubclassOfType, Type, - TypeAliasType, TypeContext, TypeIsType, UnionBuilder, UnionType, todo_type, + LintDiagnosticGuard, SpecialFormType, SubclassOfType, Type, TypeAliasType, TypeContext, + TypeIsType, UnionBuilder, UnionType, todo_type, }; /// Type expressions @@ -550,7 +550,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> { // we do not store types for sub-expressions. Re-infer the type here. builder.infer_expression(value, TypeContext::default()) } else { - builder.expression_type(value) + builder.expression_type(value.as_ref()) }; value_ty == Type::SpecialForm(SpecialFormType::Unpack) diff --git a/crates/ty_python_semantic/src/types/instance.rs b/crates/ty_python_semantic/src/types/instance.rs index 87649cb6f4b71..e634c1eac174f 100644 --- a/crates/ty_python_semantic/src/types/instance.rs +++ b/crates/ty_python_semantic/src/types/instance.rs @@ -14,8 +14,8 @@ use crate::types::protocol_class::walk_protocol_interface; use crate::types::tuple::{TupleSpec, TupleType}; use crate::types::{ ApplyTypeMappingVisitor, ClassBase, ClassLiteral, FindLegacyTypeVarsVisitor, - HasRelationToVisitor, IsDisjointVisitor, IsEquivalentVisitor, NormalizedVisitor, TypeMapping, - TypeRelation, VarianceInferable, + HasRelationToVisitor, IsDisjointVisitor, IsEquivalentVisitor, NormalizedVisitor, + RecursiveTypeNormalizedVisitor, TypeMapping, TypeRelation, VarianceInferable, }; use crate::{Db, FxOrderSet}; @@ -337,6 +337,22 @@ impl<'db> NominalInstanceType<'db> { } } + pub(super) fn recursive_type_normalized( + self, + db: &'db dyn Db, + visitor: &RecursiveTypeNormalizedVisitor<'db>, + ) -> Self { + match self.0 { + NominalInstanceInner::ExactTuple(tuple) => Self(NominalInstanceInner::ExactTuple( + tuple.recursive_type_normalized(db, visitor), + )), + NominalInstanceInner::NonTuple(class) => Self(NominalInstanceInner::NonTuple( + class.recursive_type_normalized(db, visitor), + )), + NominalInstanceInner::Object => Self(NominalInstanceInner::Object), + } + } + pub(super) fn has_relation_to_impl( self, db: &'db dyn Db, @@ -649,6 +665,14 @@ impl<'db> ProtocolInstanceType<'db> { } } + pub(super) fn recursive_type_normalized( + self, + _db: &'db dyn Db, + _visitor: &RecursiveTypeNormalizedVisitor<'db>, + ) -> Self { + self + } + /// Return `true` if this protocol type is equivalent to the protocol `other`. /// /// TODO: consider the types of the members as well as their existence diff --git a/crates/ty_python_semantic/src/types/protocol_class.rs b/crates/ty_python_semantic/src/types/protocol_class.rs index b877db65920a3..99abcc9c61136 100644 --- a/crates/ty_python_semantic/src/types/protocol_class.rs +++ b/crates/ty_python_semantic/src/types/protocol_class.rs @@ -507,7 +507,7 @@ impl<'a, 'db> ProtocolMember<'a, 'db> { self.qualifiers } - fn ty(&self) -> Type<'db> { + pub(super) fn ty(&self) -> Type<'db> { match &self.kind { ProtocolMemberKind::Method(callable) => Type::Callable(*callable), ProtocolMemberKind::Property(property) => Type::PropertyInstance(*property), diff --git a/crates/ty_python_semantic/src/types/signatures.rs b/crates/ty_python_semantic/src/types/signatures.rs index ddc06951fc8b3..711ec9b23f3b9 100644 --- a/crates/ty_python_semantic/src/types/signatures.rs +++ b/crates/ty_python_semantic/src/types/signatures.rs @@ -22,7 +22,7 @@ use crate::types::generics::{GenericContext, walk_generic_context}; use crate::types::{ ApplyTypeMappingVisitor, BindingContext, BoundTypeVarInstance, FindLegacyTypeVarsVisitor, HasRelationToVisitor, IsEquivalentVisitor, KnownClass, MaterializationKind, NormalizedVisitor, - TypeMapping, TypeRelation, VarianceInferable, todo_type, + RecursiveTypeNormalizedVisitor, TypeMapping, TypeRelation, VarianceInferable, todo_type, }; use crate::{Db, FxOrderSet}; use ruff_python_ast::{self as ast, name::Name}; @@ -74,6 +74,18 @@ impl<'db> CallableSignature<'db> { ) } + pub(super) fn recursive_type_normalized( + &self, + db: &'db dyn Db, + visitor: &RecursiveTypeNormalizedVisitor<'db>, + ) -> Self { + Self::from_overloads( + self.overloads + .iter() + .map(|signature| signature.recursive_type_normalized(db, visitor)), + ) + } + pub(crate) fn apply_type_mapping_impl<'a>( &self, db: &'db dyn Db, @@ -451,6 +463,30 @@ impl<'db> Signature<'db> { } } + pub(super) fn recursive_type_normalized( + &self, + db: &'db dyn Db, + visitor: &RecursiveTypeNormalizedVisitor<'db>, + ) -> Self { + Self { + generic_context: self + .generic_context + .map(|ctx| ctx.recursive_type_normalized(db, visitor)), + inherited_generic_context: self + .inherited_generic_context + .map(|ctx| ctx.recursive_type_normalized(db, visitor)), + definition: self.definition, + parameters: self + .parameters + .iter() + .map(|param| param.recursive_type_normalized(db, visitor)) + .collect(), + return_ty: self + .return_ty + .map(|return_ty| return_ty.recursive_type_normalized(db, visitor)), + } + } + pub(crate) fn apply_type_mapping<'a>( &self, db: &'db dyn Db, @@ -1552,6 +1588,47 @@ impl<'db> Parameter<'db> { } } + pub(super) fn recursive_type_normalized( + &self, + db: &'db dyn Db, + visitor: &RecursiveTypeNormalizedVisitor<'db>, + ) -> Self { + let Parameter { + annotated_type, + kind, + form, + } = self; + + let annotated_type = annotated_type.map(|ty| ty.recursive_type_normalized(db, visitor)); + + let kind = match kind { + ParameterKind::PositionalOnly { name, default_type } => ParameterKind::PositionalOnly { + name: name.clone(), + default_type: default_type.map(|ty| ty.recursive_type_normalized(db, visitor)), + }, + ParameterKind::PositionalOrKeyword { name, default_type } => { + ParameterKind::PositionalOrKeyword { + name: name.clone(), + default_type: default_type.map(|ty| ty.recursive_type_normalized(db, visitor)), + } + } + ParameterKind::KeywordOnly { name, default_type } => ParameterKind::KeywordOnly { + name: name.clone(), + default_type: default_type.map(|ty| ty.recursive_type_normalized(db, visitor)), + }, + ParameterKind::Variadic { name } => ParameterKind::Variadic { name: name.clone() }, + ParameterKind::KeywordVariadic { name } => { + ParameterKind::KeywordVariadic { name: name.clone() } + } + }; + + Self { + annotated_type, + kind, + form: *form, + } + } + fn from_node_and_kind( db: &'db dyn Db, definition: Definition<'db>, diff --git a/crates/ty_python_semantic/src/types/subclass_of.rs b/crates/ty_python_semantic/src/types/subclass_of.rs index 1b0db0c808bdd..29e906ef0ff2d 100644 --- a/crates/ty_python_semantic/src/types/subclass_of.rs +++ b/crates/ty_python_semantic/src/types/subclass_of.rs @@ -5,8 +5,8 @@ use crate::types::variance::VarianceInferable; use crate::types::{ ApplyTypeMappingVisitor, BoundTypeVarInstance, ClassType, DynamicType, FindLegacyTypeVarsVisitor, HasRelationToVisitor, IsDisjointVisitor, KnownClass, - MaterializationKind, MemberLookupPolicy, NormalizedVisitor, SpecialFormType, Type, TypeMapping, - TypeRelation, + MaterializationKind, MemberLookupPolicy, NormalizedVisitor, RecursiveTypeNormalizedVisitor, + SpecialFormType, Type, TypeMapping, TypeRelation, }; use crate::{Db, FxOrderSet}; @@ -181,6 +181,16 @@ impl<'db> SubclassOfType<'db> { } } + pub(super) fn recursive_type_normalized( + self, + db: &'db dyn Db, + visitor: &RecursiveTypeNormalizedVisitor<'db>, + ) -> Self { + Self { + subclass_of: self.subclass_of.recursive_type_normalized(db, visitor), + } + } + pub(crate) fn to_instance(self, db: &'db dyn Db) -> Type<'db> { match self.subclass_of { SubclassOfInner::Class(class) => Type::instance(db, class), @@ -254,6 +264,17 @@ impl<'db> SubclassOfInner<'db> { } } + pub(super) fn recursive_type_normalized( + self, + db: &'db dyn Db, + visitor: &RecursiveTypeNormalizedVisitor<'db>, + ) -> Self { + match self { + Self::Class(class) => Self::Class(class.recursive_type_normalized(db, visitor)), + Self::Dynamic(dynamic) => Self::Dynamic(dynamic.recursive_type_normalized()), + } + } + pub(crate) fn try_from_type(db: &'db dyn Db, ty: Type<'db>) -> Option { match ty { Type::Dynamic(dynamic) => Some(Self::Dynamic(dynamic)), diff --git a/crates/ty_python_semantic/src/types/tuple.rs b/crates/ty_python_semantic/src/types/tuple.rs index 2f10e79949716..48094c005e9de 100644 --- a/crates/ty_python_semantic/src/types/tuple.rs +++ b/crates/ty_python_semantic/src/types/tuple.rs @@ -27,8 +27,8 @@ use crate::types::class::{ClassType, KnownClass}; use crate::types::constraints::{ConstraintSet, IteratorConstraintsExtension}; use crate::types::{ ApplyTypeMappingVisitor, BoundTypeVarInstance, FindLegacyTypeVarsVisitor, HasRelationToVisitor, - IsDisjointVisitor, IsEquivalentVisitor, NormalizedVisitor, Type, TypeMapping, TypeRelation, - UnionBuilder, UnionType, + IsDisjointVisitor, IsEquivalentVisitor, NormalizedVisitor, RecursiveTypeNormalizedVisitor, + Type, TypeMapping, TypeRelation, UnionBuilder, UnionType, }; use crate::util::subscript::{Nth, OutOfBoundsError, PyIndex, PySlice, StepSizeZeroError}; use crate::{Db, FxOrderSet, Program}; @@ -228,6 +228,14 @@ impl<'db> TupleType<'db> { TupleType::new(db, &self.tuple(db).normalized_impl(db, visitor)) } + pub(super) fn recursive_type_normalized( + self, + db: &'db dyn Db, + visitor: &RecursiveTypeNormalizedVisitor<'db>, + ) -> Self { + Self::new_internal(db, self.tuple(db).recursive_type_normalized(db, visitor)) + } + pub(crate) fn apply_type_mapping_impl<'a>( self, db: &'db dyn Db, @@ -386,6 +394,18 @@ impl<'db> FixedLengthTuple> { Self::from_elements(self.0.iter().map(|ty| ty.normalized_impl(db, visitor))) } + fn recursive_type_normalized( + &self, + db: &'db dyn Db, + visitor: &RecursiveTypeNormalizedVisitor<'db>, + ) -> Self { + Self::from_elements( + self.0 + .iter() + .map(|ty| ty.recursive_type_normalized(db, visitor)), + ) + } + fn apply_type_mapping_impl<'a>( &self, db: &'db dyn Db, @@ -703,6 +723,29 @@ impl<'db> VariableLengthTuple> { }) } + fn recursive_type_normalized( + &self, + db: &'db dyn Db, + visitor: &RecursiveTypeNormalizedVisitor<'db>, + ) -> Self { + let prefix = self + .prefix + .iter() + .map(|ty| ty.recursive_type_normalized(db, visitor)) + .collect::>(); + let suffix = self + .suffix + .iter() + .map(|ty| ty.recursive_type_normalized(db, visitor)) + .collect::>(); + let variable = self.variable.recursive_type_normalized(db, visitor); + Self { + prefix, + variable, + suffix, + } + } + fn apply_type_mapping_impl<'a>( &self, db: &'db dyn Db, @@ -1044,6 +1087,17 @@ impl<'db> Tuple> { } } + pub(super) fn recursive_type_normalized( + &self, + db: &'db dyn Db, + visitor: &RecursiveTypeNormalizedVisitor<'db>, + ) -> Self { + match self { + Tuple::Fixed(tuple) => Tuple::Fixed(tuple.recursive_type_normalized(db, visitor)), + Tuple::Variable(tuple) => Tuple::Variable(tuple.recursive_type_normalized(db, visitor)), + } + } + pub(crate) fn apply_type_mapping_impl<'a>( &self, db: &'db dyn Db,