From da954bee94501ad25127c30a66536a2369d1cb80 Mon Sep 17 00:00:00 2001 From: Wesley Collin Wright Date: Wed, 8 Feb 2023 23:19:21 +0000 Subject: [PATCH 1/7] [dataclass_transform] support subclass/metaclass-based transforms --- mypy/nodes.py | 11 ++ mypy/plugins/dataclasses.py | 168 +++++++++++------- mypy/semanal.py | 10 ++ mypy/semanal_main.py | 10 ++ mypy/semanal_shared.py | 18 ++ test-data/unit/check-dataclass-transform.test | 46 +++++ test-data/unit/fixtures/dataclasses.pyi | 1 + 7 files changed, 199 insertions(+), 65 deletions(-) diff --git a/mypy/nodes.py b/mypy/nodes.py index 534ba7f82607..19055794e305 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -1079,6 +1079,7 @@ class ClassDef(Statement): "has_incompatible_baseclass", "deco_line", "removed_statements", + "dataclass_transform_spec", ) __match_args__ = ("name", "defs") @@ -1125,6 +1126,7 @@ def __init__( # Used for error reporting (to keep backwad compatibility with pre-3.8) self.deco_line: int | None = None self.removed_statements = [] + self.dataclass_transform_spec: DataclassTransformSpec | None = None @property def fullname(self) -> str: @@ -1148,6 +1150,11 @@ def serialize(self) -> JsonDict: "name": self.name, "fullname": self.fullname, "type_vars": [v.serialize() for v in self.type_vars], + "dataclass_transform_spec": ( + self.dataclass_transform_spec.serialize() + if self.dataclass_transform_spec is not None + else None + ), } @classmethod @@ -1163,6 +1170,10 @@ def deserialize(self, data: JsonDict) -> ClassDef: ], ) res.fullname = data["fullname"] + if data.get("dataclass_transform_spec") is not None: + res.dataclass_transform_spec = DataclassTransformSpec.deserialize( + data["dataclass_transform_spec"] + ) return res diff --git a/mypy/plugins/dataclasses.py b/mypy/plugins/dataclasses.py index 4683b8c1ffaf..db3f0c08947d 100644 --- a/mypy/plugins/dataclasses.py +++ b/mypy/plugins/dataclasses.py @@ -17,6 +17,7 @@ Argument, AssignmentStmt, CallExpr, + ClassDef, Context, DataclassTransformSpec, Expression, @@ -25,6 +26,7 @@ Node, PlaceholderNode, RefExpr, + Statement, SymbolTableNode, TempNode, TypeAlias, @@ -36,7 +38,7 @@ from mypy.plugins.common import ( _get_decorator_bool_argument, add_attribute_to_class, - add_method, + add_method_to_class, deserialize_and_fixup_type, ) from mypy.semanal_shared import find_dataclass_transform_spec @@ -161,17 +163,26 @@ class DataclassTransformer: there are no placeholders. """ - def __init__(self, ctx: ClassDefContext) -> None: - self._ctx = ctx - self._spec = _get_transform_spec(ctx.reason) + def __init__( + self, + cls: ClassDef, + # Statement must also be accepted since class definition itself may be passed as the reason + # for subclass/metaclass-based uses of `typing.dataclass_transform` + reason: Expression | Statement, + spec: DataclassTransformSpec, + api: SemanticAnalyzerPluginInterface, + ) -> None: + self._cls = cls + self._reason = reason + self._spec = spec + self._api = api def transform(self) -> bool: """Apply all the necessary transformations to the underlying dataclass so as to ensure it is fully type checked according to the rules in PEP 557. """ - ctx = self._ctx - info = self._ctx.cls.info + info = self._cls.info attributes = self.collect_attributes() if attributes is None: # Some definitions are not ready. We need another pass. @@ -180,14 +191,14 @@ def transform(self) -> bool: if attr.type is None: return False decorator_arguments = { - "init": _get_decorator_bool_argument(self._ctx, "init", True), - "eq": _get_decorator_bool_argument(self._ctx, "eq", self._spec.eq_default), - "order": _get_decorator_bool_argument(self._ctx, "order", self._spec.order_default), - "frozen": _get_decorator_bool_argument(self._ctx, "frozen", self._spec.frozen_default), - "slots": _get_decorator_bool_argument(self._ctx, "slots", False), - "match_args": _get_decorator_bool_argument(self._ctx, "match_args", True), + "init": self._get_bool_arg("init", True), + "eq": self._get_bool_arg("eq", self._spec.eq_default), + "order": self._get_bool_arg("order", self._spec.order_default), + "frozen": self._get_bool_arg("frozen", self._spec.frozen_default), + "slots": self._get_bool_arg("slots", False), + "match_args": self._get_bool_arg("match_args", True), } - py_version = self._ctx.api.options.python_version + py_version = self._api.options.python_version # If there are no attributes, it may be that the semantic analyzer has not # processed them yet. In order to work around this, we can simply skip generating @@ -199,7 +210,7 @@ def transform(self) -> bool: and attributes ): - with state.strict_optional_set(ctx.api.options.strict_optional): + with state.strict_optional_set(self._api.options.strict_optional): args = [ attr.to_argument(info) for attr in attributes @@ -221,7 +232,9 @@ def transform(self) -> bool: Argument(nameless_var, AnyType(TypeOfAny.explicit), None, ARG_STAR2), ] - add_method(ctx, "__init__", args=args, return_type=NoneType()) + add_method_to_class( + self._api, self._cls, "__init__", args=args, return_type=NoneType() + ) if ( decorator_arguments["eq"] @@ -229,7 +242,7 @@ def transform(self) -> bool: or decorator_arguments["order"] ): # Type variable for self types in generated methods. - obj_type = ctx.api.named_type("builtins.object") + obj_type = self._api.named_type("builtins.object") self_tvar_expr = TypeVarExpr( SELF_TVAR_NAME, info.fullname + "." + SELF_TVAR_NAME, [], obj_type ) @@ -238,16 +251,16 @@ def transform(self) -> bool: # Add <, >, <=, >=, but only if the class has an eq method. if decorator_arguments["order"]: if not decorator_arguments["eq"]: - ctx.api.fail('"eq" must be True if "order" is True', ctx.reason) + self._api.fail('"eq" must be True if "order" is True', self._reason) for method_name in ["__lt__", "__gt__", "__le__", "__ge__"]: # Like for __eq__ and __ne__, we want "other" to match # the self type. - obj_type = ctx.api.named_type("builtins.object") + obj_type = self._api.named_type("builtins.object") order_tvar_def = TypeVarType( SELF_TVAR_NAME, info.fullname + "." + SELF_TVAR_NAME, -1, [], obj_type ) - order_return_type = ctx.api.named_type("builtins.bool") + order_return_type = self._api.named_type("builtins.bool") order_args = [ Argument(Var("other", order_tvar_def), order_tvar_def, None, ARG_POS) ] @@ -255,13 +268,14 @@ def transform(self) -> bool: existing_method = info.get(method_name) if existing_method is not None and not existing_method.plugin_generated: assert existing_method.node - ctx.api.fail( + self._api.fail( f'You may not have a custom "{method_name}" method when "order" is True', existing_method.node, ) - add_method( - ctx, + add_method_to_class( + self._api, + self._cls, method_name, args=order_args, return_type=order_return_type, @@ -277,12 +291,12 @@ def transform(self) -> bool: if decorator_arguments["frozen"]: if any(not parent["frozen"] for parent in parent_decorator_arguments): - ctx.api.fail("Cannot inherit frozen dataclass from a non-frozen one", info) + self._api.fail("Cannot inherit frozen dataclass from a non-frozen one", info) self._propertize_callables(attributes, settable=False) self._freeze(attributes) else: if any(parent["frozen"] for parent in parent_decorator_arguments): - ctx.api.fail("Cannot inherit non-frozen dataclass from a frozen one", info) + self._api.fail("Cannot inherit non-frozen dataclass from a frozen one", info) self._propertize_callables(attributes) if decorator_arguments["slots"]: @@ -298,12 +312,12 @@ def transform(self) -> bool: and attributes and py_version >= (3, 10) ): - str_type = ctx.api.named_type("builtins.str") + str_type = self._api.named_type("builtins.str") literals: list[Type] = [ LiteralType(attr.name, str_type) for attr in attributes if attr.is_in_init ] - match_args_type = TupleType(literals, ctx.api.named_type("builtins.tuple")) - add_attribute_to_class(ctx.api, ctx.cls, "__match_args__", match_args_type) + match_args_type = TupleType(literals, self._api.named_type("builtins.tuple")) + add_attribute_to_class(self._api, self._cls, "__match_args__", match_args_type) self._add_dataclass_fields_magic_attribute() @@ -320,10 +334,10 @@ def add_slots( if not correct_version: # This means that version is lower than `3.10`, # it is just a non-existent argument for `dataclass` function. - self._ctx.api.fail( + self._api.fail( 'Keyword argument "slots" for "dataclass" ' "is only valid in Python 3.10 and higher", - self._ctx.reason, + self._reason, ) return @@ -335,11 +349,11 @@ def add_slots( # Class explicitly specifies a different `__slots__` field. # And `@dataclass(slots=True)` is used. # In runtime this raises a type error. - self._ctx.api.fail( + self._api.fail( '"{}" both defines "__slots__" and is used with "slots=True"'.format( - self._ctx.cls.name + self._cls.name ), - self._ctx.cls, + self._cls, ) return @@ -375,8 +389,7 @@ def collect_attributes(self) -> list[DataclassAttribute] | None: Return None if some dataclass base class hasn't been processed yet and thus we'll need to ask for another pass. """ - ctx = self._ctx - cls = self._ctx.cls + cls = self._cls # First, collect attributes belonging to any class in the MRO, ignoring duplicates. # @@ -397,30 +410,30 @@ def collect_attributes(self) -> list[DataclassAttribute] | None: continue # Each class depends on the set of attributes in its dataclass ancestors. - ctx.api.add_plugin_dependency(make_wildcard_trigger(info.fullname)) + self._api.add_plugin_dependency(make_wildcard_trigger(info.fullname)) found_dataclass_supertype = True for data in info.metadata["dataclass"]["attributes"]: name: str = data["name"] - attr = DataclassAttribute.deserialize(info, data, ctx.api) + attr = DataclassAttribute.deserialize(info, data, self._api) # TODO: We shouldn't be performing type operations during the main # semantic analysis pass, since some TypeInfo attributes might # still be in flux. This should be performed in a later phase. - with state.strict_optional_set(ctx.api.options.strict_optional): - attr.expand_typevar_from_subtype(ctx.cls.info) + with state.strict_optional_set(self._api.options.strict_optional): + attr.expand_typevar_from_subtype(cls.info) found_attrs[name] = attr sym_node = cls.info.names.get(name) if sym_node and sym_node.node and not isinstance(sym_node.node, Var): - ctx.api.fail( + self._api.fail( "Dataclass attribute may only be overridden by another attribute", sym_node.node, ) # Second, collect attributes belonging to the current class. current_attr_names: set[str] = set() - kw_only = _get_decorator_bool_argument(ctx, "kw_only", self._spec.kw_only_default) + kw_only = self._get_bool_arg("kw_only", self._spec.kw_only_default) for stmt in cls.defs.body: # Any assignment that doesn't use the new type declaration # syntax can be ignored out of hand. @@ -442,7 +455,7 @@ def collect_attributes(self) -> list[DataclassAttribute] | None: assert not isinstance(node, PlaceholderNode) if isinstance(node, TypeAlias): - ctx.api.fail( + self._api.fail( ("Type aliases inside dataclass definitions are not supported at runtime"), node, ) @@ -470,13 +483,13 @@ def collect_attributes(self) -> list[DataclassAttribute] | None: if self._is_kw_only_type(node_type): kw_only = True - has_field_call, field_args = self._collect_field_args(stmt.rvalue, ctx) + has_field_call, field_args = self._collect_field_args(stmt.rvalue) is_in_init_param = field_args.get("init") if is_in_init_param is None: is_in_init = True else: - is_in_init = bool(ctx.api.parse_bool(is_in_init_param)) + is_in_init = bool(self._api.parse_bool(is_in_init_param)) has_default = False # Ensure that something like x: int = field() is rejected @@ -498,7 +511,7 @@ def collect_attributes(self) -> list[DataclassAttribute] | None: # kw_only value from the decorator parameter. field_kw_only_param = field_args.get("kw_only") if field_kw_only_param is not None: - is_kw_only = bool(ctx.api.parse_bool(field_kw_only_param)) + is_kw_only = bool(self._api.parse_bool(field_kw_only_param)) if sym.type is None and node.is_final and node.is_inferred: # This is a special case, assignment like x: Final = 42 is classified @@ -506,11 +519,11 @@ def collect_attributes(self) -> list[DataclassAttribute] | None: # We do not support inferred types in dataclasses, so we can try inferring # type for simple literals, and otherwise require an explicit type # argument for Final[...]. - typ = ctx.api.analyze_simple_literal_type(stmt.rvalue, is_final=True) + typ = self._api.analyze_simple_literal_type(stmt.rvalue, is_final=True) if typ: node.type = typ else: - ctx.api.fail( + self._api.fail( "Need type argument for Final[...] with non-literal default in dataclass", stmt, ) @@ -545,19 +558,21 @@ def collect_attributes(self) -> list[DataclassAttribute] | None: if found_default and attr.is_in_init and not attr.has_default and not attr.kw_only: # If the issue comes from merging different classes, report it # at the class definition point. - context: Context = ctx.cls + context: Context = cls if attr.name in current_attr_names: context = Context(line=attr.line, column=attr.column) - ctx.api.fail( + self._api.fail( "Attributes without a default cannot follow attributes with one", context ) found_default = found_default or (attr.has_default and attr.is_in_init) if found_kw_sentinel and self._is_kw_only_type(attr.type): - context = ctx.cls + context = cls if attr.name in current_attr_names: context = Context(line=attr.line, column=attr.column) - ctx.api.fail("There may not be more than one field with the KW_ONLY type", context) + self._api.fail( + "There may not be more than one field with the KW_ONLY type", context + ) found_kw_sentinel = found_kw_sentinel or self._is_kw_only_type(attr.type) return all_attrs @@ -565,7 +580,7 @@ def _freeze(self, attributes: list[DataclassAttribute]) -> None: """Converts all attributes to @property methods in order to emulate frozen classes. """ - info = self._ctx.cls.info + info = self._cls.info for attr in attributes: sym_node = info.names.get(attr.name) if sym_node is not None: @@ -589,7 +604,7 @@ def _propertize_callables( `self` argument (it is not). """ - info = self._ctx.cls.info + info = self._cls.info for attr in attributes: if isinstance(get_proper_type(attr.type), CallableType): var = attr.to_var(info) @@ -611,21 +626,19 @@ def _is_kw_only_type(self, node: Type | None) -> bool: def _add_dataclass_fields_magic_attribute(self) -> None: attr_name = "__dataclass_fields__" any_type = AnyType(TypeOfAny.explicit) - field_type = self._ctx.api.named_type_or_none("dataclasses.Field", [any_type]) or any_type - attr_type = self._ctx.api.named_type( - "builtins.dict", [self._ctx.api.named_type("builtins.str"), field_type] + field_type = self._api.named_type_or_none("dataclasses.Field", [any_type]) or any_type + attr_type = self._api.named_type( + "builtins.dict", [self._api.named_type("builtins.str"), field_type] ) var = Var(name=attr_name, type=attr_type) - var.info = self._ctx.cls.info - var._fullname = self._ctx.cls.info.fullname + "." + attr_name + var.info = self._cls.info + var._fullname = self._cls.info.fullname + "." + attr_name var.is_classvar = True - self._ctx.cls.info.names[attr_name] = SymbolTableNode( + self._cls.info.names[attr_name] = SymbolTableNode( kind=MDEF, node=var, plugin_generated=True ) - def _collect_field_args( - self, expr: Expression, ctx: ClassDefContext - ) -> tuple[bool, dict[str, Expression]]: + def _collect_field_args(self, expr: Expression) -> tuple[bool, dict[str, Expression]]: """Returns a tuple where the first value represents whether or not the expression is a call to dataclass.field and the second is a dictionary of the keyword arguments that field() was called with. @@ -646,13 +659,37 @@ def _collect_field_args( message = 'Unpacking **kwargs in "field()" is not supported' else: message = '"field()" does not accept positional arguments' - ctx.api.fail(message, expr) + self._api.fail(message, expr) return True, {} assert name is not None args[name] = arg return True, args return False, {} + def _get_bool_arg(self, name: str, default: bool) -> bool: + # Expressions are always CallExprs (either directly or via a wrapper like Decorator), so + # we can use the helpers from common + if isinstance(self._reason, Expression): + return _get_decorator_bool_argument( + ClassDefContext(self._cls, self._reason, self._api), name, default + ) + + # Subclass/metaclass use of `typing.dataclass_transform` reads the parameters from the + # class's keyword arguments (ie `class Subclass(Parent, kwarg1=..., kwarg2=...)`) + expression = self._cls.keywords.get(name) + if expression is not None: + value = self._api.parse_bool(self._cls.keywords[name]) + if value is not None: + return value + else: + self._api.fail(f'"{name}" argument must be True or False.', expression) + return default + + +def add_dataclass_tag(info: TypeInfo) -> None: + # The value is ignored, only the existence matters. + info.metadata["dataclass_tag"] = {} + def dataclass_tag_callback(ctx: ClassDefContext) -> None: """Record that we have a dataclass in the main semantic analysis pass. @@ -660,13 +697,14 @@ def dataclass_tag_callback(ctx: ClassDefContext) -> None: The later pass implemented by DataclassTransformer will use this to detect dataclasses in base classes. """ - # The value is ignored, only the existence matters. - ctx.cls.info.metadata["dataclass_tag"] = {} + add_dataclass_tag(ctx.cls.info) def dataclass_class_maker_callback(ctx: ClassDefContext) -> bool: """Hooks into the class typechecking process to add support for dataclasses.""" - transformer = DataclassTransformer(ctx) + transformer = DataclassTransformer( + ctx.cls, ctx.reason, _get_transform_spec(ctx.reason), ctx.api + ) return transformer.transform() diff --git a/mypy/semanal.py b/mypy/semanal.py index cd5b82f80b1d..6fefe00b130e 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -1749,6 +1749,12 @@ def apply_class_plugin_hooks(self, defn: ClassDef) -> None: if hook: hook(ClassDefContext(defn, base_expr, self)) + # Check if the class definition itself triggers a dataclass transform (via a parent class/ + # metaclass) + spec = find_dataclass_transform_spec(defn) + if spec is not None: + dataclasses_plugin.add_dataclass_tag(defn.info) + def get_fullname_for_hook(self, expr: Expression) -> str | None: if isinstance(expr, CallExpr): return self.get_fullname_for_hook(expr.callee) @@ -1796,6 +1802,10 @@ def analyze_class_decorator(self, defn: ClassDef, decorator: Expression) -> None self.fail("@runtime_checkable can only be used with protocol classes", defn) elif decorator.fullname in FINAL_DECORATOR_NAMES: defn.info.is_final = True + elif isinstance(decorator, CallExpr) and refers_to_fullname( + decorator.callee, DATACLASS_TRANSFORM_NAMES + ): + defn.dataclass_transform_spec = self.parse_dataclass_transform_spec(decorator) def clean_up_bases_and_infer_type_variables( self, defn: ClassDef, base_type_exprs: list[Expression], context: Context diff --git a/mypy/semanal_main.py b/mypy/semanal_main.py index 796a862c35e7..a5e85878e931 100644 --- a/mypy/semanal_main.py +++ b/mypy/semanal_main.py @@ -472,6 +472,16 @@ def apply_hooks_to_class( if hook: ok = ok and hook(ClassDefContext(defn, decorator, self)) + + # Check if the class definition itself triggers a dataclass transform (via a parent class/ + # metaclass) + spec = find_dataclass_transform_spec(info) + if spec is not None: + with self.file_context(file_node, options, info): + # We can't use the normal hook because reason = defn, and ClassDefContext only accepts + # an Expression for reason + ok = ok and dataclasses_plugin.DataclassTransformer(defn, defn, spec, self).transform() + return ok diff --git a/mypy/semanal_shared.py b/mypy/semanal_shared.py index 05edf2ac073f..e97d1ec8e5e0 100644 --- a/mypy/semanal_shared.py +++ b/mypy/semanal_shared.py @@ -12,6 +12,7 @@ from mypy.errorcodes import ErrorCode from mypy.nodes import ( CallExpr, + ClassDef, Context, DataclassTransformSpec, Decorator, @@ -378,7 +379,24 @@ def find_dataclass_transform_spec(node: Node | None) -> DataclassTransformSpec | # `@dataclass_transform(...)` syntax and never `@dataclass_transform` node = node.func + # For functions, we can directly consult the AST field for the spec if isinstance(node, FuncDef): return node.dataclass_transform_spec + if isinstance(node, ClassDef): + node = node.info + if isinstance(node, TypeInfo): + # Search all parent classes to see if any are decorated with `typing.dataclass_transform` + for base in node.mro[1:]: + if base.defn.dataclass_transform_spec is not None: + return base.defn.dataclass_transform_spec + + # Check if there is a metaclass that is decorated with `typing.dataclass_transform` + metaclass_type = node.metaclass_type + if ( + metaclass_type is not None + and metaclass_type.type.defn.dataclass_transform_spec is not None + ): + return metaclass_type.type.defn.dataclass_transform_spec + return None diff --git a/test-data/unit/check-dataclass-transform.test b/test-data/unit/check-dataclass-transform.test index 01e8935b0745..a5e85da41ffd 100644 --- a/test-data/unit/check-dataclass-transform.test +++ b/test-data/unit/check-dataclass-transform.test @@ -76,12 +76,19 @@ def my_dataclass(*, eq: bool = True, order: bool = False) -> Callable[[Type], Ty def transform(cls: Type) -> Type: return cls return transform +@dataclass_transform() +class BaseClass: + def __init_subclass__(cls, *, eq: bool): ... +@dataclass_transform() +class Metaclass(type): ... BOOL_CONSTANT = True @my_dataclass(eq=BOOL_CONSTANT) # E: "eq" argument must be True or False. class A: ... @my_dataclass(order=not False) # E: "order" argument must be True or False. class B: ... +class C(BaseClass, eq=BOOL_CONSTANT): ... # E: "eq" argument must be True or False. +class D(metaclass=Metaclass, order=not False): ... # E: "order" argument must be True or False. [typing fixtures/typing-full.pyi] [builtins fixtures/dataclasses.pyi] @@ -202,3 +209,42 @@ Foo(5) [typing fixtures/typing-full.pyi] [builtins fixtures/dataclasses.pyi] + +[case testDataclassTransformViaBaseClass] +# flags: --python-version 3.11 +from typing import dataclass_transform + +@dataclass_transform(frozen_default=True) +class Dataclass: + def __init_subclass__(cls, *, kw_only: bool): ... + +class Person(Dataclass, kw_only=True): + name: str + age: int + +reveal_type(Person) # N: Revealed type is "def (*, name: builtins.str, age: builtins.int) -> __main__.Person" +Person('Jonh', 21) # E: Too many positional arguments for "Person" +person = Person(name='John', age=32) +person.name = "John Smith" # E: Property "name" defined in "Person" is read-only + +[typing fixtures/typing-full.pyi] +[builtins fixtures/dataclasses.pyi] + +[case testDataclassTransformViaMetaclass] +# flags: --python-version 3.11 +from typing import dataclass_transform + +@dataclass_transform(frozen_default=True) +class Dataclass(type): ... + +class Person(metaclass=Dataclass, kw_only=True): + name: str + age: int + +reveal_type(Person) # N: Revealed type is "def (*, name: builtins.str, age: builtins.int) -> __main__.Person" +Person('Jonh', 21) # E: Too many positional arguments for "Person" +person = Person(name='John', age=32) +person.name = "John Smith" # E: Property "name" defined in "Person" is read-only + +[typing fixtures/typing-full.pyi] +[builtins fixtures/dataclasses.pyi] diff --git a/test-data/unit/fixtures/dataclasses.pyi b/test-data/unit/fixtures/dataclasses.pyi index ab692302a8b6..e9394c84ba7d 100644 --- a/test-data/unit/fixtures/dataclasses.pyi +++ b/test-data/unit/fixtures/dataclasses.pyi @@ -10,6 +10,7 @@ VT = TypeVar('VT') class object: def __init__(self) -> None: pass + def __init_subclass__(cls) -> None: pass def __eq__(self, o: object) -> bool: pass def __ne__(self, o: object) -> bool: pass From b8720166a84fafd3df602c3816fa1ac694c130b2 Mon Sep 17 00:00:00 2001 From: Wesley Collin Wright Date: Thu, 9 Feb 2023 23:00:14 +0000 Subject: [PATCH 2/7] move dataclass_transform_spec field from ClassDef to TypeInfo --- mypy/nodes.py | 25 ++++++++++++++----------- mypy/semanal.py | 2 +- mypy/semanal_shared.py | 11 ++++------- 3 files changed, 19 insertions(+), 19 deletions(-) diff --git a/mypy/nodes.py b/mypy/nodes.py index 19055794e305..2f2aa6a3efbe 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -1079,7 +1079,6 @@ class ClassDef(Statement): "has_incompatible_baseclass", "deco_line", "removed_statements", - "dataclass_transform_spec", ) __match_args__ = ("name", "defs") @@ -1126,7 +1125,6 @@ def __init__( # Used for error reporting (to keep backwad compatibility with pre-3.8) self.deco_line: int | None = None self.removed_statements = [] - self.dataclass_transform_spec: DataclassTransformSpec | None = None @property def fullname(self) -> str: @@ -1150,11 +1148,6 @@ def serialize(self) -> JsonDict: "name": self.name, "fullname": self.fullname, "type_vars": [v.serialize() for v in self.type_vars], - "dataclass_transform_spec": ( - self.dataclass_transform_spec.serialize() - if self.dataclass_transform_spec is not None - else None - ), } @classmethod @@ -1170,10 +1163,6 @@ def deserialize(self, data: JsonDict) -> ClassDef: ], ) res.fullname = data["fullname"] - if data.get("dataclass_transform_spec") is not None: - res.dataclass_transform_spec = DataclassTransformSpec.deserialize( - data["dataclass_transform_spec"] - ) return res @@ -2841,6 +2830,7 @@ class is generic then it will be a type constructor of higher kind. "type_var_tuple_prefix", "type_var_tuple_suffix", "self_type", + "dataclass_transform_spec", ) _fullname: str # Fully qualified name @@ -2988,6 +2978,9 @@ class is generic then it will be a type constructor of higher kind. # Shared type variable for typing.Self in this class (if used, otherwise None). self_type: mypy.types.TypeVarType | None + # Added if the corresponding class is directly decorated with `typing.dataclass_transform` + dataclass_transform_spec: DataclassTransformSpec | None + FLAGS: Final = [ "is_abstract", "is_enum", @@ -3043,6 +3036,7 @@ def __init__(self, names: SymbolTable, defn: ClassDef, module_name: str) -> None self.is_intersection = False self.metadata = {} self.self_type = None + self.dataclass_transform_spec = None def add_type_vars(self) -> None: self.has_type_var_tuple_type = False @@ -3262,6 +3256,11 @@ def serialize(self) -> JsonDict: "slots": list(sorted(self.slots)) if self.slots is not None else None, "deletable_attributes": self.deletable_attributes, "self_type": self.self_type.serialize() if self.self_type is not None else None, + "dataclass_transform_spec": ( + self.dataclass_transform_spec.serialize() + if self.dataclass_transform_spec is not None + else None + ), } return data @@ -3325,6 +3324,10 @@ def deserialize(cls, data: JsonDict) -> TypeInfo: set_flags(ti, data["flags"]) st = data["self_type"] ti.self_type = mypy.types.TypeVarType.deserialize(st) if st is not None else None + if data.get("dataclass_transform_spec") is not None: + ti.dataclass_transform_spec = DataclassTransformSpec.deserialize( + data["dataclass_transform_spec"] + ) return ti diff --git a/mypy/semanal.py b/mypy/semanal.py index 6fefe00b130e..8dcea36f41b9 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -1805,7 +1805,7 @@ def analyze_class_decorator(self, defn: ClassDef, decorator: Expression) -> None elif isinstance(decorator, CallExpr) and refers_to_fullname( decorator.callee, DATACLASS_TRANSFORM_NAMES ): - defn.dataclass_transform_spec = self.parse_dataclass_transform_spec(decorator) + defn.info.dataclass_transform_spec = self.parse_dataclass_transform_spec(decorator) def clean_up_bases_and_infer_type_variables( self, defn: ClassDef, base_type_exprs: list[Expression], context: Context diff --git a/mypy/semanal_shared.py b/mypy/semanal_shared.py index e97d1ec8e5e0..b008b79f2738 100644 --- a/mypy/semanal_shared.py +++ b/mypy/semanal_shared.py @@ -388,15 +388,12 @@ def find_dataclass_transform_spec(node: Node | None) -> DataclassTransformSpec | if isinstance(node, TypeInfo): # Search all parent classes to see if any are decorated with `typing.dataclass_transform` for base in node.mro[1:]: - if base.defn.dataclass_transform_spec is not None: - return base.defn.dataclass_transform_spec + if base.dataclass_transform_spec is not None: + return base.dataclass_transform_spec # Check if there is a metaclass that is decorated with `typing.dataclass_transform` metaclass_type = node.metaclass_type - if ( - metaclass_type is not None - and metaclass_type.type.defn.dataclass_transform_spec is not None - ): - return metaclass_type.type.defn.dataclass_transform_spec + if metaclass_type is not None and metaclass_type.type.dataclass_transform_spec is not None: + return metaclass_type.type.dataclass_transform_spec return None From b388a8f90cea7cf8fd2b24f3ffdd663e215116c7 Mon Sep 17 00:00:00 2001 From: Wesley Collin Wright Date: Thu, 9 Feb 2023 23:10:07 +0000 Subject: [PATCH 3/7] add test coverage for subclass of subclass/using metaclass --- test-data/unit/check-dataclass-transform.test | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/test-data/unit/check-dataclass-transform.test b/test-data/unit/check-dataclass-transform.test index a5e85da41ffd..32b5bab40c75 100644 --- a/test-data/unit/check-dataclass-transform.test +++ b/test-data/unit/check-dataclass-transform.test @@ -216,7 +216,7 @@ from typing import dataclass_transform @dataclass_transform(frozen_default=True) class Dataclass: - def __init_subclass__(cls, *, kw_only: bool): ... + def __init_subclass__(cls, *, kw_only: bool = False): ... class Person(Dataclass, kw_only=True): name: str @@ -227,6 +227,12 @@ Person('Jonh', 21) # E: Too many positional arguments for "Person" person = Person(name='John', age=32) person.name = "John Smith" # E: Property "name" defined in "Person" is read-only +class Contact(Person): + email: str + +reveal_type(Contact) # N: Revealed type is "def (email: builtins.str, *, name: builtins.str, age: builtins.int) -> __main__.Contact" +Contact('john@john.com', name='John', age=32) + [typing fixtures/typing-full.pyi] [builtins fixtures/dataclasses.pyi] @@ -246,5 +252,11 @@ Person('Jonh', 21) # E: Too many positional arguments for "Person" person = Person(name='John', age=32) person.name = "John Smith" # E: Property "name" defined in "Person" is read-only +class Contact(Person): + email: str + +reveal_type(Contact) # N: Revealed type is "def (email: builtins.str, *, name: builtins.str, age: builtins.int) -> __main__.Contact" +Contact('john@john.com', name='John', age=32) + [typing fixtures/typing-full.pyi] [builtins fixtures/dataclasses.pyi] From 486cae9042ef0f3f34fc6be4a852c4a19c8b139e Mon Sep 17 00:00:00 2001 From: Wesley Collin Wright Date: Thu, 9 Feb 2023 23:27:39 +0000 Subject: [PATCH 4/7] comment explaining why we don't search the metaclass's MRO --- mypy/semanal_shared.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/mypy/semanal_shared.py b/mypy/semanal_shared.py index b008b79f2738..28ec8d0857ff 100644 --- a/mypy/semanal_shared.py +++ b/mypy/semanal_shared.py @@ -392,6 +392,17 @@ def find_dataclass_transform_spec(node: Node | None) -> DataclassTransformSpec | return base.dataclass_transform_spec # Check if there is a metaclass that is decorated with `typing.dataclass_transform` + # + # Note that PEP 681 only discusses using a metaclass that is directly decorated with + # `typing.dataclass_transform`; subclasses thereof should be treated with dataclass + # semantics rather than as transforms: + # + # > If dataclass_transform is applied to a class, dataclass-like semantics will be assumed + # > for any class that directly or indirectly derives from the decorated class or uses the + # > decorated class as a metaclass. + # + # The wording doesn't make this entirely explicit, but Pyright (the reference + # implementation for this PEP) only handles directly-decorated metaclasses. metaclass_type = node.metaclass_type if metaclass_type is not None and metaclass_type.type.dataclass_transform_spec is not None: return metaclass_type.type.dataclass_transform_spec From 8c52fc9f56f4e80814b0677fbc0f99880979b5d9 Mon Sep 17 00:00:00 2001 From: Wesley Collin Wright Date: Thu, 9 Feb 2023 23:32:52 +0000 Subject: [PATCH 5/7] add test case for subclass of metaclass as transform --- test-data/unit/check-dataclass-transform.test | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/test-data/unit/check-dataclass-transform.test b/test-data/unit/check-dataclass-transform.test index 32b5bab40c75..c59cdecb2fad 100644 --- a/test-data/unit/check-dataclass-transform.test +++ b/test-data/unit/check-dataclass-transform.test @@ -260,3 +260,22 @@ Contact('john@john.com', name='John', age=32) [typing fixtures/typing-full.pyi] [builtins fixtures/dataclasses.pyi] + +[case testDataclassTransformViaSubclassOfMetaclass] +# flags: --python-version 3.11 +from typing import dataclass_transform + +@dataclass_transform(frozen_default=True) +class BaseMeta(type): ... +class SubMeta(BaseMeta): ... + +# MyPy does *not* recognize this as a dataclass because the metaclass is not directly decorated with +# dataclass_transform +class Foo(metaclass=SubMeta): + foo: int + +reveal_type(Foo) # N: Revealed type is "def () -> __main__.Foo" +Foo(1) # E: Too many arguments for "Foo" + +[typing fixtures/typing-full.pyi] +[builtins fixtures/dataclasses.pyi] From 51598eb45828f140eadcc7d7b67946c0635c1e52 Mon Sep 17 00:00:00 2001 From: Wesley Collin Wright Date: Thu, 9 Feb 2023 23:42:20 +0000 Subject: [PATCH 6/7] remove period from error message --- mypy/plugins/dataclasses.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mypy/plugins/dataclasses.py b/mypy/plugins/dataclasses.py index db3f0c08947d..3feb644dc8ea 100644 --- a/mypy/plugins/dataclasses.py +++ b/mypy/plugins/dataclasses.py @@ -682,7 +682,7 @@ def _get_bool_arg(self, name: str, default: bool) -> bool: if value is not None: return value else: - self._api.fail(f'"{name}" argument must be True or False.', expression) + self._api.fail(f'"{name}" argument must be True or False', expression) return default From 16d9f0837052280f291fa34701e0fb4751faef38 Mon Sep 17 00:00:00 2001 From: Wesley Collin Wright Date: Thu, 9 Feb 2023 19:36:28 -0600 Subject: [PATCH 7/7] fix tests after removing period --- test-data/unit/check-dataclass-transform.test | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test-data/unit/check-dataclass-transform.test b/test-data/unit/check-dataclass-transform.test index c59cdecb2fad..075302762041 100644 --- a/test-data/unit/check-dataclass-transform.test +++ b/test-data/unit/check-dataclass-transform.test @@ -87,8 +87,8 @@ BOOL_CONSTANT = True class A: ... @my_dataclass(order=not False) # E: "order" argument must be True or False. class B: ... -class C(BaseClass, eq=BOOL_CONSTANT): ... # E: "eq" argument must be True or False. -class D(metaclass=Metaclass, order=not False): ... # E: "order" argument must be True or False. +class C(BaseClass, eq=BOOL_CONSTANT): ... # E: "eq" argument must be True or False +class D(metaclass=Metaclass, order=not False): ... # E: "order" argument must be True or False [typing fixtures/typing-full.pyi] [builtins fixtures/dataclasses.pyi]