Skip to content

Commit

Permalink
fix: serialization improvements for nested unions. Related to #763 (#770
Browse files Browse the repository at this point in the history
)
  • Loading branch information
marcosschroh authored Oct 9, 2024
1 parent 1bf1c06 commit e67279f
Show file tree
Hide file tree
Showing 7 changed files with 180 additions and 48 deletions.
23 changes: 23 additions & 0 deletions dataclasses_avroschema/schema_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,29 @@ def get_metadata(cls: "Type[CT]") -> SchemaMetadata:
cls._metadata = SchemaMetadata.create(meta)
return cls._metadata

@classmethod
def get_fullname(cls) -> str:
"""
Fullname is composed of two parts: a name and a namespace
separated by a dot. A namespace is a dot-separated sequence of such names.
The empty string may also be used as a namespace to indicate the null namespace.
Equality of names (including field names and enum symbols)
as well as fullnames is case-sensitive.
"""
# we need to make sure that the schema has been generated
cls.generate_schema()
metadata = cls.get_metadata()

if metadata.namespace:
# if the current record has a namespace we use it
return f"{metadata.namespace}.{cls.__name__}"
elif cls._parent is not None:
# if the record has a parent then we try to use the parent namespace
parent_metadata = cls._parent.get_metadata()
if parent_metadata.namespace:
return f"{parent_metadata.namespace}.{cls.__name__}"
return cls.__name__

@classmethod
def generate_schema(cls: "Type[CT]", schema_type: str = "avro") -> Optional[OrderedDict]:
if cls._parser is None or cls.__mro__[1] != AvroModel:
Expand Down
29 changes: 20 additions & 9 deletions dataclasses_avroschema/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,10 @@ def deserialize(

input_stream.flush()

return sanitize_unions(data=payload, model=model) # type: ignore
return sanitize_payload(data=payload, model=model) # type: ignore


def sanitize_unions(*, data: JsonDict, model: "CT") -> JsonDict:
def sanitize_payload(*, data: JsonDict, model: "CT") -> JsonDict:
"""
This function tries to convert cast all the cases that have
`unions` with a Tuple format (AvroType, payload), for example
Expand All @@ -81,20 +81,31 @@ def sanitize_unions(*, data: JsonDict, model: "CT") -> JsonDict:
cleaned_data = {}
for field_name, field_value in data.items():
if isinstance(field_value, dict):
field_value = sanitize_unions(data=field_value, model=model)
field_value = sanitize_payload(data=field_value, model=model)
elif isinstance(field_value, tuple) and len(field_value) == 2:
# the first value is the model/record name and the second
# is its payload
model_name, model_dict_value = field_value
avro_model = model.get_user_defined_type(name=model_name)
if avro_model is not None:
field_value = avro_model.parse_obj(model_dict_value)
field_value = sanitize_union(union=field_value, model=model)

cleaned_data[field_name] = field_value

return cleaned_data


def sanitize_union(*, union: typing.Tuple, model: "CT") -> typing.Optional["CT"]:
# the first value is the model/record name and the second is its payload
model_name, model_value = union
if isinstance(model_value, dict):
# it can be a dict again so we need to sanitize
model_value = sanitize_payload(data=model_value, model=model)

model_name = model_name.split(".")[-1]

avro_model = model.get_user_defined_type(name=model_name)
if avro_model is not None:
return avro_model.parse_obj(model_value)

return model_value


def datetime_to_str(value: datetime.datetime) -> str:
return value.strftime(DATETIME_STR_FORMAT)

Expand Down
2 changes: 1 addition & 1 deletion dataclasses_avroschema/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def standardize_custom_type(
asdict = value.asdict()

if is_union(model.__annotations__[field_name]) and include_type:
asdict["-type"] = value.__class__.__name__
asdict["-type"] = value.get_fullname()
return asdict

return value
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ cli = [
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"

[tool.mypy]
check_untyped_defs = true

[[tool.mypy.overrides]]
module = "stringcase.*"
ignore_missing_imports = true
Expand Down
Loading

0 comments on commit e67279f

Please sign in to comment.