Skip to content

Commit

Permalink
Implement support for returning TypedDict for dataclasses.asdict
Browse files Browse the repository at this point in the history
Relates to python#5152
  • Loading branch information
syastrov committed Mar 26, 2020
1 parent 52c0a63 commit fd1cc92
Show file tree
Hide file tree
Showing 7 changed files with 461 additions and 17 deletions.
15 changes: 14 additions & 1 deletion docs/source/additional_features.rst
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,22 @@ and :pep:`557`.
Caveats/Known Issues
====================

Some functions in the :py:mod:`dataclasses` module, such as :py:func:`~dataclasses.replace` and :py:func:`~dataclasses.asdict`,
Some functions in the :py:mod:`dataclasses` module, such as :py:func:`~dataclasses.replace`,
have imprecise (too permissive) types. This will be fixed in future releases.

Calls to :py:func:`~dataclasses.asdict` will return a ``TypedDict`` based on the original dataclass
definition, transforming it recursively. There are, however, some limitations:

* Subclasses of ``List``, ``Dict``, and ``Tuple`` appearing within dataclasses are transformed into reparameterized
versions of the respective base class, rather than a transformed version of the original subclass.

* Recursion (e.g. dataclasses which reference each other) is not supported and results in an error.

* ``NamedTuples`` appearing within dataclasses are transformed to ``Any``

* A more precise return type cannot be inferred for calls where ``dict_factory`` is set.


