Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Allow modification of data on [de]serialization #34

Open
wants to merge 17 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 74 additions & 1 deletion marshmallow_polyfield/polyfield.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from six import raise_from, with_metaclass

from marshmallow import Schema, ValidationError
from marshmallow.fields import Field
from marshmallow.fields import Field, String


class PolyFieldBase(with_metaclass(abc.ABCMeta, Field)):
Expand All @@ -19,6 +19,7 @@ def _deserialize(self, value, attr, parent, **kwargs):
deserializer = None
try:
deserializer = self.deserialization_schema_selector(v, parent)
v = self.deserialization_value_modifier(v, parent)
if isinstance(deserializer, type):
deserializer = deserializer()
if not isinstance(deserializer, (Field, Schema)):
Expand Down Expand Up @@ -70,11 +71,13 @@ def _serialize(self, value, key, obj, **kwargs):
res = []
for v in value:
schema = self.serialization_schema_selector(v, obj)
v = self.serialization_value_modifier(v, obj)
schema.context.update(getattr(self, 'context', {}))
res.append(schema.dump(v))
return res
else:
schema = self.serialization_schema_selector(value, obj)
value = self.serialization_value_modifier(value, obj)
schema.context.update(getattr(self, 'context', {}))
return schema.dump(value)
except Exception as err:
Expand All @@ -92,6 +95,12 @@ def serialization_schema_selector(self, value, obj):
def deserialization_schema_selector(self, value, obj):
raise NotImplementedError

def serialization_value_modifier(self, value, obj):
return value

def deserialization_value_modifier(self, value, obj):
return value


