Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Allow modification of data on [de]serialization #34

Open
wants to merge 17 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,69 @@ Once setup the schema should act like any other schema. If it does not then plea
data.get('main'),
data.get('others')
)

ExplicitPolyField
-----------------

The ``ExplicitPolyField`` class adds an additional layer to the serialized data to embed a string used to disambiguate the type of the serialized data.
This avoids any uncertainty when faced with similarly serialized classes.
A mapping from classes to be supported to the schemas used to process them must be provided.
By default the serialized type names are taken from ``cls.__name__`` but this can be overridden.

.. code:: python

import json

from marshmallow import Schema, decorators, fields
from marshmallow_polyfield import ExplicitPolyField
from six import text_type


class TopClass:
def __init__(self, polyfield_list):
self.polyfield_list = polyfield_list

def __eq__(self, other):
if type(self) != type(other):
return False

return self.polyfield_list == other.polyfield_list

class TopSchema(Schema):
polyfield_list = fields.List(ExplicitPolyField(
class_to_schema_mapping={
text_type: fields.String,
int: fields.Integer,
},
class_to_name_overrides={
text_type: u'my string name',
},
))

@decorators.post_load
def make_object(self, data, many=None, partial=None):
return TopClass(**data)

top_schema = TopSchema()

top_class_example = TopClass(polyfield_list=[u'epf', 37])
top_class_example_dumped = top_schema.dump(top_class_example)
top_class_example_loaded = top_schema.load(top_class_example_dumped)

assert top_class_example_loaded == top_class_example
print(json.dumps(top_class_example_dumped, indent=4))

.. code:: json

{
"polyfield_list": [
{
"type": "my string name",
"value": "epf"
},
{
"type": "int",
"value": 37
}
]
}
8 changes: 6 additions & 2 deletions marshmallow_polyfield/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from marshmallow_polyfield.polyfield import PolyField, PolyFieldBase
from marshmallow_polyfield.polyfield import (
PolyField,
PolyFieldBase,
ExplicitPolyField,
)

__all__ = ['PolyField', 'PolyFieldBase']
__all__ = ['PolyField', 'PolyFieldBase', 'ExplicitPolyField']
116 changes: 115 additions & 1 deletion marshmallow_polyfield/polyfield.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import abc
import itertools


from marshmallow import Schema, ValidationError
from marshmallow.fields import Field
from marshmallow.fields import Field, String


class PolyFieldBase(Field, metaclass=abc.ABCMeta):
Expand All @@ -18,6 +20,7 @@ def _deserialize(self, value, attr, parent, **kwargs):
deserializer = None
try:
deserializer = self.deserialization_schema_selector(v, parent)
v = self.deserialization_value_modifier(v, parent)
if isinstance(deserializer, type):
deserializer = deserializer()
if not isinstance(deserializer, (Field, Schema)):
Expand Down Expand Up @@ -66,11 +69,13 @@ def _serialize(self, value, key, obj, **kwargs):
res = []
for v in value:
schema = self.serialization_schema_selector(v, obj)
v = self.serialization_value_modifier(v, obj)
schema.context.update(getattr(self, 'context', {}))
res.append(schema.dump(v))
return res
else:
schema = self.serialization_schema_selector(value, obj)
value = self.serialization_value_modifier(value, obj)
schema.context.update(getattr(self, 'context', {}))
return schema.dump(value)
except Exception as err:
Expand All @@ -88,6 +93,12 @@ def serialization_schema_selector(self, value, obj):
def deserialization_schema_selector(self, value, obj):
raise NotImplementedError

def serialization_value_modifier(self, value, obj):
return value

def deserialization_value_modifier(self, value, obj):
return value


