Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support descriptors in dataclass transform #15006

Merged
merged 10 commits into from
Apr 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 55 additions & 6 deletions mypy/plugins/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing_extensions import Final

from mypy import errorcodes, message_registry
from mypy.expandtype import expand_type
from mypy.expandtype import expand_type, expand_type_by_instance
from mypy.nodes import (
ARG_NAMED,
ARG_NAMED_OPT,
Expand All @@ -23,6 +23,7 @@
Context,
DataclassTransformSpec,
Expression,
FuncDef,
IfStmt,
JsonDict,
NameExpr,
Expand Down Expand Up @@ -98,7 +99,7 @@ def __init__(
self.has_default = has_default
self.line = line
self.column = column
self.type = type
self.type = type # Type as __init__ argument
self.info = info
self.kw_only = kw_only
self.is_neither_frozen_nor_nonfrozen = is_neither_frozen_nor_nonfrozen
Expand Down Expand Up @@ -535,9 +536,12 @@ def collect_attributes(self) -> list[DataclassAttribute] | None:
elif not isinstance(stmt.rvalue, TempNode):
has_default = True

if not has_default:
# Make all non-default attributes implicit because they are de-facto set
# on self in the generated __init__(), not in the class body.
if not has_default and self._spec is _TRANSFORM_SPEC_FOR_DATACLASSES:
# Make all non-default dataclass attributes implicit because they are de-facto
# set on self in the generated __init__(), not in the class body. On the other
# hand, we don't know how custom dataclass transforms initialize attributes,
# so we don't treat them as implicit. This is required to support descriptors
# (https://github.com/python/mypy/issues/14868).
sym.implicit = True

is_kw_only = kw_only
Expand Down Expand Up @@ -578,6 +582,7 @@ def collect_attributes(self) -> list[DataclassAttribute] | None:
)

current_attr_names.add(lhs.name)
init_type = self._infer_dataclass_attr_init_type(sym, lhs.name, stmt)
found_attrs[lhs.name] = DataclassAttribute(
name=lhs.name,
alias=alias,
Expand All @@ -586,7 +591,7 @@ def collect_attributes(self) -> list[DataclassAttribute] | None:
has_default=has_default,
line=stmt.line,
column=stmt.column,
type=sym.type,
type=init_type,
info=cls.info,
kw_only=is_kw_only,
is_neither_frozen_nor_nonfrozen=_has_direct_dataclass_transform_metaclass(
Expand Down Expand Up @@ -755,6 +760,50 @@ def _get_bool_arg(self, name: str, default: bool) -> bool:
return require_bool_literal_argument(self._api, expression, name, default)
return default

def _infer_dataclass_attr_init_type(
self, sym: SymbolTableNode, name: str, context: Context
) -> Type | None:
"""Infer __init__ argument type for an attribute.

In particular, possibly use the signature of __set__.
"""
default = sym.type
if sym.implicit:
return default
t = get_proper_type(sym.type)

# Perform a simple-minded inference from the signature of __set__, if present.
# We can't use mypy.checkmember here, since this plugin runs before type checking.
# We only support some basic scanerios here, which is hopefully sufficient for
# the vast majority of use cases.
if not isinstance(t, Instance):
return default
setter = t.type.get("__set__")
if setter:
if isinstance(setter.node, FuncDef):
super_info = t.type.get_containing_type_info("__set__")
assert super_info
if setter.type:
setter_type = get_proper_type(
map_type_from_supertype(setter.type, t.type, super_info)
)
else:
return AnyType(TypeOfAny.unannotated)
if isinstance(setter_type, CallableType) and setter_type.arg_kinds == [
ARG_POS,
ARG_POS,
ARG_POS,
]:
return expand_type_by_instance(setter_type.arg_types[2], t)
else:
self._api.fail(
f'Unsupported signature for "__set__" in "{t.type.name}"', context
)
else:
self._api.fail(f'Unsupported "__set__" in "{t.type.name}"', context)

return default


def add_dataclass_tag(info: TypeInfo) -> None:
# The value is ignored, only the existence matters.
Expand Down
214 changes: 214 additions & 0 deletions test-data/unit/check-dataclass-transform.test
Original file line number Diff line number Diff line change
Expand Up @@ -807,3 +807,217 @@ reveal_type(bar.base) # N: Revealed type is "builtins.int"

[typing fixtures/typing-full.pyi]
[builtins fixtures/dataclasses.pyi]

[case testDataclassTransformSimpleDescriptor]
# flags: --python-version 3.11

from typing import dataclass_transform, overload, Any

@dataclass_transform()
def my_dataclass(cls): ...

class Desc:
@overload
def __get__(self, instance: None, owner: Any) -> Desc: ...
@overload
def __get__(self, instance: object, owner: Any) -> str: ...
def __get__(self, instance: object | None, owner: Any) -> Desc | str: ...

def __set__(self, instance: Any, value: str) -> None: ...

@my_dataclass
class C:
x: Desc
y: int

C(x='x', y=1)
C(x=1, y=1) # E: Argument "x" to "C" has incompatible type "int"; expected "str"
reveal_type(C(x='x', y=1).x) # N: Revealed type is "builtins.str"
reveal_type(C(x='x', y=1).y) # N: Revealed type is "builtins.int"
reveal_type(C.x) # N: Revealed type is "__main__.Desc"

[typing fixtures/typing-full.pyi]
[builtins fixtures/dataclasses.pyi]

[case testDataclassTransformUnannotatedDescriptor]
# flags: --python-version 3.11

from typing import dataclass_transform, overload, Any

@dataclass_transform()
def my_dataclass(cls): ...

class Desc:
@overload
def __get__(self, instance: None, owner: Any) -> Desc: ...
@overload
def __get__(self, instance: object, owner: Any) -> str: ...
def __get__(self, instance: object | None, owner: Any) -> Desc | str: ...

def __set__(*args, **kwargs): ...

@my_dataclass
class C:
x: Desc
y: int

C(x='x', y=1)
C(x=1, y=1)
reveal_type(C(x='x', y=1).x) # N: Revealed type is "builtins.str"
reveal_type(C(x='x', y=1).y) # N: Revealed type is "builtins.int"
reveal_type(C.x) # N: Revealed type is "__main__.Desc"

[typing fixtures/typing-full.pyi]
[builtins fixtures/dataclasses.pyi]

[case testDataclassTransformGenericDescriptor]
# flags: --python-version 3.11

from typing import dataclass_transform, overload, Any, TypeVar, Generic

@dataclass_transform()
def my_dataclass(frozen: bool = False): ...

T = TypeVar("T")

class Desc(Generic[T]):
@overload
def __get__(self, instance: None, owner: Any) -> Desc[T]: ...
@overload
def __get__(self, instance: object, owner: Any) -> T: ...
def __get__(self, instance: object | None, owner: Any) -> Desc | T: ...

def __set__(self, instance: Any, value: T) -> None: ...

@my_dataclass()
class C:
x: Desc[str]

C(x='x')
C(x=1) # E: Argument "x" to "C" has incompatible type "int"; expected "str"
reveal_type(C(x='x').x) # N: Revealed type is "builtins.str"
reveal_type(C.x) # N: Revealed type is "__main__.Desc[builtins.str]"

@my_dataclass()
class D(C):
y: Desc[int]

d = D(x='x', y=1)
reveal_type(d.x) # N: Revealed type is "builtins.str"
reveal_type(d.y) # N: Revealed type is "builtins.int"
reveal_type(D.x) # N: Revealed type is "__main__.Desc[builtins.str]"
reveal_type(D.y) # N: Revealed type is "__main__.Desc[builtins.int]"

@my_dataclass(frozen=True)
class F:
x: Desc[str] = Desc()

F(x='x')
F(x=1) # E: Argument "x" to "F" has incompatible type "int"; expected "str"
reveal_type(F(x='x').x) # N: Revealed type is "builtins.str"
reveal_type(F.x) # N: Revealed type is "__main__.Desc[builtins.str]"

[typing fixtures/typing-full.pyi]
[builtins fixtures/dataclasses.pyi]

[case testDataclassTransformGenericDescriptorWithInheritance]
# flags: --python-version 3.11

from typing import dataclass_transform, overload, Any, TypeVar, Generic

@dataclass_transform()
def my_dataclass(cls): ...

T = TypeVar("T")

class Desc(Generic[T]):
@overload
def __get__(self, instance: None, owner: Any) -> Desc[T]: ...
@overload
def __get__(self, instance: object, owner: Any) -> T: ...
def __get__(self, instance: object | None, owner: Any) -> Desc | T: ...

def __set__(self, instance: Any, value: T) -> None: ...

class Desc2(Desc[str]):
pass

@my_dataclass
class C:
x: Desc2

C(x='x')
C(x=1) # E: Argument "x" to "C" has incompatible type "int"; expected "str"
reveal_type(C(x='x').x) # N: Revealed type is "builtins.str"
reveal_type(C.x) # N: Revealed type is "__main__.Desc[builtins.str]"

[typing fixtures/typing-full.pyi]
[builtins fixtures/dataclasses.pyi]

[case testDataclassTransformDescriptorWithDifferentGetSetTypes]
# flags: --python-version 3.11

from typing import dataclass_transform, overload, Any

@dataclass_transform()
def my_dataclass(cls): ...

class Desc:
@overload
def __get__(self, instance: None, owner: Any) -> int: ...
@overload
def __get__(self, instance: object, owner: Any) -> str: ...
def __get__(self, instance, owner): ...

def __set__(self, instance: Any, value: bytes) -> None: ...

@my_dataclass
class C:
x: Desc

c = C(x=b'x')
C(x=1) # E: Argument "x" to "C" has incompatible type "int"; expected "bytes"
reveal_type(c.x) # N: Revealed type is "builtins.str"
reveal_type(C.x) # N: Revealed type is "builtins.int"
c.x = b'x'
c.x = 1 # E: Incompatible types in assignment (expression has type "int", variable has type "bytes")

[typing fixtures/typing-full.pyi]
[builtins fixtures/dataclasses.pyi]

[case testDataclassTransformUnsupportedDescriptors]
# flags: --python-version 3.11

from typing import dataclass_transform, overload, Any

@dataclass_transform()
def my_dataclass(cls): ...

class Desc:
@overload
def __get__(self, instance: None, owner: Any) -> int: ...
@overload
def __get__(self, instance: object, owner: Any) -> str: ...
def __get__(self, instance, owner): ...

def __set__(*args, **kwargs) -> None: ...

class Desc2:
@overload
def __get__(self, instance: None, owner: Any) -> int: ...
@overload
def __get__(self, instance: object, owner: Any) -> str: ...
def __get__(self, instance, owner): ...

@overload
def __set__(self, instance: Any, value: bytes) -> None: ...
@overload
def __set__(self) -> None: ...
def __set__(self, *args, **kawrga) -> None: ...

@my_dataclass
class C:
x: Desc # E: Unsupported signature for "__set__" in "Desc"
y: Desc2 # E: Unsupported "__set__" in "Desc2"
[typing fixtures/typing-full.pyi]
[builtins fixtures/dataclasses.pyi]