class PolyField(PolyFieldBase):
"""
Expand All @@ -104,6 +113,8 @@ def __init__(
self,
serialization_schema_selector=None,
deserialization_schema_selector=None,
serialization_value_modifier=lambda value, obj: value,
deserialization_value_modifier=lambda value, obj: value,
many=False,
**metadata
):
Expand All @@ -119,9 +130,71 @@ def __init__(
super(PolyField, self).__init__(many=many, **metadata)
self._serialization_schema_selector_arg = serialization_schema_selector
self._deserialization_schema_selector_arg = deserialization_schema_selector
self.serialization_value_modifier = serialization_value_modifier
self.deserialization_value_modifier = deserialization_value_modifier

def serialization_schema_selector(self, value, obj):
return self._serialization_schema_selector_arg(value, obj)

def deserialization_schema_selector(self, value, obj):
return self._deserialization_schema_selector_arg(value, obj)


def create_label_schema(schema):
class LabelSchema(Schema):
type = String()
value = schema()

return LabelSchema


class ExplicitPolyField(PolyFieldBase):
def __init__(
self,
class_to_schema_mapping,
create_label_schema=create_label_schema,
class_to_name_overrides=None,
many=False,
**metadata
):
super(ExplicitPolyField, self).__init__(many=many, **metadata)

if class_to_name_overrides is None:
class_to_name_overrides = {}

self._class_to_schema_mapping = class_to_schema_mapping
self._class_to_name = {
cls: cls.__name__
for cls in self._class_to_schema_mapping.keys()
}

self._class_to_name.update(class_to_name_overrides)

self._name_to_class = {
name: cls
for cls, name in self._class_to_name.items()
}
self.create_label_schema = create_label_schema

def serialization_schema_selector(self, base_object, parent_obj):
cls = type(base_object)
schema = self._class_to_schema_mapping[cls]
label_schema = self.create_label_schema(schema=schema)

return label_schema()

def serialization_value_modifier(self, base_object, parent_obj):
cls = type(base_object)
name = self._class_to_name[cls]

return {u'type': name, u'value': base_object}

def deserialization_schema_selector(self, object_dict, parent_object_dict):
name = object_dict[u'type']
cls = self._name_to_class[name]
schema = self._class_to_schema_mapping[cls]

return schema()

def deserialization_value_modifier(self, object_dict, parent_object_dict):
return object_dict[u'value']
261 changes: 259 additions & 2 deletions tests/test_serialization.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from collections import namedtuple
from marshmallow import fields, Schema
from marshmallow_polyfield.polyfield import PolyField
from marshmallow import decorators, fields, Schema
from marshmallow_polyfield.polyfield import PolyField, ExplicitPolyField
import pytest
from six import text_type
from tests.shapes import (
Rectangle,
Triangle,
Expand Down Expand Up @@ -99,3 +100,259 @@ def test_serializing_polyfield_by_parent_type(field):
rect_dict = field.serialize('shape', marshmallow_sticker)

assert rect_dict == {"length": 4, "width": 10, "color": "blue"}


def test_serializing_with_modification():
import marshmallow

def create_label_schema(schema):
class LabelSchema(marshmallow.Schema):
type = marshmallow.fields.String()
value = schema()

return LabelSchema

class_to_schema = {
text_type: marshmallow.fields.String,
int: marshmallow.fields.Integer,
}

name_to_class = {
u'str': text_type,
u'int': int,
}

class_to_name = {
cls: name
for name, cls in name_to_class.items()
}
class_to_name[text_type] = u'str'

def serialization_schema(base_object, parent_obj):
cls = type(base_object)
schema = class_to_schema[cls]

label_schema = create_label_schema(schema=schema)

return label_schema()

def serialization_value(base_object, parent_obj):
cls = type(base_object)
name = class_to_name[cls]

return {'type': name, 'value': base_object}

def deserialization_schema(object_dict, parent_object_dict):
name = object_dict['type']
cls = name_to_class[name]
schema = class_to_schema[cls]

return schema()

def deserialization_value(object_dict, parent_object_dict):
return object_dict['value']

class TopClass:
def __init__(self, polyfield):
self.polyfield = polyfield

def __eq__(self, other):
if type(self) != type(other):
return False

return self.polyfield == other.polyfield

class TopSchema(marshmallow.Schema):
polyfield = PolyField(
serialization_schema_selector=serialization_schema,
deserialization_schema_selector=deserialization_schema,
serialization_value_modifier=serialization_value,
deserialization_value_modifier=deserialization_value,
)

@marshmallow.decorators.post_load
def make_object(self, data, many=None, partial=None):
return TopClass(**data)

top_schema = TopSchema()

top_class_str_example = TopClass(polyfield=u'abc')
top_class_str_example_dumped = top_schema.dump(top_class_str_example)
print(top_class_str_example_dumped)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rather than printing I think you should validate the json that comes out with an assertion

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just exploratory code and I think it's fully replaced at this point by 'real' tests. It's in the WIP list to be removed.

top_class_str_example_loaded = top_schema.load(top_class_str_example_dumped)
assert top_class_str_example_loaded == top_class_str_example

print('---')
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Personally id rather you pull out what you need so you can make these two different tests rather than one test separated by the output

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like above, to be removed.


top_class_int_example = TopClass(polyfield=42)
top_class_int_example_dumped = top_schema.dump(top_class_int_example)
print(top_class_int_example_dumped)
top_class_int_example_loaded = top_schema.load(top_class_int_example_dumped)
assert top_class_int_example_loaded == top_class_int_example


def test_serializing_with_modification_ExplicitPolyField():
class TopClass:
def __init__(self, polyfield):
self.polyfield = polyfield

def __eq__(self, other):
if type(self) != type(other):
return False

return self.polyfield == other.polyfield

class TopSchema(Schema):
polyfield = ExplicitPolyField(
class_to_schema_mapping={
text_type: fields.String,
int: fields.Integer,
},
class_to_name_overrides={
text_type: u'str',
},
)

@decorators.post_load
def make_object(self, data, many=None, partial=None):
return TopClass(**data)

top_schema = TopSchema()

top_class_str_example = TopClass(polyfield=u'abc')
top_class_str_example_dumped = top_schema.dump(top_class_str_example)
print(top_class_str_example_dumped)
top_class_str_example_loaded = top_schema.load(top_class_str_example_dumped)
assert top_class_str_example_loaded == top_class_str_example

print('---')

top_class_int_example = TopClass(polyfield=42)
top_class_int_example_dumped = top_schema.dump(top_class_int_example)
print(top_class_int_example_dumped)
top_class_int_example_loaded = top_schema.load(top_class_int_example_dumped)
assert top_class_int_example_loaded == top_class_int_example


explicit_poly_field_with_overrides = ExplicitPolyField(
class_to_schema_mapping={
text_type: fields.String,
int: fields.Integer,
dict: fields.Dict,
},
class_to_name_overrides={
text_type: 'str',
},
)


explicit_poly_field_without_overrides = ExplicitPolyField(
class_to_schema_mapping={
int: fields.Integer,
dict: fields.Dict,
},
)


ExplicitPolyFieldExample = namedtuple(
'ExplicitPolyFieldExample',
[
'type_name',
'value',
'layer',
'field',
'polyfield',
],
)


def create_explicit_poly_field_example(type_name, value, field, polyfield):
return ExplicitPolyFieldExample(
type_name=type_name,
value=value,
layer={'type': type_name, 'value': value},
field=field,
polyfield=polyfield,
)


parametrize_explicit_poly_field_type_name_and_value = pytest.mark.parametrize(
['example'],
[
[create_explicit_poly_field_example(
type_name=u'str',
value=u'red',
field=fields.String,
polyfield=explicit_poly_field_with_overrides,
)],
[create_explicit_poly_field_example(
type_name=u'int',
value=42,
field=fields.Integer,
polyfield=explicit_poly_field_with_overrides,
)],
[create_explicit_poly_field_example(
type_name=u'dict',
value={u'puppy': 3.9},
field=fields.Dict,
polyfield=explicit_poly_field_with_overrides,
)],
[create_explicit_poly_field_example(
type_name=u'int',
value=42,
field=fields.Integer,
polyfield=explicit_poly_field_without_overrides,
)],
[create_explicit_poly_field_example(
type_name=u'dict',
value={u'puppy': 3.9},
field=fields.Dict,
polyfield=explicit_poly_field_without_overrides,
)],
],
)


@parametrize_explicit_poly_field_type_name_and_value
def test_serializing_explicit_poly_field(example):
Point = namedtuple('Point', ['x', 'y'])
p = Point(x=example.value, y=37)

assert example.polyfield.serialize('x', p) == example.layer


@parametrize_explicit_poly_field_type_name_and_value
def test_serializing_explicit_poly_field_type_name(example):
Point = namedtuple('Point', ['x', 'y'])
p = Point(x=example.value, y=37)

serialized = example.polyfield.serialize('x', p)
assert serialized['type'] == example.type_name


@parametrize_explicit_poly_field_type_name_and_value
def test_serializing_explicit_poly_field_value(example):
Point = namedtuple('Point', ['x', 'y'])
p = Point(x=example.value, y=37)

serialized = example.polyfield.serialize('x', p)
assert serialized['value'] is example.value


@parametrize_explicit_poly_field_type_name_and_value
def test_deserializing_explicit_poly_field_value(example):
assert example.polyfield.deserialize(example.layer) is example.value


@parametrize_explicit_poly_field_type_name_and_value
def test_deserializing_explicit_poly_field_field_type(example):
# TODO: Checking the type only does so much, really want to compare
# the fields but they don't implement == so we'll have to code
# that up to check it.
assert (
type(example.polyfield.deserialization_schema_selector(
example.layer,
{'x': example.layer},
))
is type(example.field())
) # noqa E721