Skip to content

Commit

Permalink
Merge pull request #121 from Fatal1ty/fix-post-deserialize-hook-pass-…
Browse files Browse the repository at this point in the history
…through

Fix generated code for dataclass with hook and pass_through fields
  • Loading branch information
Fatal1ty authored Jun 23, 2023
2 parents abaef00 + 0f69984 commit 769d7fc
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 53 deletions.
72 changes: 21 additions & 51 deletions mashumaro/core/meta/code/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,6 @@ def _add_unpack_method_lines(self, method_name: str) -> None:
f"{type_name(self.cls)}] signature"
)
filtered_fields = []
kwargs_only = post_deserialize is not None
pos_args = []
kw_args = []
missing_kw_only = False
Expand Down Expand Up @@ -375,7 +374,7 @@ def _add_unpack_method_lines(self, method_name: str) -> None:
filtered_fields.append((fname, ftype))
if filtered_fields:
with self.indent("try:"):
if kwargs_only or can_be_kwargs:
if can_be_kwargs:
self.add_line("kwargs = {}")
for fname, ftype in filtered_fields:
self.add_type_modules(ftype)
Expand All @@ -384,11 +383,7 @@ def _add_unpack_method_lines(self, method_name: str) -> None:
if alias is None:
alias = config.aliases.get(fname)
self._unpack_method_set_value(
fname,
ftype,
metadata,
alias=alias,
kwargs_only=kwargs_only,
fname, ftype, metadata, alias=alias
)
with self.indent("except TypeError:"):
with self.indent("if not isinstance(d, dict):"):
Expand All @@ -399,19 +394,18 @@ def _add_unpack_method_lines(self, method_name: str) -> None:
)
with self.indent("else:"):
self.add_line("raise")
else:
self.add_line("kwargs = {}")

args = [f"__{f}" for f in pos_args]
for kw_arg in kw_args:
args.append(f"{kw_arg}=__{kw_arg}")
if can_be_kwargs:
args.append("**kwargs")
cls_inst = f"cls({', '.join(args)})"

if post_deserialize:
self.add_line(
f"return cls.{__POST_DESERIALIZE__}(cls(**kwargs))"
)
self.add_line(f"return cls.{__POST_DESERIALIZE__}({cls_inst})")
else:
args = [f"__{f}" for f in pos_args]
for kw_arg in kw_args:
args.append(f"{kw_arg}=__{kw_arg}")
if can_be_kwargs:
args.append("**kwargs")
self.add_line(f"return cls({', '.join(args)})")
self.add_line(f"return {cls_inst}")

def _add_unpack_method_with_dialect_lines(self, method_name: str) -> None:
if self.decoder is not None:
Expand Down Expand Up @@ -485,7 +479,6 @@ def _unpack_method_set_value(
metadata: typing.Mapping,
*,
alias: typing.Optional[str] = None,
kwargs_only: bool = False,
) -> None:
default = self.get_field_default(fname)
has_default = default is not MISSING
Expand Down Expand Up @@ -531,71 +524,48 @@ def _unpack_method_set_value(
if could_be_none:
with self.indent(f"if {packed_value} is not None:"):
self.__unpack_try_set_value(
fname,
field_type,
unpacked_value,
kwargs_only,
has_default,
fname, field_type, unpacked_value, has_default
)
with self.indent("else:"):
self.__unpack_set_value(
fname, "None", kwargs_only or has_default
)
self.__unpack_set_value(fname, "None", has_default)
else:
self.__unpack_try_set_value(
fname,
field_type,
unpacked_value,
kwargs_only,
has_default,
fname, field_type, unpacked_value, has_default
)
else:
with self.indent(f"if {packed_value} is not MISSING:"):
if could_be_none:
with self.indent(f"if {packed_value} is not None:"):
self.__unpack_try_set_value(
fname,
field_type,
unpacked_value,
kwargs_only,
has_default,
fname, field_type, unpacked_value, has_default
)
if default is not None:
with self.indent("else:"):
self.__unpack_set_value(
fname, "None", kwargs_only or has_default
)
self.__unpack_set_value(fname, "None", has_default)
else:
self.__unpack_try_set_value(
fname,
field_type,
unpacked_value,
kwargs_only,
has_default,
fname, field_type, unpacked_value, has_default
)

def __unpack_try_set_value(
self,
field_name: str,
field_type_name: str,
unpacked_value: str,
kwargs_only: bool,
has_default: bool,
) -> None:
with self.indent("try:"):
self.__unpack_set_value(
field_name, unpacked_value, kwargs_only or has_default
)
self.__unpack_set_value(field_name, unpacked_value, has_default)
with self.indent("except:"):
self.add_line(
f"raise InvalidFieldValue("
f"'{field_name}',{field_type_name},value,cls)"
)

def __unpack_set_value(
self, fname: str, unpacked_value: str, kwargs_only: bool
self, fname: str, unpacked_value: str, in_kwargs: bool
) -> None:
if kwargs_only:
if in_kwargs:
self.add_line(f"kwargs['{fname}'] = {unpacked_value}")
else:
self.add_line(f"__{fname} = {unpacked_value}")
Expand Down
26 changes: 24 additions & 2 deletions tests/test_hooks.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Any, ClassVar, Dict, Optional, no_type_check

import pytest

from mashumaro import DataClassDictMixin
from mashumaro import DataClassDictMixin, field_options, pass_through
from mashumaro.config import ADD_SERIALIZATION_CONTEXT, BaseConfig
from mashumaro.exceptions import BadHookSignature

Expand Down Expand Up @@ -193,6 +193,28 @@ class B(A, DataClassDictMixin):
post_serialize_hook.assert_called_once()


def test_post_deserialize_hook_with_pass_through_field():
@dataclass
class MyClass(DataClassDictMixin):
x: int = field(metadata=field_options(deserialize=pass_through))

@classmethod
def __post_deserialize__(cls, obj):
return obj

assert MyClass.from_dict({"x": 42}) == MyClass(42)


def test_post_deserialize_hook_with_empty_dataclass():
@dataclass
class MyClass(DataClassDictMixin):
@classmethod
def __post_deserialize__(cls, obj):
return obj # pragma no cover

assert MyClass.from_dict({}) == MyClass()


def test_passing_context_into_hook():
foo = FooBarBaz(foo=Foo(1), bar=Bar(baz=2), baz=3)
assert foo.to_dict() == {"foo": {"baz": 1}, "bar": {"baz": 2}, "baz": 3}
Expand Down

0 comments on commit 769d7fc

Please sign in to comment.