Skip to content

Commit

Permalink
fix: serialization whit schema inheritance. Closes #800 (#810)
Browse files Browse the repository at this point in the history
  • Loading branch information
marcosschroh authored Dec 2, 2024
1 parent 5b5fd30 commit 3ffaf94
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 14 deletions.
24 changes: 19 additions & 5 deletions dataclasses_avroschema/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,28 +100,42 @@ def standardize_custom_type(
) -> typing.Any:
if isinstance(value, dict):
return {
k: standardize_custom_type(field_name=field_name, value=v, model=model, base_class=base_class)
k: standardize_custom_type(
field_name=field_name, value=v, model=model, base_class=base_class, include_type=include_type
)
for k, v in value.items()
}
elif isinstance(value, list):
return [
standardize_custom_type(field_name=field_name, value=v, model=model, base_class=base_class) for v in value
standardize_custom_type(
field_name=field_name, value=v, model=model, base_class=base_class, include_type=include_type
)
for v in value
]
elif isinstance(value, tuple):
return tuple(
standardize_custom_type(field_name=field_name, value=v, model=model, base_class=base_class) for v in value
standardize_custom_type(
field_name=field_name, value=v, model=model, base_class=base_class, include_type=include_type
)
for v in value
)
elif isinstance(value, enum.Enum):
return value.value
elif isinstance(value, base_class):
if is_faust_record(type(value)): # type: ignore[arg-type]
# we need to do a trick because we can not overrride asdict from faust..
# once the function interface is introduced we can remove this check
asdict = value.standardize_type() # type: ignore
asdict = value.standardize_type(include_type=False) # type: ignore
else:
asdict = value.asdict()

if is_union(model.__annotations__[field_name]) and include_type:
annotations = model.__annotations__
# This is a hack to get the annotations from the parent class
# https://github.com/marcosschroh/dataclasses-avroschema/issues/800
if model.__class__.mro()[1] != base_class:
annotations.update(typing.get_type_hints(model.__class__))

if is_union(annotations[field_name]) and include_type:
asdict["-type"] = value.get_fullname()
return asdict

Expand Down
21 changes: 20 additions & 1 deletion tests/serialization/test_nested_schema_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
(AvroModel, dataclasses.dataclass),
(AvroBaseModel, lambda f: f),
(AvroBaseModelV1, lambda f: f),
(AvroRecord, lambda f: f),
(AvroRecord, dataclasses.dataclass),
],
)

Expand Down Expand Up @@ -360,3 +360,22 @@ class EventManager(model_class):

assert event_serialized == b"\x02\x1ehello Event two\x10EventTwo\xac\x02"
assert EventManager.deserialize(event_serialized) == event


@parametrize_base_model
def test_inheritance(model_class: typing.Type[AvroModel], decorator: typing.Callable) -> None:
@decorator
class A(model_class):
a: int

@decorator
class Parent(model_class):
p: A

@decorator
class Child(Parent):
c: int

child = Child(p=A(a=1), c=1)
ser = child.serialize()
assert Child.deserialize(ser) == child
33 changes: 25 additions & 8 deletions tests/serialization/test_recursive_schema_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,16 @@
from dataclasses_avroschema import AvroModel
from dataclasses_avroschema.faust import AvroRecord
from dataclasses_avroschema.pydantic import AvroBaseModel
from dataclasses_avroschema.pydantic.v1 import AvroBaseModel as AvroBaseModelV1

parametrize_base_model = pytest.mark.parametrize(
"model_class, decorator",
[(AvroModel, dataclasses.dataclass), (AvroBaseModel, lambda f: f), (AvroRecord, lambda f: f)],
[
(AvroModel, dataclasses.dataclass),
(AvroBaseModel, lambda f: f),
(AvroBaseModelV1, lambda f: f),
(AvroRecord, lambda f: f),
],
)


Expand Down Expand Up @@ -66,7 +72,13 @@ class User(model_class):


# don't test AvroBaseModel due to typing incompatibilites with Pydantic
@pytest.mark.parametrize("model_class, decorator", [(AvroModel, dataclasses.dataclass)])
@pytest.mark.parametrize(
"model_class, decorator",
[
(AvroModel, dataclasses.dataclass),
(AvroRecord, lambda f: f),
],
)
def test_self_one_to_many_relationship(model_class: typing.Type[AvroModel], decorator: typing.Callable):
"""
Test self relationship one-to-many serialization
Expand All @@ -78,7 +90,7 @@ class User(model_class):

name: str
age: int
friends: typing.List[typing.Type["User"]]
friends: typing.Optional[typing.List[typing.Type["User"]]] = None

data_friend = {"name": "john", "age": 20, "friends": []}
friend = User(**data_friend)
Expand All @@ -96,7 +108,7 @@ class User(model_class):
"friends": [data_friend],
}

avro_binary = b"\x08juan(\x02\x08john(\x00\x00"
avro_binary = b"\x08juan(\x02\x02\x08john(\x02\x00\x00"
# avro_json = b'{"name": "juan", "age": 20, "friends": [{"name": "john", "age": 20, "friends": []}]}'

assert user.serialize() == avro_binary
Expand All @@ -115,8 +127,13 @@ class User(model_class):
assert user.to_dict() == expected


# don't test AvroBaseModel due to typing incompatibilites with Pydantic
@pytest.mark.parametrize("model_class, decorator", [(AvroModel, dataclasses.dataclass)])
@pytest.mark.parametrize(
"model_class, decorator",
[
(AvroModel, dataclasses.dataclass),
(AvroRecord, lambda f: f),
],
)
def test_self_one_to_many_map_relationship(model_class: typing.Type[AvroModel], decorator: typing.Callable):
"""
Test self relationship one-to-many Map serialization
Expand All @@ -128,7 +145,7 @@ class User(model_class):

name: str
age: int
friends: typing.Dict[str, typing.Type["User"]] = None
friends: typing.Optional[typing.Dict[str, typing.Type["User"]]] = None

data_friend = {
"name": "john",
Expand All @@ -144,7 +161,7 @@ class User(model_class):
}
user = User(**data_user)

avro_binary = b"\x08juan(\x02\x10a_friend\x08john(\x00\x00"
avro_binary = b"\x08juan(\x02\x02\x10a_friend\x08john(\x02\x00\x00"
# avro_json = b'{"name": "juan", "age": 20, "friends": {"a_friend": {"name": "john", "age": 20, "friends": {}}}}'

expected = {
Expand Down

0 comments on commit 3ffaf94

Please sign in to comment.