From a3abde4ef50e4862b435d6ce1e6bb605dd64ac4f Mon Sep 17 00:00:00 2001 From: David Peter Date: Tue, 15 Apr 2025 16:08:01 +0200 Subject: [PATCH] [red-knot] Dataclasses: Proper `__init__` signature --- .../resources/mdtest/dataclasses.md | 359 ++++++++++++++++-- .../src/semantic_index/use_def.rs | 9 + .../src/types/class.rs | 232 +++++++++-- 3 files changed, 539 insertions(+), 61 deletions(-) diff --git a/crates/red_knot_python_semantic/resources/mdtest/dataclasses.md b/crates/red_knot_python_semantic/resources/mdtest/dataclasses.md index 842bac0a0858d..000c60031758f 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/dataclasses.md +++ b/crates/red_knot_python_semantic/resources/mdtest/dataclasses.md @@ -40,19 +40,191 @@ The signature of the `__init__` method is generated based on the classes attribu calls are not valid: ```py -# TODO: should be an error: too few arguments +# error: [missing-argument] Person() -# TODO: should be an error: too many arguments +# error: [too-many-positional-arguments] Person("Eve", 20, "too many arguments") -# TODO: should be an error: wrong argument type +# error: [invalid-argument-type] Person("Eve", "string instead of int") -# TODO: should be an error: wrong argument types +# error: [invalid-argument-type] +# error: [invalid-argument-type] Person(20, "Eve") ``` +## Signature of `__init__` + +TODO: All of the following tests are missing the `self` argument in the `__init__` signature. + +Declarations in the class body are used to generate the signature of the `__init__` method. If the +attributes are not just declarations, but also bindings, the type inferred from bindings is used as +the default value. + +```py +from dataclasses import dataclass + +@dataclass +class D: + x: int + y: str = "default" + z: int | None = 1 + 2 + +reveal_type(D.__init__) # revealed: (x: int, y: str = Literal["default"], z: int | None = Literal[3]) -> None +``` + +This also works if the declaration and binding are split: + +```py +@dataclass +class D: + x: int | None + x = None + +reveal_type(D.__init__) # revealed: (x: int | None = None) -> None +``` + +Non-fully static types are handled correctly: + +```py +from typing import Any + +@dataclass +class C: + x: Any + y: int | Any + z: tuple[int, Any] + +reveal_type(C.__init__) # revealed: (x: Any, y: int | Any, z: tuple[int, Any]) -> None +``` + +Variables without annotations are ignored: + +```py +@dataclass +class D: + x: int + y = 1 + +reveal_type(D.__init__) # revealed: (x: int) -> None +``` + +If attributes without default values are declared after attributes with default values, a +`TypeError` will be raised at runtime. Ideally, we would emit a diagnostic in that case: + +```py +@dataclass +class D: + x: int = 1 + # TODO: this should be an error: field without default defined after field with default + y: str +``` + +Pure class attributes (`ClassVar`) are not included in the signature of `__init__`: + +```py +from typing import ClassVar + +@dataclass +class D: + x: int + y: ClassVar[str] = "default" + z: bool + +reveal_type(D.__init__) # revealed: (x: int, z: bool) -> None + +d = D(1, True) +reveal_type(d.x) # revealed: int +reveal_type(d.y) # revealed: str +reveal_type(d.z) # revealed: bool +``` + +Function declarations do not affect the signature of `__init__`: + +```py +@dataclass +class D: + x: int + + def y(self) -> str: + return "" + +reveal_type(D.__init__) # revealed: (x: int) -> None +``` + +And neither do nested class declarations: + +```py +@dataclass +class D: + x: int + + class Nested: + y: str + +reveal_type(D.__init__) # revealed: (x: int) -> None +``` + +But if there is a variable annotation with a function or class literal type, the signature of +`__init__` will include this field: + +```py +from knot_extensions import TypeOf + +class SomeClass: ... + +def some_function() -> None: ... +@dataclass +class D: + function_literal: TypeOf[some_function] + class_literal: TypeOf[SomeClass] + class_subtype_of: type[SomeClass] + +# revealed: (function_literal: def some_function() -> None, class_literal: Literal[SomeClass], class_subtype_of: type[SomeClass]) -> None +reveal_type(D.__init__) +``` + +More realistically, dataclasses can have `Callable` attributes: + +```py +from typing import Callable + +@dataclass +class D: + c: Callable[[int], str] + +reveal_type(D.__init__) # revealed: (c: (int, /) -> str) -> None +``` + +Implicit instance attributes do not affect the signature of `__init__`: + +```py +@dataclass +class D: + x: int + + def f(self, y: str) -> None: + self.y: str = y + +reveal_type(D(1).y) # revealed: str + +reveal_type(D.__init__) # revealed: (x: int) -> None +``` + +Annotating expressions does not lead to an entry in `__annotations__` at runtime, and so it wouldn't +be included in the signature of `__init__`. This is a case that we currently don't detect: + +```py +@dataclass +class D: + # (x) is an expression, not a "simple name" + (x): int = 1 + +# TODO: should ideally not include a `x` parameter +reveal_type(D.__init__) # revealed: (x: int = Literal[1]) -> None +``` + ## `@dataclass` calls with arguments The `@dataclass` decorator can take several arguments to customize the existence of the generated @@ -241,7 +413,8 @@ class Derived(Base): d = Derived("a") -# TODO: should be an error: +# error: [too-many-positional-arguments] +# error: [invalid-argument-type] Derived(1, "a") ``` @@ -253,18 +426,47 @@ from dataclasses import dataclass @dataclass class Base: x: int + y: str @dataclass class Derived(Base): - y: str + z: bool -d = Derived(1, "a") # OK +d = Derived(1, "a", True) # OK reveal_type(d.x) # revealed: int reveal_type(d.y) # revealed: str +reveal_type(d.z) # revealed: bool + +# error: [missing-argument] +Derived(1, "a") + +# error: [missing-argument] +Derived(True) +``` + +### Overwriting attributes from base class + +The following example comes from the +[Python documentation](https://docs.python.org/3/library/dataclasses.html#inheritance). The `x` +attribute appears just once in the `__init__` signature, and the default value is taken from the +derived class + +```py +from dataclasses import dataclass +from typing import Any + +@dataclass +class Base: + x: Any = 15.0 + y: int = 0 + +@dataclass +class C(Base): + z: int = 10 + x: int = 15 -# TODO: should be an error: -Derived("a") +reveal_type(C.__init__) # revealed: (x: int = Literal[15], y: int = Literal[0], z: int = Literal[10]) -> None ``` ## Generic dataclasses @@ -283,33 +485,124 @@ d_int = DataWithDescription[int](1, "description") # OK reveal_type(d_int.data) # revealed: int reveal_type(d_int.description) # revealed: str -# TODO: should be an error: wrong argument type +# error: [invalid-argument-type] DataWithDescription[int](None, "description") ``` ## Descriptor-typed fields +### Same type in `__get__` and `__set__` + +For the following descriptor, the return type of `__get__` and the type of the `value` parameter in +`__set__` are the same. The generated `__init__` method takes an argument of this type (instead of +the type of the descriptor), and the default value is also of this type: + ```py +from typing import overload from dataclasses import dataclass -class Descriptor: - _value: int = 0 +class UppercaseString: + _value: str = "" - def __get__(self, instance, owner) -> str: - return str(self._value) + def __get__(self, instance: object, owner: None | type) -> str: + return self._value - def __set__(self, instance, value: int) -> None: - self._value = value + def __set__(self, instance: object, value: str) -> None: + self._value = value.upper() @dataclass class C: - d: Descriptor = Descriptor() + upper: UppercaseString = UppercaseString() -c = C(1) -reveal_type(c.d) # revealed: str +reveal_type(C.__init__) # revealed: (upper: str = str) -> None -# TODO: should be an error -C("a") +c = C("abc") +reveal_type(c.upper) # revealed: str + +# This is also okay: +C() + +# error: [invalid-argument-type] +C(1) + +# error: [too-many-positional-arguments] +C("a", "b") +``` + +### Different types in `__get__` and `__set__` + +In general, the type of the `__init__` parameter is determined by the `value` parameter type of the +`__set__` method (`str` in the example below). However, the default value is generated by calling +the descriptor's `__get__` method as if it had been called on the class itself, i.e. passing `None` +for the `instance` argument. + +```py +from typing import overload +from dataclasses import dataclass + +class ConvertToLength: + _len: int = 0 + + @overload + def __get__(self, instance: None, owner: type) -> str: ... + @overload + def __get__(self, instance: object, owner: type | None) -> int: ... + def __get__(self, instance: object | None, owner: type | None) -> str | int: + if instance is None: + return "" + + return self._len + + def __set__(self, instance, value: str) -> None: + self._len = len(value) + +@dataclass +class C: + converter: ConvertToLength = ConvertToLength() + +# TODO: Should be `(converter: str = Literal[""]) -> None` once we understand overloads +reveal_type(C.__init__) # revealed: (converter: str = str | int) -> None + +c = C("abc") +# TODO: Should be `int` once we understand overloads +reveal_type(c.converter) # revealed: str | int + +# This is also okay: +C() + +# error: [invalid-argument-type] +C(1) + +# error: [too-many-positional-arguments] +C("a", "b") +``` + +### With overloaded `__set__` method + +If the `__set__` method is overloaded, we determine the type for the `__init__` parameter as the +union of all possible `value` parameter types: + +```py +from typing import overload +from dataclasses import dataclass + +class AcceptsStrAndInt: + def __get__(self, instance, owner) -> int: + return 0 + + @overload + def __set__(self, instance: object, value: str) -> None: ... + @overload + def __set__(self, instance: object, value: int) -> None: ... + def __set__(self, instance: object, value) -> None: + pass + +@dataclass +class C: + field: AcceptsStrAndInt = AcceptsStrAndInt() + +# TODO: Should be `field: str | int = int` once we understand overloads +reveal_type(C.__init__) # revealed: (field: Unknown = int) -> None ``` ## `dataclasses.field` @@ -329,8 +622,7 @@ import dataclasses class C: x: str -# TODO: should show the proper signature -reveal_type(C.__init__) # revealed: (*args: Any, **kwargs: Any) -> None +reveal_type(C.__init__) # revealed: (x: str) -> None ``` ### Dataclass with custom `__init__` method @@ -349,7 +641,7 @@ class C: C(1) # OK -# TODO: should be an error +# error: [invalid-argument-type] C("a") ``` @@ -365,9 +657,20 @@ D(1) # OK D() # error: [missing-argument] ``` -### Dataclass with `ClassVar`s +### Accessing instance attributes on the class itself -To do +Just like for normal classes, accessing instance attributes on the class itself is not allowed: + +```py +from dataclasses import dataclass + +@dataclass +class C: + x: int + +# error: [unresolved-attribute] "Attribute `x` can only be accessed on instances, not on the class object `Literal[C]` itself." +C.x +``` ### Return type of `dataclass(...)` @@ -412,8 +715,8 @@ reveal_type(Person.__mro__) # revealed: tuple[Literal[Person], Literal[object]] The generated methods have the following signatures: ```py -# TODO: proper signature -reveal_type(Person.__init__) # revealed: (*args: Any, **kwargs: Any) -> None +# TODO: `self` is missing here +reveal_type(Person.__init__) # revealed: (name: str, age: int | None = None) -> None reveal_type(Person.__repr__) # revealed: def __repr__(self) -> str diff --git a/crates/red_knot_python_semantic/src/semantic_index/use_def.rs b/crates/red_knot_python_semantic/src/semantic_index/use_def.rs index a52c675cb3924..50b6d804d5809 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/use_def.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/use_def.rs @@ -429,6 +429,14 @@ impl<'db> UseDefMap<'db> { self.declarations_iterator(declarations) } + pub(crate) fn all_public_declarations<'map>( + &'map self, + ) -> impl Iterator)> + 'map { + (0..self.public_symbols.len()) + .map(ScopedSymbolId::from_usize) + .map(|symbol_id| (symbol_id, self.public_declarations(symbol_id))) + } + /// This function is intended to be called only once inside `TypeInferenceBuilder::infer_function_body`. pub(crate) fn can_implicit_return(&self, db: &dyn crate::Db) -> bool { !self @@ -551,6 +559,7 @@ impl<'db> Iterator for ConstraintsIterator<'_, 'db> { impl std::iter::FusedIterator for ConstraintsIterator<'_, '_> {} +#[derive(Clone)] pub(crate) struct DeclarationsIterator<'map, 'db> { all_definitions: &'map IndexVec>>, pub(crate) predicates: &'map Predicates<'db>, diff --git a/crates/red_knot_python_semantic/src/types/class.rs b/crates/red_knot_python_semantic/src/types/class.rs index 57c076595d54e..b690f6ea8c306 100644 --- a/crates/red_knot_python_semantic/src/types/class.rs +++ b/crates/red_knot_python_semantic/src/types/class.rs @@ -1,3 +1,4 @@ +use std::hash::BuildHasherDefault; use std::sync::{LazyLock, Mutex}; use super::{ @@ -6,6 +7,7 @@ use super::{ Type, TypeAliasType, TypeQualifiers, TypeVarInstance, }; use crate::semantic_index::definition::Definition; +use crate::semantic_index::DeclarationWithConstraint; use crate::types::generics::{GenericContext, Specialization}; use crate::types::signatures::{Parameter, Parameters}; use crate::types::{CallableType, DataclassMetadata, Signature}; @@ -34,7 +36,9 @@ use itertools::Itertools as _; use ruff_db::files::File; use ruff_python_ast::name::Name; use ruff_python_ast::{self as ast, PythonVersion}; -use rustc_hash::FxHashSet; +use rustc_hash::{FxHashSet, FxHasher}; + +type FxOrderMap = ordermap::map::OrderMap>; fn explicit_bases_cycle_recover<'db>( _db: &'db dyn Db, @@ -822,38 +826,8 @@ impl<'db> ClassLiteralType<'db> { specialization: Option>, name: &str, ) -> SymbolAndQualifiers<'db> { - if let Some(metadata) = self.dataclass_metadata(db) { - if name == "__init__" && metadata.contains(DataclassMetadata::INIT) { - // TODO: Generate the signature from the attributes on the class - let init_signature = Signature::new( - Parameters::new([ - Parameter::variadic(Name::new_static("args")) - .with_annotated_type(Type::any()), - Parameter::keyword_variadic(Name::new_static("kwargs")) - .with_annotated_type(Type::any()), - ]), - Some(Type::none(db)), - ); - - return Symbol::bound(Type::Callable(CallableType::new(db, init_signature))).into(); - } else if matches!(name, "__lt__" | "__le__" | "__gt__" | "__ge__") { - if metadata.contains(DataclassMetadata::ORDER) { - let signature = Signature::new( - Parameters::new([Parameter::positional_or_keyword(Name::new_static( - "other", - )) - .with_annotated_type(Type::instance( - self.apply_optional_specialization(db, specialization), - ))]), - Some(KnownClass::Bool.to_instance(db)), - ); - return Symbol::bound(Type::Callable(CallableType::new(db, signature))).into(); - } - } - } - let body_scope = self.body_scope(db); - class_symbol(db, body_scope, name).map_type(|ty| { + let symbol = class_symbol(db, body_scope, name).map_type(|ty| { // The `__new__` and `__init__` members of a non-specialized generic class are handled // specially: they inherit the generic context of their class. That lets us treat them // as generic functions when constructing the class, and infer the specialization of @@ -876,7 +850,199 @@ impl<'db> ClassLiteralType<'db> { ), _ => ty, } - }) + }); + + if symbol.symbol.is_unbound() { + if let Some(metadata) = self.dataclass_metadata(db) { + if let Some(dataclass_member) = + self.own_dataclass_member(db, specialization, metadata, name) + { + return Symbol::bound(dataclass_member).into(); + } + } + } + + symbol + } + + /// Returns the type of a synthesized dataclass member like `__init__` or `__lt__`. + fn own_dataclass_member( + self, + db: &'db dyn Db, + specialization: Option>, + metadata: DataclassMetadata, + name: &str, + ) -> Option> { + if name == "__init__" && metadata.contains(DataclassMetadata::INIT) { + let mut parameters = vec![]; + + for (name, (mut attr_ty, mut default_ty)) in self.dataclass_fields(db, specialization) { + // The descriptor handling below is guarded by this fully-static check, because dynamic + // types like `Any` are valid (data) descriptors: since they have all possible attributes, + // they also have a (callable) `__set__` method. The problem is that we can't determine + // the type of the value parameter this way. Instead, we want to use the dynamic type + // itself in this case, so we skip the special descriptor handling. + if attr_ty.is_fully_static(db) { + let dunder_set = attr_ty.class_member(db, "__set__".into()); + if let Some(dunder_set) = dunder_set.symbol.ignore_possibly_unbound() { + // This type of this attribute is a data descriptor. Instead of overwriting the + // descriptor attribute, data-classes will (implicitly) call the `__set__` method + // of the descriptor. This means that the synthesized `__init__` parameter for + // this attribute is determined by possible `value` parameter types with which + // the `__set__` method can be called. We build a union of all possible options + // to account for possible overloads. + let mut value_types = UnionBuilder::new(db); + for signature in &dunder_set.signatures(db) { + for overload in signature { + if let Some(value_param) = overload.parameters().get_positional(2) { + value_types = value_types.add( + value_param.annotated_type().unwrap_or_else(Type::unknown), + ); + } else if overload.parameters().is_gradual() { + value_types = value_types.add(Type::unknown()); + } + } + } + attr_ty = value_types.build(); + + // The default value of the attribute is *not* determined by the right hand side + // of the class-body assignment. Instead, the runtime invokes `__get__` on the + // descriptor, as if it had been called on the class itself, i.e. it passes `None` + // for the `instance` argument. + + if let Some(ref mut default_ty) = default_ty { + *default_ty = default_ty + .try_call_dunder_get(db, Type::none(db), Type::ClassLiteral(self)) + .map(|(return_ty, _)| return_ty) + .unwrap_or_else(Type::unknown); + } + } + } + + let mut parameter = + Parameter::positional_or_keyword(name).with_annotated_type(attr_ty); + + if let Some(default_ty) = default_ty { + parameter = parameter.with_default_type(default_ty); + } + + parameters.push(parameter); + } + + let init_signature = Signature::new(Parameters::new(parameters), Some(Type::none(db))); + + return Some(Type::Callable(CallableType::new(db, init_signature))); + } else if matches!(name, "__lt__" | "__le__" | "__gt__" | "__ge__") { + if metadata.contains(DataclassMetadata::ORDER) { + let signature = Signature::new( + Parameters::new([Parameter::positional_or_keyword(Name::new_static("other")) + // TODO: could be `Self`. + .with_annotated_type(Type::instance( + self.apply_optional_specialization(db, specialization), + ))]), + Some(KnownClass::Bool.to_instance(db)), + ); + + return Some(Type::Callable(CallableType::new(db, signature))); + } + } + + None + } + + /// Returns a list of all annotated attributes defined in this class, or any of its superclasses. + /// + /// See [`ClassLiteralType::own_dataclass_fields`] for more details. + fn dataclass_fields( + self, + db: &'db dyn Db, + specialization: Option>, + ) -> FxOrderMap, Option>)> { + let dataclasses_in_mro: Vec<_> = self + .iter_mro(db, specialization) + .filter_map(|superclass| { + if let Some(class) = superclass.into_class() { + let class_literal = class.class_literal(db).0; + if class_literal.dataclass_metadata(db).is_some() { + Some(class_literal) + } else { + None + } + } else { + None + } + }) + // We need to collect into a `Vec` here because we iterate the MRO in reverse order + .collect(); + + dataclasses_in_mro + .into_iter() + .rev() + .flat_map(|class| class.own_dataclass_fields(db)) + // We collect into a FxOrderMap here to deduplicate attributes + .collect() + } + + /// Returns a list of all annotated attributes defined in the body of this class. This is similar + /// to the `__annotations__` attribute at runtime, but also contains default values. + /// + /// For a class body like + /// ```py + /// @dataclass + /// class C: + /// x: int + /// y: str = "a" + /// ``` + /// we return a map `{"x": (int, None), "y": (str, Some(Literal["a"]))}`. + fn own_dataclass_fields( + self, + db: &'db dyn Db, + ) -> FxOrderMap, Option>)> { + let mut attributes = FxOrderMap::default(); + + let class_body_scope = self.body_scope(db); + let table = symbol_table(db, class_body_scope); + + let use_def = use_def_map(db, class_body_scope); + for (symbol_id, declarations) in use_def.all_public_declarations() { + // Here, we exclude all declarations that are not annotated assignments. We need this because + // things like function definitions and nested classes would otherwise be considered dataclass + // fields. The check is too broad in the sense that it also excludes (weird) constructs where + // a symbol would have multiple declarations, one of which is an annotated assignment. If we + // want to improve this, we could instead pass a definition-kind filter to the use-def map + // query, or to the `symbol_from_declarations` call below. Doing so would potentially require + // us to generate a union of `__init__` methods. + if !declarations + .clone() + .all(|DeclarationWithConstraint { declaration, .. }| { + declaration.is_some_and(|declaration| { + matches!( + declaration.kind(db), + DefinitionKind::AnnotatedAssignment(..) + ) + }) + }) + { + continue; + } + + let symbol = table.symbol(symbol_id); + + if let Ok(attr) = symbol_from_declarations(db, declarations) { + if attr.is_class_var() { + continue; + } + + if let Some(attr_ty) = attr.symbol.ignore_possibly_unbound() { + let bindings = use_def.public_bindings(symbol_id); + let default_ty = symbol_from_bindings(db, bindings).ignore_possibly_unbound(); + + attributes.insert(symbol.name().clone(), (attr_ty, default_ty)); + } + } + } + + attributes } /// Returns the `name` attribute of an instance of this class.