Skip to content

Commit

Permalink
Support descriptors in dataclass transform (#15006)
Browse files Browse the repository at this point in the history
Infer `__init__` argument types from the signatures of descriptor
`__set__` methods, if present. We can't (easily) perform full type
inference in a plugin, so we cheat and use a simplified implementation
that should still cover most use cases. Here we assume that `__set__` is
not decorated or overloaded, in particular.

Fixes #14868.
  • Loading branch information
JukkaL committed Apr 5, 2023
1 parent a7a995a commit 7beaec2
Show file tree
Hide file tree
Showing 2 changed files with 269 additions and 6 deletions.
61 changes: 55 additions & 6 deletions mypy/plugins/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing_extensions import Final

from mypy import errorcodes, message_registry
from mypy.expandtype import expand_type
from mypy.expandtype import expand_type, expand_type_by_instance
from mypy.nodes import (
ARG_NAMED,
ARG_NAMED_OPT,
Expand All @@ -23,6 +23,7 @@
Context,
DataclassTransformSpec,
Expression,
FuncDef,
IfStmt,
JsonDict,
NameExpr,
Expand Down Expand Up @@ -98,7 +99,7 @@ def __init__(
self.has_default = has_default
self.line = line
self.column = column
self.type = type
self.type = type # Type as __init__ argument
self.info = info
self.kw_only = kw_only
self.is_neither_frozen_nor_nonfrozen = is_neither_frozen_nor_nonfrozen
Expand Down Expand Up @@ -535,9 +536,12 @@ def collect_attributes(self) -> list[DataclassAttribute] | None:
elif not isinstance(stmt.rvalue, TempNode):
has_default = True

if not has_default:
# Make all non-default attributes implicit because they are de-facto set
# on self in the generated __init__(), not in the class body.
if not has_default and self._spec is _TRANSFORM_SPEC_FOR_DATACLASSES:
# Make all non-default dataclass attributes implicit because they are de-facto
# set on self in the generated __init__(), not in the class body. On the other
# hand, we don't know how custom dataclass transforms initialize attributes,
# so we don't treat them as implicit. This is required to support descriptors
# (https://github.com/python/mypy/issues/14868).
sym.implicit = True

is_kw_only = kw_only
Expand Down Expand Up @@ -578,6 +582,7 @@ def collect_attributes(self) -> list[DataclassAttribute] | None:
)

current_attr_names.add(lhs.name)
init_type = self._infer_dataclass_attr_init_type(sym, lhs.name, stmt)
found_attrs[lhs.name] = DataclassAttribute(
name=lhs.name,
alias=alias,
Expand All @@ -586,7 +591,7 @@ def collect_attributes(self) -> list[DataclassAttribute] | None:
has_default=has_default,
line=stmt.line,
column=stmt.column,
type=sym.type,
type=init_type,
info=cls.info,
kw_only=is_kw_only,
is_neither_frozen_nor_nonfrozen=_has_direct_dataclass_transform_metaclass(
Expand Down Expand Up @@ -755,6 +760,50 @@ def _get_bool_arg(self, name: str, default: bool) -> bool:
return require_bool_literal_argument(self._api, expression, name, default)
return default

def _infer_dataclass_attr_init_type(
self, sym: SymbolTableNode, name: str, context: Context
) -> Type | None:
"""Infer __init__ argument type for an attribute.
In particular, possibly use the signature of __set__.
"""
default = sym.type
if sym.implicit:
return default
t = get_proper_type(sym.type)

# Perform a simple-minded inference from the signature of __set__, if present.
# We can't use mypy.checkmember here, since this plugin runs before type checking.
# We only support some basic scanerios here, which is hopefully sufficient for
# the vast majority of use cases.
if not isinstance(t, Instance):
return default
setter = t.type.get("__set__")
if setter:
if isinstance(setter.node, FuncDef):
super_info = t.type.get_containing_type_info("__set__")
assert super_info
if setter.type:
setter_type = get_proper_type(
map_type_from_supertype(setter.type, t.type, super_info)
)
else:
return AnyType(TypeOfAny.unannotated)
if isinstance(setter_type, CallableType) and setter_type.arg_kinds == [
ARG_POS,
ARG_POS,
ARG_POS,
]:
return expand_type_by_instance(setter_type.arg_types[2], t)
else:
self._api.fail(
f'Unsupported signature for "__set__" in "{t.type.name}"', context
)
else:
self._api.fail(f'Unsupported "__set__" in "{t.type.name}"', context)

return default


def add_dataclass_tag(info: TypeInfo) -> None:
# The value is ignored, only the existence matters.
Expand Down
214 changes: 214 additions & 0 deletions test-data/unit/check-dataclass-transform.test
Original file line number Diff line number Diff line change
Expand Up @@ -807,3 +807,217 @@ reveal_type(bar.base) # N: Revealed type is "builtins.int"

[typing fixtures/typing-full.pyi]
[builtins fixtures/dataclasses.pyi]

[case testDataclassTransformSimpleDescriptor]
# flags: --python-version 3.11

from typing import dataclass_transform, overload, Any

@dataclass_transform()
def my_dataclass(cls): ...

class Desc:
@overload
def __get__(self, instance: None, owner: Any) -> Desc: ...
@overload
def __get__(self, instance: object, owner: Any) -> str: ...
def __get__(self, instance: object | None, owner: Any) -> Desc | str: ...

def __set__(self, instance: Any, value: str) -> None: ...

@my_dataclass
class C:
x: Desc
y: int

C(x='x', y=1)
C(x=1, y=1) # E: Argument "x" to "C" has incompatible type "int"; expected "str"
reveal_type(C(x='x', y=1).x) # N: Revealed type is "builtins.str"
reveal_type(C(x='x', y=1).y) # N: Revealed type is "builtins.int"
reveal_type(C.x) # N: Revealed type is "__main__.Desc"

[typing fixtures/typing-full.pyi]
[builtins fixtures/dataclasses.pyi]

[case testDataclassTransformUnannotatedDescriptor]
# flags: --python-version 3.11

from typing import dataclass_transform, overload, Any

@dataclass_transform()
def my_dataclass(cls): ...

class Desc:
@overload
def __get__(self, instance: None, owner: Any) -> Desc: ...
@overload
def __get__(self, instance: object, owner: Any) -> str: ...
def __get__(self, instance: object | None, owner: Any) -> Desc | str: ...

def __set__(*args, **kwargs): ...

@my_dataclass
class C:
x: Desc
y: int

C(x='x', y=1)
C(x=1, y=1)
reveal_type(C(x='x', y=1).x) # N: Revealed type is "builtins.str"
reveal_type(C(x='x', y=1).y) # N: Revealed type is "builtins.int"
reveal_type(C.x) # N: Revealed type is "__main__.Desc"

[typing fixtures/typing-full.pyi]
[builtins fixtures/dataclasses.pyi]

[case testDataclassTransformGenericDescriptor]
# flags: --python-version 3.11

from typing import dataclass_transform, overload, Any, TypeVar, Generic

@dataclass_transform()
def my_dataclass(frozen: bool = False): ...

T = TypeVar("T")

class Desc(Generic[T]):
@overload
def __get__(self, instance: None, owner: Any) -> Desc[T]: ...
@overload
def __get__(self, instance: object, owner: Any) -> T: ...
def __get__(self, instance: object | None, owner: Any) -> Desc | T: ...

def __set__(self, instance: Any, value: T) -> None: ...

@my_dataclass()
class C:
x: Desc[str]

C(x='x')
C(x=1) # E: Argument "x" to "C" has incompatible type "int"; expected "str"
reveal_type(C(x='x').x) # N: Revealed type is "builtins.str"
reveal_type(C.x) # N: Revealed type is "__main__.Desc[builtins.str]"

@my_dataclass()
class D(C):
y: Desc[int]

d = D(x='x', y=1)
reveal_type(d.x) # N: Revealed type is "builtins.str"
reveal_type(d.y) # N: Revealed type is "builtins.int"
reveal_type(D.x) # N: Revealed type is "__main__.Desc[builtins.str]"
reveal_type(D.y) # N: Revealed type is "__main__.Desc[builtins.int]"

@my_dataclass(frozen=True)
class F:
x: Desc[str] = Desc()

F(x='x')
F(x=1) # E: Argument "x" to "F" has incompatible type "int"; expected "str"
reveal_type(F(x='x').x) # N: Revealed type is "builtins.str"
reveal_type(F.x) # N: Revealed type is "__main__.Desc[builtins.str]"

[typing fixtures/typing-full.pyi]
[builtins fixtures/dataclasses.pyi]

[case testDataclassTransformGenericDescriptorWithInheritance]
# flags: --python-version 3.11

from typing import dataclass_transform, overload, Any, TypeVar, Generic

@dataclass_transform()
def my_dataclass(cls): ...

T = TypeVar("T")

class Desc(Generic[T]):
@overload
def __get__(self, instance: None, owner: Any) -> Desc[T]: ...
@overload
def __get__(self, instance: object, owner: Any) -> T: ...
def __get__(self, instance: object | None, owner: Any) -> Desc | T: ...

def __set__(self, instance: Any, value: T) -> None: ...

class Desc2(Desc[str]):
pass

@my_dataclass
class C:
x: Desc2

C(x='x')
C(x=1) # E: Argument "x" to "C" has incompatible type "int"; expected "str"
reveal_type(C(x='x').x) # N: Revealed type is "builtins.str"
reveal_type(C.x) # N: Revealed type is "__main__.Desc[builtins.str]"

[typing fixtures/typing-full.pyi]
[builtins fixtures/dataclasses.pyi]

[case testDataclassTransformDescriptorWithDifferentGetSetTypes]
# flags: --python-version 3.11

from typing import dataclass_transform, overload, Any

@dataclass_transform()
def my_dataclass(cls): ...

class Desc:
@overload
def __get__(self, instance: None, owner: Any) -> int: ...
@overload
def __get__(self, instance: object, owner: Any) -> str: ...
def __get__(self, instance, owner): ...

def __set__(self, instance: Any, value: bytes) -> None: ...

@my_dataclass
class C:
x: Desc

c = C(x=b'x')
C(x=1) # E: Argument "x" to "C" has incompatible type "int"; expected "bytes"
reveal_type(c.x) # N: Revealed type is "builtins.str"
reveal_type(C.x) # N: Revealed type is "builtins.int"
c.x = b'x'
c.x = 1 # E: Incompatible types in assignment (expression has type "int", variable has type "bytes")

[typing fixtures/typing-full.pyi]
[builtins fixtures/dataclasses.pyi]

[case testDataclassTransformUnsupportedDescriptors]
# flags: --python-version 3.11

from typing import dataclass_transform, overload, Any

@dataclass_transform()
def my_dataclass(cls): ...

class Desc:
@overload
def __get__(self, instance: None, owner: Any) -> int: ...
@overload
def __get__(self, instance: object, owner: Any) -> str: ...
def __get__(self, instance, owner): ...

def __set__(*args, **kwargs) -> None: ...

class Desc2:
@overload
def __get__(self, instance: None, owner: Any) -> int: ...
@overload
def __get__(self, instance: object, owner: Any) -> str: ...
def __get__(self, instance, owner): ...

@overload
def __set__(self, instance: Any, value: bytes) -> None: ...
@overload
def __set__(self) -> None: ...
def __set__(self, *args, **kawrga) -> None: ...

@my_dataclass
class C:
x: Desc # E: Unsupported signature for "__set__" in "Desc"
y: Desc2 # E: Unsupported "__set__" in "Desc2"
[typing fixtures/typing-full.pyi]
[builtins fixtures/dataclasses.pyi]

0 comments on commit 7beaec2

Please sign in to comment.