class PolyField(PolyFieldBase):
"""
Expand All @@ -100,6 +111,8 @@ def __init__(
self,
serialization_schema_selector=None,
deserialization_schema_selector=None,
serialization_value_modifier=lambda value, obj: value,
deserialization_value_modifier=lambda value, obj: value,
many=False,
**metadata
):
Expand All @@ -115,9 +128,110 @@ def __init__(
super().__init__(many=many, **metadata)
self._serialization_schema_selector_arg = serialization_schema_selector
self._deserialization_schema_selector_arg = deserialization_schema_selector
self.serialization_value_modifier = serialization_value_modifier
self.deserialization_value_modifier = deserialization_value_modifier

def serialization_schema_selector(self, value, obj):
return self._serialization_schema_selector_arg(value, obj)

def deserialization_schema_selector(self, value, obj):
return self._deserialization_schema_selector_arg(value, obj)


def create_label_schema(schema):
class LabelSchema(Schema):
type = String()
value = schema()

return LabelSchema


class ExplicitNamesNotUniqueError(Exception):
pass


class ExplicitPolyField(PolyFieldBase):
"""
Similar to PolyField except that disambiguation is done by creating and
consuming an extra layer with an explicit string used to disambiguate the
type. The layer defaults to the form of ``{'type': cls.__name__,
'value': <serialized value>}``.
"""
def __init__(
self,
class_to_schema_mapping,
create_label_schema=create_label_schema,
class_to_name_overrides=None,
many=False,
**metadata
):
"""
:param class_to_schema_mapping: Classes as keys mapped to the schema
to be used for each.
:param create_label_schema: Callable returning a schema used to create
the extra serialized layer including the type name.
:param class_to_name_overrides: Classes as keys mapped to the name to
use as the serialized type name. Default is to use ``cls.__name__``.
"""
super().__init__(many=many, **metadata)

if class_to_name_overrides is None:
class_to_name_overrides = {}

self._class_to_schema_mapping = class_to_schema_mapping
self._class_to_name = {
cls: cls.__name__
for cls in self._class_to_schema_mapping.keys()
}

self._class_to_name.update(class_to_name_overrides)

name_to_classes = {
name: sorted(
(cls for cls, name in class_name_pairs),
key=lambda cls: cls.__name__,
)
for name, class_name_pairs in itertools.groupby(
sorted(self._class_to_name.items(), key=lambda x: x[1]),
key=lambda x: x[1],
)
}

reused_names = {
name: classes
for name, classes in name_to_classes.items()
if len(classes) > 1
}

if len(reused_names) > 0:
raise ExplicitNamesNotUniqueError(repr(reused_names))

self._name_to_class = {
name: classes[0]
for name, classes in name_to_classes.items()
}

self.create_label_schema = create_label_schema

def serialization_schema_selector(self, base_object, parent_obj):
cls = type(base_object)
schema = self._class_to_schema_mapping[cls]
label_schema = self.create_label_schema(schema=schema)

return label_schema()

def serialization_value_modifier(self, base_object, parent_obj):
cls = type(base_object)
name = self._class_to_name[cls]

return {u'type': name, u'value': base_object}

def deserialization_schema_selector(self, object_dict, parent_object_dict):
name = object_dict[u'type']
cls = self._name_to_class[name]
schema = self._class_to_schema_mapping[cls]

return schema()

def deserialization_value_modifier(self, object_dict, parent_object_dict):
return object_dict[u'value']
109 changes: 108 additions & 1 deletion tests/polyclasses.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
from marshmallow_polyfield import PolyFieldBase
from collections import namedtuple

from marshmallow import Schema, fields
import pytest

from marshmallow_polyfield import PolyFieldBase, ExplicitPolyField

from tests.shapes import (
shape_schema_serialization_disambiguation,
Expand Down Expand Up @@ -33,3 +38,105 @@ def serialization_schema_selector(self, value, obj):

def deserialization_schema_selector(self, value, obj):
return shape_property_schema_deserialization_disambiguation(value, obj)


class BadStringValueModifierSchema(Schema):
a = fields.String()


class BadStringValueModifierPolyField(PolyFieldBase):
def __init__(self, bad_string_value, many=False, **metadata):
super().__init__(many=many, **metadata)

self.bad_string_value = bad_string_value

def serialization_schema_selector(self, value, obj):
return BadStringValueModifierSchema()

def deserialization_schema_selector(self, value, obj):
return BadStringValueModifierSchema()

def serialization_value_modifier(self, value, obj):
return {'a': self.bad_string_value}

def deserialization_value_modifier(self, value, obj):
return {'a': self.bad_string_value}


explicit_poly_field_with_overrides = ExplicitPolyField(
class_to_schema_mapping={
str: fields.String,
int: fields.Integer,
dict: fields.Dict,
},
class_to_name_overrides={
str: 'str',
},
)


explicit_poly_field_without_overrides = ExplicitPolyField(
class_to_schema_mapping={
int: fields.Integer,
dict: fields.Dict,
},
)


ExplicitPolyFieldExample = namedtuple(
'ExplicitPolyFieldExample',
[
'type_name',
'value',
'layer',
'field',
'polyfield',
],
)


def create_explicit_poly_field_example(type_name, value, field, polyfield):
return ExplicitPolyFieldExample(
type_name=type_name,
value=value,
layer={'type': type_name, 'value': value},
field=field,
polyfield=polyfield,
)


parametrize_explicit_poly_field_type_name_and_value = pytest.mark.parametrize(
['example'],
[
[create_explicit_poly_field_example(
type_name=u'str',
value=u'red',
field=fields.String,
polyfield=explicit_poly_field_with_overrides,
)],
[create_explicit_poly_field_example(
type_name=u'int',
value=42,
field=fields.Integer,
polyfield=explicit_poly_field_with_overrides,
)],
[create_explicit_poly_field_example(
type_name=u'dict',
value={u'puppy': 3.9},
field=fields.Dict,
polyfield=explicit_poly_field_with_overrides,
)],
[create_explicit_poly_field_example(
type_name=u'int',
value=42,
field=fields.Integer,
polyfield=explicit_poly_field_without_overrides,
)],
[create_explicit_poly_field_example(
type_name=u'dict',
value={u'puppy': 3.9},
field=fields.Dict,
polyfield=explicit_poly_field_without_overrides,
)],
],
)
Loading