Skip to content

Commit

Permalink
[dataclass_transform] support field_specifiers (#14667)
Browse files Browse the repository at this point in the history
These are analogous to `dataclasses.field`/`dataclasses.Field`.

Like most dataclass_transform features so far, this commit mostly just
plumbs through the necessary metadata so that we can re-use the existing
`dataclasses` plugin logic. It also adds support for the `alias=` and
`factory=` kwargs for fields, which are small; we rely on typeshed to
enforce that these aren't used with `dataclasses.field`.
  • Loading branch information
wesleywright authored Feb 15, 2023
1 parent ec511c6 commit 4635a8c
Show file tree
Hide file tree
Showing 5 changed files with 196 additions and 5 deletions.
4 changes: 4 additions & 0 deletions mypy/message_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,3 +270,7 @@ def with_additional_msg(self, info: str) -> ErrorMessage:
CLASS_PATTERN_UNKNOWN_KEYWORD: Final = 'Class "{}" has no attribute "{}"'
MULTIPLE_ASSIGNMENTS_IN_PATTERN: Final = 'Multiple assignments to name "{}" in pattern'
CANNOT_MODIFY_MATCH_ARGS: Final = 'Cannot assign to "__match_args__"'

DATACLASS_FIELD_ALIAS_MUST_BE_LITERAL: Final = (
'"alias" argument to dataclass field must be a string literal'
)
4 changes: 4 additions & 0 deletions mypy/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,10 @@ def parse_bool(self, expr: Expression) -> bool | None:
"""Parse True/False literals."""
raise NotImplementedError

@abstractmethod
def parse_str_literal(self, expr: Expression) -> str | None:
"""Parse string literals."""

@abstractmethod
def fail(
self,
Expand Down
44 changes: 41 additions & 3 deletions mypy/plugins/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Optional
from typing_extensions import Final

from mypy import errorcodes, message_registry
from mypy.expandtype import expand_type
from mypy.nodes import (
ARG_NAMED,
Expand Down Expand Up @@ -77,6 +78,7 @@ class DataclassAttribute:
def __init__(
self,
name: str,
alias: str | None,
is_in_init: bool,
is_init_var: bool,
has_default: bool,
Expand All @@ -87,6 +89,7 @@ def __init__(
kw_only: bool,
) -> None:
self.name = name
self.alias = alias
self.is_in_init = is_in_init
self.is_init_var = is_init_var
self.has_default = has_default
Expand Down Expand Up @@ -121,12 +124,13 @@ def expand_type(self, current_info: TypeInfo) -> Optional[Type]:
return self.type

def to_var(self, current_info: TypeInfo) -> Var:
return Var(self.name, self.expand_type(current_info))
return Var(self.alias or self.name, self.expand_type(current_info))

def serialize(self) -> JsonDict:
assert self.type
return {
"name": self.name,
"alias": self.alias,
"is_in_init": self.is_in_init,
"is_init_var": self.is_init_var,
"has_default": self.has_default,
Expand Down Expand Up @@ -495,7 +499,12 @@ def collect_attributes(self) -> list[DataclassAttribute] | None:
# Ensure that something like x: int = field() is rejected
# after an attribute with a default.
if has_field_call:
has_default = "default" in field_args or "default_factory" in field_args
has_default = (
"default" in field_args
or "default_factory" in field_args
# alias for default_factory defined in PEP 681
or "factory" in field_args
)

# All other assignments are already type checked.
elif not isinstance(stmt.rvalue, TempNode):
Expand All @@ -511,7 +520,11 @@ def collect_attributes(self) -> list[DataclassAttribute] | None:
# kw_only value from the decorator parameter.
field_kw_only_param = field_args.get("kw_only")
if field_kw_only_param is not None:
is_kw_only = bool(self._api.parse_bool(field_kw_only_param))
value = self._api.parse_bool(field_kw_only_param)
if value is not None:
is_kw_only = value
else:
self._api.fail('"kw_only" argument must be a boolean literal', stmt.rvalue)

if sym.type is None and node.is_final and node.is_inferred:
# This is a special case, assignment like x: Final = 42 is classified
Expand All @@ -529,9 +542,20 @@ def collect_attributes(self) -> list[DataclassAttribute] | None:
)
node.type = AnyType(TypeOfAny.from_error)

alias = None
if "alias" in field_args:
alias = self._api.parse_str_literal(field_args["alias"])
if alias is None:
self._api.fail(
message_registry.DATACLASS_FIELD_ALIAS_MUST_BE_LITERAL,
stmt.rvalue,
code=errorcodes.LITERAL_REQ,
)

current_attr_names.add(lhs.name)
found_attrs[lhs.name] = DataclassAttribute(
name=lhs.name,
alias=alias,
is_in_init=is_in_init,
is_init_var=is_init_var,
has_default=has_default,
Expand Down Expand Up @@ -624,6 +648,14 @@ def _is_kw_only_type(self, node: Type | None) -> bool:
return node_type.type.fullname == "dataclasses.KW_ONLY"

def _add_dataclass_fields_magic_attribute(self) -> None:
# Only add if the class is a dataclasses dataclass, and omit it for dataclass_transform
# classes.
# It would be nice if this condition were reified rather than using an `is` check.
# Only add if the class is a dataclasses dataclass, and omit it for dataclass_transform
# classes.
if self._spec is not _TRANSFORM_SPEC_FOR_DATACLASSES:
return

attr_name = "__dataclass_fields__"
any_type = AnyType(TypeOfAny.explicit)
field_type = self._api.named_type_or_none("dataclasses.Field", [any_type]) or any_type
Expand Down Expand Up @@ -657,6 +689,12 @@ def _collect_field_args(self, expr: Expression) -> tuple[bool, dict[str, Express
# the best we can do for now is not to fail.
# TODO: we can infer what's inside `**` and try to collect it.
message = 'Unpacking **kwargs in "field()" is not supported'
elif self._spec is not _TRANSFORM_SPEC_FOR_DATACLASSES:
# dataclasses.field can only be used with keyword args, but this
# restriction is only enforced for the *standardized* arguments to
# dataclass_transform field specifiers. If this is not a
# dataclasses.dataclass class, we can just skip positional args safely.
continue
else:
message = '"field()" does not accept positional arguments'
self._api.fail(message, expr)
Expand Down
30 changes: 28 additions & 2 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@
remove_dups,
type_constructors,
)
from mypy.typeops import function_type, get_type_vars
from mypy.typeops import function_type, get_type_vars, try_getting_str_literals_from_type
from mypy.types import (
ASSERT_TYPE_NAMES,
DATACLASS_TRANSFORM_NAMES,
Expand Down Expand Up @@ -6462,6 +6462,17 @@ def parse_bool(self, expr: Expression) -> bool | None:
return False
return None

def parse_str_literal(self, expr: Expression) -> str | None:
"""Attempt to find the string literal value of the given expression. Returns `None` if no
literal value can be found."""
if isinstance(expr, StrExpr):
return expr.value
if isinstance(expr, RefExpr) and isinstance(expr.node, Var) and expr.node.type is not None:
values = try_getting_str_literals_from_type(expr.node.type)
if values is not None and len(values) == 1:
return values[0]
return None

def set_future_import_flags(self, module_name: str) -> None:
if module_name in FUTURE_IMPORTS:
self.modules[self.cur_mod_id].future_import_flags.add(FUTURE_IMPORTS[module_name])
Expand All @@ -6482,7 +6493,9 @@ def parse_dataclass_transform_spec(self, call: CallExpr) -> DataclassTransformSp
# field_specifiers is currently the only non-boolean argument; check for it first so
# so the rest of the block can fail through to handling booleans
if name == "field_specifiers":
self.fail('"field_specifiers" support is currently unimplemented', call)
parameters.field_specifiers = self.parse_dataclass_transform_field_specifiers(
value
)
continue

boolean = require_bool_literal_argument(self, value, name)
Expand All @@ -6502,6 +6515,19 @@ def parse_dataclass_transform_spec(self, call: CallExpr) -> DataclassTransformSp

return parameters

def parse_dataclass_transform_field_specifiers(self, arg: Expression) -> tuple[str, ...]:
if not isinstance(arg, TupleExpr):
self.fail('"field_specifiers" argument must be a tuple literal', arg)
return tuple()

names = []
for specifier in arg.items:
if not isinstance(specifier, RefExpr):
self.fail('"field_specifiers" must only contain identifiers', specifier)
return tuple()
names.append(specifier.fullname)
return tuple(names)


def replace_implicit_first_type(sig: FunctionLike, new: Type) -> FunctionLike:
if isinstance(sig, CallableType):
Expand Down
119 changes: 119 additions & 0 deletions test-data/unit/check-dataclass-transform.test
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,125 @@ Foo(5)
[typing fixtures/typing-full.pyi]
[builtins fixtures/dataclasses.pyi]

[case testDataclassTransformFieldSpecifierRejectMalformed]
# flags: --python-version 3.11
from typing import dataclass_transform, Any, Callable, Final, Type

def some_type() -> Type: ...
def some_function() -> Callable[[], None]: ...

def field(*args, **kwargs): ...
def fields_tuple() -> tuple[type | Callable[..., Any], ...]: return (field,)
CONSTANT: Final = (field,)

@dataclass_transform(field_specifiers=(some_type(),)) # E: "field_specifiers" must only contain identifiers
def bad_dataclass1() -> None: ...
@dataclass_transform(field_specifiers=(some_function(),)) # E: "field_specifiers" must only contain identifiers
def bad_dataclass2() -> None: ...
@dataclass_transform(field_specifiers=CONSTANT) # E: "field_specifiers" argument must be a tuple literal
def bad_dataclass3() -> None: ...
@dataclass_transform(field_specifiers=fields_tuple()) # E: "field_specifiers" argument must be a tuple literal
def bad_dataclass4() -> None: ...

[typing fixtures/typing-full.pyi]
[builtins fixtures/dataclasses.pyi]

[case testDataclassTransformFieldSpecifierParams]
# flags: --python-version 3.11
from typing import dataclass_transform, Any, Callable, Type, Final

def field(
*,
init: bool = True,
kw_only: bool = False,
alias: str | None = None,
default: Any | None = None,
default_factory: Callable[[], Any] | None = None,
factory: Callable[[], Any] | None = None,
): ...
@dataclass_transform(field_specifiers=(field,))
def my_dataclass(cls: Type) -> Type:
return cls

B: Final = 'b_'
@my_dataclass
class Foo:
a: int = field(alias='a_')
b: int = field(alias=B)
# cannot be passed as a positional
kwonly: int = field(kw_only=True, default=0)
# Safe to omit from constructor, error to pass
noinit: int = field(init=False, default=1)
# It should be safe to call the constructor without passing any of these
unused1: int = field(default=0)
unused2: int = field(factory=lambda: 0)
unused3: int = field(default_factory=lambda: 0)

Foo(a=5, b_=1) # E: Unexpected keyword argument "a" for "Foo"
Foo(a_=1, b_=1, noinit=1) # E: Unexpected keyword argument "noinit" for "Foo"
Foo(1, 2, 3) # E: Too many positional arguments for "Foo"
foo = Foo(1, 2, kwonly=3)
reveal_type(foo.noinit) # N: Revealed type is "builtins.int"
reveal_type(foo.unused1) # N: Revealed type is "builtins.int"
Foo(a_=5, b_=1, unused1=2, unused2=3, unused3=4)

def some_str() -> str: ...
def some_bool() -> bool: ...
@my_dataclass
class Bad:
bad1: int = field(alias=some_str()) # E: "alias" argument to dataclass field must be a string literal
bad2: int = field(kw_only=some_bool()) # E: "kw_only" argument must be a boolean literal

# this metadata should only exist for dataclasses.dataclass classes
Foo.__dataclass_fields__ # E: "Type[Foo]" has no attribute "__dataclass_fields__"

[typing fixtures/typing-full.pyi]
[builtins fixtures/dataclasses.pyi]

[case testDataclassTransformFieldSpecifierExtraArgs]
# flags: --python-version 3.11
from typing import dataclass_transform

def field(extra1, *, kw_only=False, extra2=0): ...
@dataclass_transform(field_specifiers=(field,))
def my_dataclass(cls):
return cls

@my_dataclass
class Good:
a: int = field(5)
b: int = field(5, extra2=1)
c: int = field(5, kw_only=True)

@my_dataclass
class Bad:
a: int = field(kw_only=True) # E: Missing positional argument "extra1" in call to "field"

[typing fixtures/typing-full.pyi]
[builtins fixtures/dataclasses.pyi]

[case testDataclassTransformMultipleFieldSpecifiers]
# flags: --python-version 3.11
from typing import dataclass_transform

def field1(*, default: int) -> int: ...
def field2(*, default: str) -> str: ...

@dataclass_transform(field_specifiers=(field1, field2))
def my_dataclass(cls): return cls

@my_dataclass
class Foo:
a: int = field1(default=0)
b: str = field2(default='hello')

reveal_type(Foo) # N: Revealed type is "def (a: builtins.int =, b: builtins.str =) -> __main__.Foo"
Foo()
Foo(a=1, b='bye')

[typing fixtures/typing-full.pyi]
[builtins fixtures/dataclasses.pyi]

[case testDataclassTransformOverloadsDecoratorOnOverload]
# flags: --python-version 3.11
from typing import dataclass_transform, overload, Any, Callable, Type, Literal
Expand Down

0 comments on commit 4635a8c

Please sign in to comment.