From e67279fa41a0b2a929062d9c33711e8cffa7d6e8 Mon Sep 17 00:00:00 2001 From: Marcos Schroh <2828842+marcosschroh@users.noreply.github.com> Date: Wed, 9 Oct 2024 14:10:36 +0200 Subject: [PATCH] fix: serialization improvements for nested unions. Related to #763 (#770) --- dataclasses_avroschema/schema_generator.py | 23 +++++ dataclasses_avroschema/serialization.py | 29 ++++-- dataclasses_avroschema/utils.py | 2 +- pyproject.toml | 3 + tests/schemas/test_fastavro_paser_schema.py | 80 ++++++++++++++++- tests/serialization/test_model_utils.py | 2 +- .../test_nested_schema_serialization.py | 89 ++++++++++++------- 7 files changed, 180 insertions(+), 48 deletions(-) diff --git a/dataclasses_avroschema/schema_generator.py b/dataclasses_avroschema/schema_generator.py index 1f9bf7c2..4e9e5f84 100644 --- a/dataclasses_avroschema/schema_generator.py +++ b/dataclasses_avroschema/schema_generator.py @@ -46,6 +46,29 @@ def get_metadata(cls: "Type[CT]") -> SchemaMetadata: cls._metadata = SchemaMetadata.create(meta) return cls._metadata + @classmethod + def get_fullname(cls) -> str: + """ + Fullname is composed of two parts: a name and a namespace + separated by a dot. A namespace is a dot-separated sequence of such names. + The empty string may also be used as a namespace to indicate the null namespace. + Equality of names (including field names and enum symbols) + as well as fullnames is case-sensitive. + """ + # we need to make sure that the schema has been generated + cls.generate_schema() + metadata = cls.get_metadata() + + if metadata.namespace: + # if the current record has a namespace we use it + return f"{metadata.namespace}.{cls.__name__}" + elif cls._parent is not None: + # if the record has a parent then we try to use the parent namespace + parent_metadata = cls._parent.get_metadata() + if parent_metadata.namespace: + return f"{parent_metadata.namespace}.{cls.__name__}" + return cls.__name__ + @classmethod def generate_schema(cls: "Type[CT]", schema_type: str = "avro") -> Optional[OrderedDict]: if cls._parser is None or cls.__mro__[1] != AvroModel: diff --git a/dataclasses_avroschema/serialization.py b/dataclasses_avroschema/serialization.py index a6c4a502..59da3762 100644 --- a/dataclasses_avroschema/serialization.py +++ b/dataclasses_avroschema/serialization.py @@ -68,10 +68,10 @@ def deserialize( input_stream.flush() - return sanitize_unions(data=payload, model=model) # type: ignore + return sanitize_payload(data=payload, model=model) # type: ignore -def sanitize_unions(*, data: JsonDict, model: "CT") -> JsonDict: +def sanitize_payload(*, data: JsonDict, model: "CT") -> JsonDict: """ This function tries to convert cast all the cases that have `unions` with a Tuple format (AvroType, payload), for example @@ -81,20 +81,31 @@ def sanitize_unions(*, data: JsonDict, model: "CT") -> JsonDict: cleaned_data = {} for field_name, field_value in data.items(): if isinstance(field_value, dict): - field_value = sanitize_unions(data=field_value, model=model) + field_value = sanitize_payload(data=field_value, model=model) elif isinstance(field_value, tuple) and len(field_value) == 2: - # the first value is the model/record name and the second - # is its payload - model_name, model_dict_value = field_value - avro_model = model.get_user_defined_type(name=model_name) - if avro_model is not None: - field_value = avro_model.parse_obj(model_dict_value) + field_value = sanitize_union(union=field_value, model=model) cleaned_data[field_name] = field_value return cleaned_data +def sanitize_union(*, union: typing.Tuple, model: "CT") -> typing.Optional["CT"]: + # the first value is the model/record name and the second is its payload + model_name, model_value = union + if isinstance(model_value, dict): + # it can be a dict again so we need to sanitize + model_value = sanitize_payload(data=model_value, model=model) + + model_name = model_name.split(".")[-1] + + avro_model = model.get_user_defined_type(name=model_name) + if avro_model is not None: + return avro_model.parse_obj(model_value) + + return model_value + + def datetime_to_str(value: datetime.datetime) -> str: return value.strftime(DATETIME_STR_FORMAT) diff --git a/dataclasses_avroschema/utils.py b/dataclasses_avroschema/utils.py index 924eec0d..b7b2ef28 100644 --- a/dataclasses_avroschema/utils.py +++ b/dataclasses_avroschema/utils.py @@ -124,7 +124,7 @@ def standardize_custom_type( asdict = value.asdict() if is_union(model.__annotations__[field_name]) and include_type: - asdict["-type"] = value.__class__.__name__ + asdict["-type"] = value.get_fullname() return asdict return value diff --git a/pyproject.toml b/pyproject.toml index 0f04e310..b9fcd83d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,6 +67,9 @@ cli = [ requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" +[tool.mypy] +check_untyped_defs = true + [[tool.mypy.overrides]] module = "stringcase.*" ignore_missing_imports = true diff --git a/tests/schemas/test_fastavro_paser_schema.py b/tests/schemas/test_fastavro_paser_schema.py index b99a3140..1627be6c 100644 --- a/tests/schemas/test_fastavro_paser_schema.py +++ b/tests/schemas/test_fastavro_paser_schema.py @@ -115,12 +115,14 @@ def test_one_to_one_schema(): Test relationship one-to-one """ + @dataclasses.dataclass class Address(AvroModel): "An Address" street: str street_number: int + @dataclasses.dataclass class User(AvroModel): "An User with Address" @@ -136,6 +138,7 @@ def test_one_to_one_repeated_schema(): Test relationship one-to-one with more than once schema """ + @dataclasses.dataclass class Location(AvroModel): latitude: float longitude: float @@ -143,6 +146,7 @@ class Location(AvroModel): class Meta: namespace = "types.location_type" + @dataclasses.dataclass class Trip(AvroModel): start_time: datetime.datetime start_location: Location @@ -155,11 +159,13 @@ class Trip(AvroModel): def test_repeated_schema_without_namespace(): + @dataclasses.dataclass class Bus(AvroModel): "A Bus" engine_name: str + @dataclasses.dataclass class UnionSchema(AvroModel): "Some Unions" @@ -174,6 +180,7 @@ def test_one_to_one_repeated_schema_in_array(): Test relationship one-to-one with more than once schema """ + @dataclasses.dataclass class Location(AvroModel): latitude: float longitude: float @@ -181,6 +188,7 @@ class Location(AvroModel): class Meta: namespace = "types.location_type" + @dataclasses.dataclass class Trip(AvroModel): start_time: datetime.datetime start_location: Location @@ -196,6 +204,7 @@ def test_one_to_one_repeated_schema_in_map(): Test relationship one-to-one with more than once schema """ + @dataclasses.dataclass class Location(AvroModel): latitude: float longitude: float @@ -203,6 +212,7 @@ class Location(AvroModel): class Meta: namespace = "types.location_type" + @dataclasses.dataclass class Trip(AvroModel): start_time: datetime.datetime start_location: Location @@ -214,6 +224,7 @@ class Trip(AvroModel): def test_one_to_many_repeated_schema_in_array_and_map(): + @dataclasses.dataclass class User(AvroModel): name: str @@ -221,6 +232,7 @@ class Meta: schema_doc = False namespace = "types.user" + @dataclasses.dataclass class UserAdvance(AvroModel): users: typing.List[User] accounts: typing.Dict[str, User] @@ -237,12 +249,14 @@ def test_one_to_many_schema(): Test relationship one-to-many """ + @dataclasses.dataclass class Address(AvroModel): "An Address" street: str street_number: int + @dataclasses.dataclass class User(AvroModel): "User with multiple Address" @@ -258,12 +272,14 @@ def test_one_to_many_with_map_schema(): Test relationship one-to-many using a map """ + @dataclasses.dataclass class Address(AvroModel): "An Address" street: str street_number: int + @dataclasses.dataclass class User(AvroModel): "User with multiple Address" @@ -279,6 +295,7 @@ def test_one_to_one_self_relationship(): Test self relationship one-to-one """ + @dataclasses.dataclass class User(AvroModel): "User with self reference as friend" @@ -294,6 +311,7 @@ def test_one_to_many_self_reference_schema(): Test self relationship one-to-many using an array """ + @dataclasses.dataclass class User(AvroModel): "User with self reference as friends" @@ -309,6 +327,7 @@ def test_one_to_many_self_reference_map_schema(): Test self relationship one-to-many using a map """ + @dataclasses.dataclass class User(AvroModel): "User with self reference as friends" @@ -325,6 +344,7 @@ def test_logical_types_schema(): """ a_datetime = datetime.datetime(2019, 10, 12, 17, 57, 42, 179133) + @dataclasses.dataclass class LogicalTypes(AvroModel): "Some logical types" @@ -342,6 +362,7 @@ def test_logical_micro_types_schema(): """ a_datetime = datetime.datetime(2019, 10, 12, 17, 57, 42) + @dataclasses.dataclass class LogicalTypesMicro(AvroModel): "Some logical types" @@ -357,6 +378,7 @@ class LogicalTypesMicro(AvroModel): def test_schema_with_union_types(): + @dataclasses.dataclass class UnionSchema(AvroModel): "Some Unions" @@ -398,24 +420,64 @@ class ArrayUnionSchema(AvroModel): assert parse_schema(ArrayUnionSchema.avro_schema_to_python()) -def test_namespaces(): +def test_schema_fullname(): + class B(AvroModel): + ... + + class Meta: + namespace = "my.namespace" + + class A(AvroModel): ... + + assert A.get_fullname() == "A" + assert B.get_fullname() == "my.namespace.B" + + # check that A is inside the `__named_schemas` + parsed_schema = parse_schema(A.avro_schema_to_python()) + assert A.get_fullname() in parsed_schema["__named_schemas"].keys() + + # check that B is inside the `__named_schemas` + parsed_schema = parse_schema(B.avro_schema_to_python()) + assert B.get_fullname() in parsed_schema["__named_schemas"].keys() + + +def test_schema_fullname_nested_records(): + @dataclasses.dataclass class C(AvroModel): - pass + name: str + age: int + @dataclasses.dataclass class B(AvroModel): c: C class Meta: namespace = "my.namespace" + @dataclasses.dataclass class A(AvroModel): b1: B b2: B - parse_schema(A.avro_schema_to_python()) + assert A.get_fullname() == "A" + assert B.get_fullname() == "my.namespace.B" + assert C.get_fullname() == "C" + + # check that A is inside the `__named_schemas` + parsed_schema = parse_schema(A.avro_schema_to_python()) + assert A.get_fullname() in parsed_schema["__named_schemas"].keys() + + # check that B is inside the `__named_schemas` + parsed_schema = parse_schema(B.avro_schema_to_python()) + assert B.get_fullname() in parsed_schema["__named_schemas"].keys() + + # check that C is inside the `__named_schemas` + # parsed_schema = parse_schema(C.avro_schema_to_python()) + assert C.get_fullname() in parsed_schema["__named_schemas"].keys() def test_use_of_same_type_in_nested_list(): + @dataclasses.dataclass class Address(AvroModel): "An Address" @@ -425,9 +487,11 @@ class Address(AvroModel): class Meta: namespace = "types.test" + @dataclasses.dataclass class PreviousAddresses(AvroModel): addresses: typing.List[Address] + @dataclasses.dataclass class User(AvroModel): "An User with Address and previous addresses" @@ -440,6 +504,7 @@ class User(AvroModel): def test_two_different_child_records(): + @dataclasses.dataclass class Location(AvroModel): lat: float long: float @@ -449,6 +514,7 @@ class Location(AvroModel): class Meta: namespace = "test.namespace" + @dataclasses.dataclass class Photo(AvroModel): filename: str data: bytes @@ -456,12 +522,14 @@ class Photo(AvroModel): height: int geo_tag: typing.Optional[Location] = None + @dataclasses.dataclass class Video(AvroModel): filename: str data: bytes duration: int geo_tag: typing.Optional[Location] = None + @dataclasses.dataclass class HolidayAlbum(AvroModel): album_name: str photos: typing.List[Photo] = dataclasses.field(default_factory=list) @@ -479,13 +547,16 @@ def test_nested_schemas_splitted() -> None: used in a separate way. """ + @dataclasses.dataclass class A(AvroModel): class Meta: namespace = "namespace" + @dataclasses.dataclass class B(AvroModel): a: A + @dataclasses.dataclass class C(AvroModel): b: B a: A @@ -498,6 +569,7 @@ class C(AvroModel): def test_nested_scheamas_splitted_with_intermediates() -> None: + @dataclasses.dataclass class A(AvroModel): class Meta: namespace = "namespace" @@ -505,9 +577,11 @@ class Meta: class B(AvroModel): a: A + @dataclasses.dataclass class C(AvroModel): a: A + @dataclasses.dataclass class D(AvroModel): b: B c: C diff --git a/tests/serialization/test_model_utils.py b/tests/serialization/test_model_utils.py index c58d49cc..eee522f9 100644 --- a/tests/serialization/test_model_utils.py +++ b/tests/serialization/test_model_utils.py @@ -16,7 +16,7 @@ def test_to_dict_to_json(klass, data, avro_binary, avro_json, instance_json, python_dict): instance = klass(**data) - assert instance.to_dict() == python_dict + assert instance.to_dict() == instance.asdict() == python_dict assert instance.to_json() == json.dumps(instance_json) assert instance == klass.parse_obj(data) assert instance == klass.parse_obj(json.loads(json.dumps(instance_json))) diff --git a/tests/serialization/test_nested_schema_serialization.py b/tests/serialization/test_nested_schema_serialization.py index 2610712e..7a9c2f87 100644 --- a/tests/serialization/test_nested_schema_serialization.py +++ b/tests/serialization/test_nested_schema_serialization.py @@ -229,31 +229,68 @@ class C(model_class): assert c.serialize() == b"" +@parametrize_base_model +def test_nested_schemas_splitted_with_intermediates( + model_class: typing.Type[AvroModel], decorator: typing.Callable +) -> None: + @decorator + class A(model_class): + class Meta: + namespace = "namespace" + + @decorator + class B(model_class): + a: A + + @decorator + class C(model_class): + a: A + + @decorator + class D(model_class): + b: B + c: C + + a = A() + b = B(a=a) + c = C(a=a) + d = D(b=b, c=c) + + assert d.serialize() == b"" + assert c.serialize() == b"" + + @parametrize_base_model def test_nested_schemas_splitted_with_unions(model_class: typing.Type[AvroModel], decorator: typing.Callable) -> None: """ This test will cover the cases when nested schemas with Unions that are used in a separate way. """ + if model_class == AvroBaseModelV1: + pytest.skip(reason="Smart Unions are not supported properly in `AvroBaseModelV1` (pydantic v1)") @decorator class S1(model_class): - pass + ... + + class Meta: + namespace = "my_namespace" @decorator class S2(model_class): - pass + age: int = 10 @decorator class A(model_class): s: typing.Union[S1, S2] - # class Meta: - # namespace = "namespace" + @decorator + class D(model_class): + name: str = "" @decorator class B(model_class): - a: A + a: typing.Union[A, D] @decorator class C(model_class): @@ -262,40 +299,24 @@ class C(model_class): b = B(a=A(s=S1())) c = C(b=B(a=A(s=S1())), a=A(s=S1())) + c2 = C(b=B(a=A(s=S1())), a=A(s=S2())) - assert b.serialize() == b"\x00" - assert c.serialize() == b"\x00\x00" + from fastavro import parse_schema + parsed_schema = parse_schema(B.avro_schema_to_python()) + print(parsed_schema["__named_schemas"].keys()) -@parametrize_base_model -def test_nested_schemas_splitted_with_intermediates( - model_class: typing.Type[AvroModel], decorator: typing.Callable -) -> None: - @decorator - class A(model_class): - class Meta: - namespace = "namespace" + ser = b.serialize() + assert ser == b"\x00\x00" + assert B.deserialize(ser) == b - @decorator - class B(model_class): - a: A + ser = c.serialize() + assert ser == b"\x00\x00\x00" + assert C.deserialize(ser) == c - @decorator - class C(model_class): - a: A - - @decorator - class D(model_class): - b: B - c: C - - a = A() - b = B(a=a) - c = C(a=a) - d = D(b=b, c=c) - - assert d.serialize() == b"" - assert c.serialize() == b"" + ser = c2.serialize() + assert ser == b"\x00\x00\x02\x14" + assert C.deserialize(ser) == c2 @parametrize_base_model