Mypy does not yet recognize aliases of :py:func:`dataclasses.dataclass <dataclasses.dataclass>`, and will
probably never recognize dynamically computed decorators. The following examples
do **not** work:
Expand Down
1 change: 1 addition & 0 deletions mypy/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ class CheckerPluginInterface:
docstrings in checker.py for more details.
"""

modules = None # type: Dict[str, MypyFile]
msg = None # type: MessageBuilder
options = None # type: Options
path = None # type: str
Expand Down
27 changes: 23 additions & 4 deletions mypy/plugins/common.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from typing import List, Optional, Union
from collections import OrderedDict
from typing import List, Optional, Union, Set

from mypy.nodes import (
ARG_POS, MDEF, Argument, Block, CallExpr, ClassDef, Expression, SYMBOL_FUNCBASE_TYPES,
FuncDef, PassStmt, RefExpr, SymbolTableNode, Var, JsonDict,
)
from mypy.plugin import ClassDefContext, SemanticAnalyzerPluginInterface
from mypy.plugin import ClassDefContext, SemanticAnalyzerPluginInterface, CheckerPluginInterface
from mypy.semanal import set_callable_name
from mypy.types import (
CallableType, Overloaded, Type, TypeVarDef, deserialize_type, get_proper_type,
)
TypedDictType, Instance, TPDICT_FB_NAMES)
from mypy.typevars import fill_typevars
from mypy.util import get_unique_redefinition_name
from mypy.typeops import try_getting_str_literals # noqa: F401 # Part of public API
Expand Down Expand Up @@ -155,8 +156,26 @@ def add_method_to_class(


def deserialize_and_fixup_type(
data: Union[str, JsonDict], api: SemanticAnalyzerPluginInterface
data: Union[str, JsonDict],
api: Union[SemanticAnalyzerPluginInterface, CheckerPluginInterface]
) -> Type:
typ = deserialize_type(data)
typ.accept(TypeFixer(api.modules, allow_missing=False))
return typ


def get_anonymous_typeddict_type(api: CheckerPluginInterface) -> Instance:
for type_fullname in TPDICT_FB_NAMES:
try:
anonymous_typeddict_type = api.named_generic_type(type_fullname, [])
if anonymous_typeddict_type is not None:
return anonymous_typeddict_type
except KeyError:
continue
raise RuntimeError("No TypedDict fallback type found")


def make_anonymous_typeddict(api: CheckerPluginInterface, fields: 'OrderedDict[str, Type]',
required_keys: Set[str]) -> TypedDictType:
return TypedDictType(fields, required_keys=required_keys,
fallback=get_anonymous_typeddict_type(api))
127 changes: 116 additions & 11 deletions mypy/plugins/dataclasses.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,23 @@
"""Plugin that provides support for dataclasses."""

from typing import Dict, List, Set, Tuple, Optional
from collections import OrderedDict
from typing import Dict, List, Set, Tuple, Optional, FrozenSet, Callable, Union

from typing_extensions import Final

from mypy.maptype import map_instance_to_supertype
from mypy.nodes import (
ARG_OPT, ARG_POS, MDEF, Argument, AssignmentStmt, CallExpr,
Context, Expression, JsonDict, NameExpr, RefExpr,
SymbolTableNode, TempNode, TypeInfo, Var, TypeVarExpr, PlaceholderNode
)
from mypy.plugin import ClassDefContext, SemanticAnalyzerPluginInterface
from mypy.plugins.common import (
add_method, _get_decorator_bool_argument, deserialize_and_fixup_type,
ARG_OPT, ARG_POS, MDEF, Argument, AssignmentStmt, CallExpr, Context,
Expression, JsonDict, NameExpr, RefExpr, SymbolTableNode, TempNode,
TypeInfo, Var, TypeVarExpr, PlaceholderNode
)
from mypy.types import Type, Instance, NoneType, TypeVarDef, TypeVarType, get_proper_type
from mypy.plugin import ClassDefContext, FunctionContext, CheckerPluginInterface
from mypy.plugin import SemanticAnalyzerPluginInterface
from mypy.plugins.common import (add_method, _get_decorator_bool_argument,
make_anonymous_typeddict, deserialize_and_fixup_type)
from mypy.server.trigger import make_wildcard_trigger
from mypy.types import (Instance, NoneType, TypeVarDef, TypeVarType, get_proper_type, Type,
TupleType, UnionType, AnyType, TypeOfAny)

# The set of decorators that generate dataclasses.
dataclass_makers = {
Expand All @@ -24,6 +28,10 @@
SELF_TVAR_NAME = '_DT' # type: Final


def is_type_dataclass(info: TypeInfo) -> bool:
return 'dataclass' in info.metadata


class DataclassAttribute:
def __init__(
self,
Expand Down Expand Up @@ -68,7 +76,8 @@ def serialize(self) -> JsonDict:

@classmethod
def deserialize(
cls, info: TypeInfo, data: JsonDict, api: SemanticAnalyzerPluginInterface
cls, info: TypeInfo, data: JsonDict,
api: Union[SemanticAnalyzerPluginInterface, CheckerPluginInterface]
) -> 'DataclassAttribute':
data = data.copy()
typ = deserialize_and_fixup_type(data.pop('type'), api)
Expand Down Expand Up @@ -297,7 +306,7 @@ def collect_attributes(self) -> Optional[List[DataclassAttribute]]:
# we'll have unmodified attrs laying around.
all_attrs = attrs.copy()
for info in cls.info.mro[1:-1]:
if 'dataclass' not in info.metadata:
if not is_type_dataclass(info):
continue

super_attrs = []
Expand Down Expand Up @@ -386,3 +395,99 @@ def _collect_field_args(expr: Expression) -> Tuple[bool, Dict[str, Expression]]:
args[name] = arg
return True, args
return False, {}


def asdict_callback(ctx: FunctionContext) -> Type:
positional_arg_types = ctx.arg_types[0]

if positional_arg_types:
if len(ctx.arg_types) == 2:
# We can't infer a more precise for calls where dict_factory is set.
# At least for now, typeshed stubs for asdict don't allow you to pass in `dict` as
# dict_factory, so we can't special-case that.
return ctx.default_return_type
dataclass_instance = positional_arg_types[0]
dataclass_instance = get_proper_type(dataclass_instance)
if isinstance(dataclass_instance, Instance):
info = dataclass_instance.type
if not is_type_dataclass(info):
ctx.api.fail('asdict() should be called on dataclass instances',
dataclass_instance)
return _asdictify(ctx.api, ctx.context, dataclass_instance)
return ctx.default_return_type


def _transform_type_args(*, typ: Instance, transform: Callable[[Instance], Type]) -> List[Type]:
"""For each type arg used in the Instance, call transform function on it if the arg is an
Instance."""
new_args = []
for arg in typ.args:
proper_arg = get_proper_type(arg)
if isinstance(proper_arg, Instance):
new_args.append(transform(proper_arg))
else:
new_args.append(arg)
return new_args


def _asdictify(api: CheckerPluginInterface, context: Context, typ: Type) -> Type:
"""Convert dataclasses into TypedDicts, recursively looking into built-in containers.
It will look for dataclasses inside of tuples, lists, and dicts and convert them to TypedDicts.
"""

def _asdictify_inner(typ: Type, seen_dataclasses: FrozenSet[str]) -> Type:
typ = get_proper_type(typ)
if isinstance(typ, UnionType):
return UnionType([_asdictify_inner(item, seen_dataclasses) for item in typ.items])
if isinstance(typ, Instance):
info = typ.type
if is_type_dataclass(info):
if info.fullname in seen_dataclasses:
api.fail(
"Recursive types are not supported in call to asdict, so falling back to "
"Dict[str, Any]",
context)
# Note: Would be nicer to fallback to default_return_type, but that is Any
# (due to overloads?)
return api.named_generic_type('builtins.dict',
[api.named_generic_type('builtins.str', []),
AnyType(TypeOfAny.implementation_artifact)])
seen_dataclasses |= {info.fullname}
attrs = info.metadata['dataclass']['attributes']
fields = OrderedDict() # type: OrderedDict[str, Type]
for data in attrs:
attr = DataclassAttribute.deserialize(info, data, api)
sym_node = info.names[attr.name]
attr_type = sym_node.type
assert attr_type is not None
fields[attr.name] = _asdictify_inner(attr_type, seen_dataclasses)
return make_anonymous_typeddict(api, fields=fields,
required_keys=set(fields.keys()))
elif info.has_base('builtins.list'):
supertype_instance = map_instance_to_supertype(typ, api.named_generic_type(
'builtins.list', []).type)
new_args = _transform_type_args(
typ=supertype_instance,
transform=lambda arg: _asdictify_inner(arg, seen_dataclasses))
return api.named_generic_type('builtins.list', new_args)
elif info.has_base('builtins.dict'):
supertype_instance = map_instance_to_supertype(typ, api.named_generic_type(
'builtins.dict', []).type)
new_args = _transform_type_args(
typ=supertype_instance,
transform=lambda arg: _asdictify_inner(arg, seen_dataclasses))
return api.named_generic_type('builtins.dict', new_args)
elif isinstance(typ, TupleType):
if typ.partial_fallback.type.is_named_tuple:
# For namedtuples, return Any. To properly support transforming namedtuples,
# we would have to generate a partial_fallback type for the TupleType and add it
# to the symbol table. It's not currently possibl to do this via the
# CheckerPluginInterface. Ideally it would use the same code as
# NamedTupleAnalyzer.build_namedtuple_typeinfo.
return AnyType(TypeOfAny.implementation_artifact)
return TupleType([_asdictify_inner(item, seen_dataclasses) for item in typ.items],
api.named_generic_type('builtins.tuple', []), implicit=typ.implicit)
return typ

return _asdictify_inner(typ, seen_dataclasses=frozenset())
3 changes: 3 additions & 0 deletions mypy/plugins/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,16 @@ class DefaultPlugin(Plugin):
def get_function_hook(self, fullname: str
) -> Optional[Callable[[FunctionContext], Type]]:
from mypy.plugins import ctypes
from mypy.plugins import dataclasses

if fullname == 'contextlib.contextmanager':
return contextmanager_callback
elif fullname == 'builtins.open' and self.python_version[0] == 3:
return open_callback
elif fullname == 'ctypes.Array':
return ctypes.array_constructor_callback
elif fullname == 'dataclasses.asdict':
return dataclasses.asdict_callback
return None

def get_method_signature_hook(self, fullname: str
Expand Down
Loading

0 comments on commit fd1cc92

Please sign in to comment.