From 4d932109ac93fba17daca49e5c6be50f8c63c6ab Mon Sep 17 00:00:00 2001 From: marcosschroh Date: Mon, 2 Dec 2024 21:21:17 +0100 Subject: [PATCH] fix: serialization whit schema inheritance. Closes #800 --- dataclasses_avroschema/utils.py | 24 +++++++++++--- .../test_nested_schema_serialization.py | 21 +++++++++++- .../test_recursive_schema_serialization.py | 33 ++++++++++++++----- 3 files changed, 64 insertions(+), 14 deletions(-) diff --git a/dataclasses_avroschema/utils.py b/dataclasses_avroschema/utils.py index aecd6b69..03e94497 100644 --- a/dataclasses_avroschema/utils.py +++ b/dataclasses_avroschema/utils.py @@ -100,16 +100,24 @@ def standardize_custom_type( ) -> typing.Any: if isinstance(value, dict): return { - k: standardize_custom_type(field_name=field_name, value=v, model=model, base_class=base_class) + k: standardize_custom_type( + field_name=field_name, value=v, model=model, base_class=base_class, include_type=include_type + ) for k, v in value.items() } elif isinstance(value, list): return [ - standardize_custom_type(field_name=field_name, value=v, model=model, base_class=base_class) for v in value + standardize_custom_type( + field_name=field_name, value=v, model=model, base_class=base_class, include_type=include_type + ) + for v in value ] elif isinstance(value, tuple): return tuple( - standardize_custom_type(field_name=field_name, value=v, model=model, base_class=base_class) for v in value + standardize_custom_type( + field_name=field_name, value=v, model=model, base_class=base_class, include_type=include_type + ) + for v in value ) elif isinstance(value, enum.Enum): return value.value @@ -117,11 +125,17 @@ def standardize_custom_type( if is_faust_record(type(value)): # type: ignore[arg-type] # we need to do a trick because we can not overrride asdict from faust.. # once the function interface is introduced we can remove this check - asdict = value.standardize_type() # type: ignore + asdict = value.standardize_type(include_type=False) # type: ignore else: asdict = value.asdict() - if is_union(model.__annotations__[field_name]) and include_type: + annotations = model.__annotations__ + # This is a hack to get the annotations from the parent class + # https://github.com/marcosschroh/dataclasses-avroschema/issues/800 + if model.__class__.mro()[1] != base_class: + annotations.update(typing.get_type_hints(model.__class__)) + + if is_union(annotations[field_name]) and include_type: asdict["-type"] = value.get_fullname() return asdict diff --git a/tests/serialization/test_nested_schema_serialization.py b/tests/serialization/test_nested_schema_serialization.py index d9fd7b19..d17033e7 100644 --- a/tests/serialization/test_nested_schema_serialization.py +++ b/tests/serialization/test_nested_schema_serialization.py @@ -15,7 +15,7 @@ (AvroModel, dataclasses.dataclass), (AvroBaseModel, lambda f: f), (AvroBaseModelV1, lambda f: f), - (AvroRecord, lambda f: f), + (AvroRecord, dataclasses.dataclass), ], ) @@ -360,3 +360,22 @@ class EventManager(model_class): assert event_serialized == b"\x02\x1ehello Event two\x10EventTwo\xac\x02" assert EventManager.deserialize(event_serialized) == event + + +@parametrize_base_model +def test_inheritance(model_class: typing.Type[AvroModel], decorator: typing.Callable) -> None: + @decorator + class A(model_class): + a: int + + @decorator + class Parent(model_class): + p: A + + @decorator + class Child(Parent): + c: int + + child = Child(p=A(a=1), c=1) + ser = child.serialize() + assert Child.deserialize(ser) == child diff --git a/tests/serialization/test_recursive_schema_serialization.py b/tests/serialization/test_recursive_schema_serialization.py index ebc35114..e6e8e389 100644 --- a/tests/serialization/test_recursive_schema_serialization.py +++ b/tests/serialization/test_recursive_schema_serialization.py @@ -6,10 +6,16 @@ from dataclasses_avroschema import AvroModel from dataclasses_avroschema.faust import AvroRecord from dataclasses_avroschema.pydantic import AvroBaseModel +from dataclasses_avroschema.pydantic.v1 import AvroBaseModel as AvroBaseModelV1 parametrize_base_model = pytest.mark.parametrize( "model_class, decorator", - [(AvroModel, dataclasses.dataclass), (AvroBaseModel, lambda f: f), (AvroRecord, lambda f: f)], + [ + (AvroModel, dataclasses.dataclass), + (AvroBaseModel, lambda f: f), + (AvroBaseModelV1, lambda f: f), + (AvroRecord, lambda f: f), + ], ) @@ -66,7 +72,13 @@ class User(model_class): # don't test AvroBaseModel due to typing incompatibilites with Pydantic -@pytest.mark.parametrize("model_class, decorator", [(AvroModel, dataclasses.dataclass)]) +@pytest.mark.parametrize( + "model_class, decorator", + [ + (AvroModel, dataclasses.dataclass), + (AvroRecord, lambda f: f), + ], +) def test_self_one_to_many_relationship(model_class: typing.Type[AvroModel], decorator: typing.Callable): """ Test self relationship one-to-many serialization @@ -78,7 +90,7 @@ class User(model_class): name: str age: int - friends: typing.List[typing.Type["User"]] + friends: typing.Optional[typing.List[typing.Type["User"]]] = None data_friend = {"name": "john", "age": 20, "friends": []} friend = User(**data_friend) @@ -96,7 +108,7 @@ class User(model_class): "friends": [data_friend], } - avro_binary = b"\x08juan(\x02\x08john(\x00\x00" + avro_binary = b"\x08juan(\x02\x02\x08john(\x02\x00\x00" # avro_json = b'{"name": "juan", "age": 20, "friends": [{"name": "john", "age": 20, "friends": []}]}' assert user.serialize() == avro_binary @@ -115,8 +127,13 @@ class User(model_class): assert user.to_dict() == expected -# don't test AvroBaseModel due to typing incompatibilites with Pydantic -@pytest.mark.parametrize("model_class, decorator", [(AvroModel, dataclasses.dataclass)]) +@pytest.mark.parametrize( + "model_class, decorator", + [ + (AvroModel, dataclasses.dataclass), + (AvroRecord, lambda f: f), + ], +) def test_self_one_to_many_map_relationship(model_class: typing.Type[AvroModel], decorator: typing.Callable): """ Test self relationship one-to-many Map serialization @@ -128,7 +145,7 @@ class User(model_class): name: str age: int - friends: typing.Dict[str, typing.Type["User"]] = None + friends: typing.Optional[typing.Dict[str, typing.Type["User"]]] = None data_friend = { "name": "john", @@ -144,7 +161,7 @@ class User(model_class): } user = User(**data_user) - avro_binary = b"\x08juan(\x02\x10a_friend\x08john(\x00\x00" + avro_binary = b"\x08juan(\x02\x02\x10a_friend\x08john(\x02\x00\x00" # avro_json = b'{"name": "juan", "age": 20, "friends": {"a_friend": {"name": "john", "age": 20, "friends": {}}}}' expected = {