From 47f5dc4709375be69da7a00612de04ab75b3baad Mon Sep 17 00:00:00 2001 From: Kyle Altendorf Date: Mon, 6 Jan 2020 13:55:14 -0500 Subject: [PATCH 01/16] messy first pass at allow modification of the values --- marshmallow_polyfield/polyfield.py | 40 ++++++++++++- tests/polyclasses.py | 12 ++++ tests/test_deserialization.py | 6 ++ tests/test_polyfield_base.py | 6 ++ tests/test_serialization.py | 94 ++++++++++++++++++++++++++++++ 5 files changed, 155 insertions(+), 3 deletions(-) diff --git a/marshmallow_polyfield/polyfield.py b/marshmallow_polyfield/polyfield.py index 2549845..51b3c56 100644 --- a/marshmallow_polyfield/polyfield.py +++ b/marshmallow_polyfield/polyfield.py @@ -9,6 +9,8 @@ class PolyFieldBase(with_metaclass(abc.ABCMeta, Field)): def __init__(self, many=False, **metadata): super(PolyFieldBase, self).__init__(**metadata) self.many = many + self.serializer_modifies = False + self.deserializer_modifies = False def _deserialize(self, value, attr, parent, **kwargs): if not self.many: @@ -18,7 +20,12 @@ def _deserialize(self, value, attr, parent, **kwargs): for v in value: deserializer = None try: - deserializer = self.deserialization_schema_selector(v, parent) + if self.deserializer_modifies: + deserializer, v = ( + self.deserialization_modifier(v, parent) + ) + else: + deserializer = self.deserialization_schema_selector(v, parent) if isinstance(deserializer, type): deserializer = deserializer() if not isinstance(deserializer, (Field, Schema)): @@ -69,12 +76,18 @@ def _serialize(self, value, key, obj, **kwargs): if self.many: res = [] for v in value: - schema = self.serialization_schema_selector(v, obj) + if self.serializer_modifies: + schema, v = self.serialization_modifier(v, obj) + else: + schema = self.serialization_schema_selector(v, obj) schema.context.update(getattr(self, 'context', {})) res.append(schema.dump(v)) return res else: - schema = self.serialization_schema_selector(value, obj) + if self.serializer_modifies: + schema, value = self.serialization_modifier(value, obj) + else: + schema = self.serialization_schema_selector(value, obj) schema.context.update(getattr(self, 'context', {})) return schema.dump(value) except Exception as err: @@ -92,6 +105,14 @@ def serialization_schema_selector(self, value, obj): def deserialization_schema_selector(self, value, obj): raise NotImplementedError + @abc.abstractmethod + def serialization_modifier(self, value, obj): + raise NotImplementedError + + @abc.abstractmethod + def deserialization_modifier(self, value, obj): + raise NotImplementedError + class PolyField(PolyFieldBase): """ @@ -104,6 +125,8 @@ def __init__( self, serialization_schema_selector=None, deserialization_schema_selector=None, + serialization_modifier=None, + deserialization_modifier=None, many=False, **metadata ): @@ -119,9 +142,20 @@ def __init__( super(PolyField, self).__init__(many=many, **metadata) self._serialization_schema_selector_arg = serialization_schema_selector self._deserialization_schema_selector_arg = deserialization_schema_selector + self._serialization_modifier_arg = serialization_modifier + self._deserialization_modifier_arg = deserialization_modifier + # TODO: make above exclusive to each other + self.serializer_modifies = self._serialization_modifier_arg is not None + self.deserializer_modifies = self._deserialization_modifier_arg is not None def serialization_schema_selector(self, value, obj): return self._serialization_schema_selector_arg(value, obj) def deserialization_schema_selector(self, value, obj): return self._deserialization_schema_selector_arg(value, obj) + + def serialization_modifier(self, value, obj): + return self._serialization_modifier_arg(value, obj) + + def deserialization_modifier(self, value, obj): + return self._deserialization_modifier_arg(value, obj) diff --git a/tests/polyclasses.py b/tests/polyclasses.py index 6919e6a..9ea5a9a 100644 --- a/tests/polyclasses.py +++ b/tests/polyclasses.py @@ -26,6 +26,12 @@ def serialization_schema_selector(self, value, obj): def deserialization_schema_selector(self, value, obj): return shape_schema_deserialization_disambiguation(value, obj) + def serialization_modifier(self, value, obj): + return shape_schema_serialization_disambiguation(value, obj), value + + def deserialization_modifier(self, value, obj): + return shape_schema_deserialization_disambiguation(value, obj), value + class ShapePropertyPolyField(PolyFieldBase): def serialization_schema_selector(self, value, obj): @@ -33,3 +39,9 @@ def serialization_schema_selector(self, value, obj): def deserialization_schema_selector(self, value, obj): return shape_property_schema_deserialization_disambiguation(value, obj) + + def serialization_modifier(self, value, obj): + return shape_property_schema_serialization_disambiguation(value, obj) + + def deserialization_modifier(self, value, obj): + return shape_property_schema_deserialization_disambiguation(value, obj) diff --git a/tests/test_deserialization.py b/tests/test_deserialization.py index b119d56..461fd70 100644 --- a/tests/test_deserialization.py +++ b/tests/test_deserialization.py @@ -29,6 +29,12 @@ def serialization_schema_selector(self, value, obj): def deserialization_schema_selector(self, value, obj): return _bad_deserializer_disambiguation(value, obj) + def serialization_modifier(self, value, obj): + return _bad_deserializer_disambiguation(value, obj), value + + def deserialization_modifier(self, value, obj): + return _bad_deserializer_disambiguation(value, obj), value + class TestPolyField(object): diff --git a/tests/test_polyfield_base.py b/tests/test_polyfield_base.py index a2ff228..3b39c46 100644 --- a/tests/test_polyfield_base.py +++ b/tests/test_polyfield_base.py @@ -8,6 +8,12 @@ def serialization_schema_selector(self, value, obj): def deserialization_schema_selector(self, value, obj): super(TrivialExample, self).deserialization_schema_selector(value, obj) + def serialization_modifier(self, value, obj): + super(TrivialExample, self).serialization_schema_selector(value, obj), value + + def deserialization_modifier(self, value, obj): + super(TrivialExample, self).deserialization_schema_selector(value, obj), value + def test_polyfield_base(): te = TrivialExample() diff --git a/tests/test_serialization.py b/tests/test_serialization.py index dd6994f..958d75b 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -99,3 +99,97 @@ def test_serializing_polyfield_by_parent_type(field): rect_dict = field.serialize('shape', marshmallow_sticker) assert rect_dict == {"length": 4, "width": 10, "color": "blue"} + + +def test_serializing_with_modification(): + import marshmallow + + def create_label_schema(schema): + class LabelSchema(marshmallow.Schema): + type = marshmallow.fields.String() + value = schema() + + return LabelSchema + + class_to_schema = { + str: marshmallow.fields.String, + int: marshmallow.fields.Integer, + } + + name_to_class = { + 'str': str, + 'int': int, + } + + class_to_name = { + cls: name + for name, cls in name_to_class.items() + } + + def serialization_disambiguation(base_object, parent_obj): + cls = type(base_object) + schema = class_to_schema[cls] + name = class_to_name[cls] + + label_schema = create_label_schema(schema=schema) + + return label_schema(), {'type': name, 'value': base_object} + + def deserialization_disambiguation(object_dict, parent_object_dict): + name = object_dict['type'] + value = object_dict['value'] + cls = name_to_class[name] + schema = class_to_schema[cls] + + # for key, item in parent_object_dict.items(): + # if item is object_dict: + # break + # else: + # raise Exception('ack') + # + # parent_object_dict[key] = value + # print(parent_object_dict) + + # label_schema = create_label_schema( + # schema=class_to_schema[name_to_class[name]], + # type_name=name, + # instance=object_dict, + # ) + + return schema(), value + + class TopClass: + def __init__(self, polyfield): + self.polyfield = polyfield + + def __eq__(self, other): + if type(self) != type(other): + return False + + return self.polyfield == other.polyfield + + class TopSchema(marshmallow.Schema): + polyfield = PolyField( + serialization_modifier=serialization_disambiguation, + deserialization_modifier=deserialization_disambiguation, + ) + + @marshmallow.decorators.post_load + def make_object(self, data, many, partial): + return TopClass(**data) + + top_schema = TopSchema() + + top_class_str_example = TopClass(polyfield='abc') + top_class_str_example_dumped = top_schema.dump(top_class_str_example) + print(top_class_str_example_dumped) + top_class_str_example_loaded = top_schema.load(top_class_str_example_dumped) + assert top_class_str_example_loaded == top_class_str_example + + print('---') + + top_class_int_example = TopClass(polyfield=42) + top_class_int_example_dumped = top_schema.dump(top_class_int_example) + print(top_class_int_example_dumped) + top_class_int_example_loaded = top_schema.load(top_class_int_example_dumped) + assert top_class_int_example_loaded == top_class_int_example From 326f7cc98c8c4652f246dcde33ec196b541721f8 Mon Sep 17 00:00:00 2001 From: Kyle Altendorf Date: Mon, 6 Jan 2020 15:47:04 -0500 Subject: [PATCH 02/16] use separate schema and value functions --- marshmallow_polyfield/polyfield.py | 51 +++++++++++------------------- tests/polyclasses.py | 12 ------- tests/test_deserialization.py | 6 ---- tests/test_polyfield_base.py | 6 ---- tests/test_serialization.py | 38 ++++++++++------------ 5 files changed, 34 insertions(+), 79 deletions(-) diff --git a/marshmallow_polyfield/polyfield.py b/marshmallow_polyfield/polyfield.py index 51b3c56..7b58695 100644 --- a/marshmallow_polyfield/polyfield.py +++ b/marshmallow_polyfield/polyfield.py @@ -9,8 +9,6 @@ class PolyFieldBase(with_metaclass(abc.ABCMeta, Field)): def __init__(self, many=False, **metadata): super(PolyFieldBase, self).__init__(**metadata) self.many = many - self.serializer_modifies = False - self.deserializer_modifies = False def _deserialize(self, value, attr, parent, **kwargs): if not self.many: @@ -20,12 +18,8 @@ def _deserialize(self, value, attr, parent, **kwargs): for v in value: deserializer = None try: - if self.deserializer_modifies: - deserializer, v = ( - self.deserialization_modifier(v, parent) - ) - else: - deserializer = self.deserialization_schema_selector(v, parent) + deserializer = self.deserialization_schema_selector(v, parent) + v = self.deserialization_value_modifier(v, parent) if isinstance(deserializer, type): deserializer = deserializer() if not isinstance(deserializer, (Field, Schema)): @@ -76,18 +70,14 @@ def _serialize(self, value, key, obj, **kwargs): if self.many: res = [] for v in value: - if self.serializer_modifies: - schema, v = self.serialization_modifier(v, obj) - else: - schema = self.serialization_schema_selector(v, obj) + schema = self.serialization_schema_selector(v, obj) + v = self.serialization_value_modifier(v, obj) schema.context.update(getattr(self, 'context', {})) res.append(schema.dump(v)) return res else: - if self.serializer_modifies: - schema, value = self.serialization_modifier(value, obj) - else: - schema = self.serialization_schema_selector(value, obj) + schema = self.serialization_schema_selector(value, obj) + value = self.serialization_value_modifier(value, obj) schema.context.update(getattr(self, 'context', {})) return schema.dump(value) except Exception as err: @@ -105,13 +95,11 @@ def serialization_schema_selector(self, value, obj): def deserialization_schema_selector(self, value, obj): raise NotImplementedError - @abc.abstractmethod - def serialization_modifier(self, value, obj): - raise NotImplementedError + def serialization_value_modifier(self, value, obj): + return value - @abc.abstractmethod - def deserialization_modifier(self, value, obj): - raise NotImplementedError + def deserialization_value_modifier(self, value, obj): + return value class PolyField(PolyFieldBase): @@ -125,8 +113,8 @@ def __init__( self, serialization_schema_selector=None, deserialization_schema_selector=None, - serialization_modifier=None, - deserialization_modifier=None, + serialization_value_modifier=lambda value, obj: value, + deserialization_value_modifier=lambda value, obj: value, many=False, **metadata ): @@ -142,11 +130,8 @@ def __init__( super(PolyField, self).__init__(many=many, **metadata) self._serialization_schema_selector_arg = serialization_schema_selector self._deserialization_schema_selector_arg = deserialization_schema_selector - self._serialization_modifier_arg = serialization_modifier - self._deserialization_modifier_arg = deserialization_modifier - # TODO: make above exclusive to each other - self.serializer_modifies = self._serialization_modifier_arg is not None - self.deserializer_modifies = self._deserialization_modifier_arg is not None + self._serialization_value_modifier_arg = serialization_value_modifier + self._deserialization_value_modifier_arg = deserialization_value_modifier def serialization_schema_selector(self, value, obj): return self._serialization_schema_selector_arg(value, obj) @@ -154,8 +139,8 @@ def serialization_schema_selector(self, value, obj): def deserialization_schema_selector(self, value, obj): return self._deserialization_schema_selector_arg(value, obj) - def serialization_modifier(self, value, obj): - return self._serialization_modifier_arg(value, obj) + def serialization_value_modifier(self, value, obj): + return self._serialization_value_modifier_arg(value, obj) - def deserialization_modifier(self, value, obj): - return self._deserialization_modifier_arg(value, obj) + def deserialization_value_modifier(self, value, obj): + return self._deserialization_value_modifier_arg(value, obj) diff --git a/tests/polyclasses.py b/tests/polyclasses.py index 9ea5a9a..6919e6a 100644 --- a/tests/polyclasses.py +++ b/tests/polyclasses.py @@ -26,12 +26,6 @@ def serialization_schema_selector(self, value, obj): def deserialization_schema_selector(self, value, obj): return shape_schema_deserialization_disambiguation(value, obj) - def serialization_modifier(self, value, obj): - return shape_schema_serialization_disambiguation(value, obj), value - - def deserialization_modifier(self, value, obj): - return shape_schema_deserialization_disambiguation(value, obj), value - class ShapePropertyPolyField(PolyFieldBase): def serialization_schema_selector(self, value, obj): @@ -39,9 +33,3 @@ def serialization_schema_selector(self, value, obj): def deserialization_schema_selector(self, value, obj): return shape_property_schema_deserialization_disambiguation(value, obj) - - def serialization_modifier(self, value, obj): - return shape_property_schema_serialization_disambiguation(value, obj) - - def deserialization_modifier(self, value, obj): - return shape_property_schema_deserialization_disambiguation(value, obj) diff --git a/tests/test_deserialization.py b/tests/test_deserialization.py index 461fd70..b119d56 100644 --- a/tests/test_deserialization.py +++ b/tests/test_deserialization.py @@ -29,12 +29,6 @@ def serialization_schema_selector(self, value, obj): def deserialization_schema_selector(self, value, obj): return _bad_deserializer_disambiguation(value, obj) - def serialization_modifier(self, value, obj): - return _bad_deserializer_disambiguation(value, obj), value - - def deserialization_modifier(self, value, obj): - return _bad_deserializer_disambiguation(value, obj), value - class TestPolyField(object): diff --git a/tests/test_polyfield_base.py b/tests/test_polyfield_base.py index 3b39c46..a2ff228 100644 --- a/tests/test_polyfield_base.py +++ b/tests/test_polyfield_base.py @@ -8,12 +8,6 @@ def serialization_schema_selector(self, value, obj): def deserialization_schema_selector(self, value, obj): super(TrivialExample, self).deserialization_schema_selector(value, obj) - def serialization_modifier(self, value, obj): - super(TrivialExample, self).serialization_schema_selector(value, obj), value - - def deserialization_modifier(self, value, obj): - super(TrivialExample, self).deserialization_schema_selector(value, obj), value - def test_polyfield_base(): te = TrivialExample() diff --git a/tests/test_serialization.py b/tests/test_serialization.py index 958d75b..ae40c1c 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -126,37 +126,29 @@ class LabelSchema(marshmallow.Schema): for name, cls in name_to_class.items() } - def serialization_disambiguation(base_object, parent_obj): + def serialization_schema(base_object, parent_obj): cls = type(base_object) schema = class_to_schema[cls] - name = class_to_name[cls] label_schema = create_label_schema(schema=schema) - return label_schema(), {'type': name, 'value': base_object} + return label_schema() + + def serialization_value(base_object, parent_obj): + cls = type(base_object) + name = class_to_name[cls] + + return {'type': name, 'value': base_object} - def deserialization_disambiguation(object_dict, parent_object_dict): + def deserialization_schema(object_dict, parent_object_dict): name = object_dict['type'] - value = object_dict['value'] cls = name_to_class[name] schema = class_to_schema[cls] - # for key, item in parent_object_dict.items(): - # if item is object_dict: - # break - # else: - # raise Exception('ack') - # - # parent_object_dict[key] = value - # print(parent_object_dict) - - # label_schema = create_label_schema( - # schema=class_to_schema[name_to_class[name]], - # type_name=name, - # instance=object_dict, - # ) + return schema() - return schema(), value + def deserialization_value(object_dict, parent_object_dict): + return object_dict['value'] class TopClass: def __init__(self, polyfield): @@ -170,8 +162,10 @@ def __eq__(self, other): class TopSchema(marshmallow.Schema): polyfield = PolyField( - serialization_modifier=serialization_disambiguation, - deserialization_modifier=deserialization_disambiguation, + serialization_schema_selector=serialization_schema, + deserialization_schema_selector=deserialization_schema, + serialization_value_modifier=serialization_value, + deserialization_value_modifier=deserialization_value, ) @marshmallow.decorators.post_load From 986bd43ba18ab564f542bddb8235b1c30895b3d2 Mon Sep 17 00:00:00 2001 From: Kyle Altendorf Date: Mon, 6 Jan 2020 22:54:56 -0500 Subject: [PATCH 03/16] add ExplicitPolyField --- marshmallow_polyfield/polyfield.py | 60 ++++++++++++++++++++++++++---- tests/test_serialization.py | 44 +++++++++++++++++++++- 2 files changed, 95 insertions(+), 9 deletions(-) diff --git a/marshmallow_polyfield/polyfield.py b/marshmallow_polyfield/polyfield.py index 7b58695..231f2f3 100644 --- a/marshmallow_polyfield/polyfield.py +++ b/marshmallow_polyfield/polyfield.py @@ -2,7 +2,7 @@ from six import raise_from, with_metaclass from marshmallow import Schema, ValidationError -from marshmallow.fields import Field +from marshmallow.fields import Field, String class PolyFieldBase(with_metaclass(abc.ABCMeta, Field)): @@ -130,8 +130,8 @@ def __init__( super(PolyField, self).__init__(many=many, **metadata) self._serialization_schema_selector_arg = serialization_schema_selector self._deserialization_schema_selector_arg = deserialization_schema_selector - self._serialization_value_modifier_arg = serialization_value_modifier - self._deserialization_value_modifier_arg = deserialization_value_modifier + self.serialization_value_modifier = serialization_value_modifier + self.deserialization_value_modifier = deserialization_value_modifier def serialization_schema_selector(self, value, obj): return self._serialization_schema_selector_arg(value, obj) @@ -139,8 +139,54 @@ def serialization_schema_selector(self, value, obj): def deserialization_schema_selector(self, value, obj): return self._deserialization_schema_selector_arg(value, obj) - def serialization_value_modifier(self, value, obj): - return self._serialization_value_modifier_arg(value, obj) - def deserialization_value_modifier(self, value, obj): - return self._deserialization_value_modifier_arg(value, obj) +def create_label_schema(schema): + class LabelSchema(Schema): + type = String() + value = schema() + + return LabelSchema + + +class ExplicitPolyField(PolyFieldBase): + def __init__( + self, + class_to_schema_mapping, + create_label_schema=create_label_schema, + many=False, + **metadata + ): + super(ExplicitPolyField, self).__init__(many=many, **metadata) + self._class_to_schema_mapping = class_to_schema_mapping + self._class_to_name = { + cls: cls.__name__ + for cls in self._class_to_schema_mapping.keys() + } + self._name_to_class = { + name: cls + for cls, name in self._class_to_name.items() + } + self.create_label_schema = create_label_schema + + def serialization_schema_selector(self, base_object, parent_obj): + cls = type(base_object) + schema = self._class_to_schema_mapping[cls] + label_schema = self.create_label_schema(schema=schema) + + return label_schema() + + def serialization_value_modifier(self, base_object, parent_obj): + cls = type(base_object) + name = self._class_to_name[cls] + + return {'type': name, 'value': base_object} + + def deserialization_schema_selector(self, object_dict, parent_object_dict): + name = object_dict['type'] + cls = self._name_to_class[name] + schema = self._class_to_schema_mapping[cls] + + return schema() + + def deserialization_value_modifier(self, object_dict, parent_object_dict): + return object_dict['value'] diff --git a/tests/test_serialization.py b/tests/test_serialization.py index ae40c1c..d6e93ee 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -1,6 +1,6 @@ from collections import namedtuple -from marshmallow import fields, Schema -from marshmallow_polyfield.polyfield import PolyField +from marshmallow import decorators, fields, Schema +from marshmallow_polyfield.polyfield import PolyField, ExplicitPolyField import pytest from tests.shapes import ( Rectangle, @@ -187,3 +187,43 @@ def make_object(self, data, many, partial): print(top_class_int_example_dumped) top_class_int_example_loaded = top_schema.load(top_class_int_example_dumped) assert top_class_int_example_loaded == top_class_int_example + + +def test_serializing_with_modification_ExplicitPolyField(): + class TopClass: + def __init__(self, polyfield): + self.polyfield = polyfield + + def __eq__(self, other): + if type(self) != type(other): + return False + + return self.polyfield == other.polyfield + + class TopSchema(Schema): + polyfield = ExplicitPolyField( + class_to_schema_mapping={ + str: fields.String, + int: fields.Integer, + }, + ) + + @decorators.post_load + def make_object(self, data, many, partial): + return TopClass(**data) + + top_schema = TopSchema() + + top_class_str_example = TopClass(polyfield='abc') + top_class_str_example_dumped = top_schema.dump(top_class_str_example) + print(top_class_str_example_dumped) + top_class_str_example_loaded = top_schema.load(top_class_str_example_dumped) + assert top_class_str_example_loaded == top_class_str_example + + print('---') + + top_class_int_example = TopClass(polyfield=42) + top_class_int_example_dumped = top_schema.dump(top_class_int_example) + print(top_class_int_example_dumped) + top_class_int_example_loaded = top_schema.load(top_class_int_example_dumped) + assert top_class_int_example_loaded == top_class_int_example From a4fe15f08d514c31ff779a069541b090d6172ae6 Mon Sep 17 00:00:00 2001 From: Kyle Altendorf Date: Tue, 7 Jan 2020 00:33:33 -0500 Subject: [PATCH 04/16] add some 'real' tests --- tests/test_serialization.py | 97 +++++++++++++++++++++++++++++++++++++ 1 file changed, 97 insertions(+) diff --git a/tests/test_serialization.py b/tests/test_serialization.py index d6e93ee..336407d 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -227,3 +227,100 @@ def make_object(self, data, many, partial): print(top_class_int_example_dumped) top_class_int_example_loaded = top_schema.load(top_class_int_example_dumped) assert top_class_int_example_loaded == top_class_int_example + + +explicit_poly_field = ExplicitPolyField( + class_to_schema_mapping={ + str: fields.String, + int: fields.Integer, + dict: fields.Dict, + }, +) + + +ExplicitPolyFieldExample = namedtuple( + 'ExplicitPolyFieldExample', + [ + 'type_name', + 'value', + 'layer', + 'field', + ], +) + + +def create_explicit_poly_field_example(type_name, value, field): + return ExplicitPolyFieldExample( + type_name=type_name, + value=value, + layer={'type': type_name, 'value': value}, + field=field, + ) + + +parametrize_explicit_poly_field_type_name_and_value = pytest.mark.parametrize( + ['example'], + [ + [create_explicit_poly_field_example( + type_name='str', + value='red', + field=fields.String, + )], + [create_explicit_poly_field_example( + type_name='int', + value=42, + field=fields.Integer, + )], + [create_explicit_poly_field_example( + type_name='dict', + value={'puppy': 3.9}, + field=fields.Dict, + )], + ], +) + + +@parametrize_explicit_poly_field_type_name_and_value +def test_serializing_explicit_poly_field(example): + Point = namedtuple('Point', ['x', 'y']) + p = Point(x=example.value, y=37) + + assert explicit_poly_field.serialize('x', p) == example.layer + + +@parametrize_explicit_poly_field_type_name_and_value +def test_serializing_explicit_poly_field_type_name(example): + Point = namedtuple('Point', ['x', 'y']) + p = Point(x=example.value, y=37) + + serialized = explicit_poly_field.serialize('x', p) + assert serialized['type'] == example.type_name + + +@parametrize_explicit_poly_field_type_name_and_value +def test_serializing_explicit_poly_field_type_name(example): + Point = namedtuple('Point', ['x', 'y']) + p = Point(x=example.value, y=37) + + serialized = explicit_poly_field.serialize('x', p) + assert serialized['value'] is example.value + + +@parametrize_explicit_poly_field_type_name_and_value +def test_deserializing_explicit_poly_field_value(example): + assert explicit_poly_field.deserialize(example.layer) is example.value + + +@parametrize_explicit_poly_field_type_name_and_value +def test_deserializing_explicit_poly_field_field_type(example): + # TODO: Checking the type only does so much, really want to compare + # the fields but they don't implement == so we'll have to code + # that up to check it. + assert ( + type(explicit_poly_field.deserialization_schema_selector( + example.layer, + {'x': example.layer}, + )) + is type(example.field()) + ) + From b0ab60d9ff144ab7ee7060d04c874829bdcb19af Mon Sep 17 00:00:00 2001 From: Kyle Altendorf Date: Tue, 7 Jan 2020 00:36:08 -0500 Subject: [PATCH 05/16] fix reused test name --- tests/test_serialization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_serialization.py b/tests/test_serialization.py index 336407d..1928416 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -298,7 +298,7 @@ def test_serializing_explicit_poly_field_type_name(example): @parametrize_explicit_poly_field_type_name_and_value -def test_serializing_explicit_poly_field_type_name(example): +def test_serializing_explicit_poly_field_value(example): Point = namedtuple('Point', ['x', 'y']) p = Point(x=example.value, y=37) From 28994156537f9557c05ecc057fa4e4be836f8d56 Mon Sep 17 00:00:00 2001 From: Kyle Altendorf Date: Tue, 7 Jan 2020 01:02:42 -0500 Subject: [PATCH 06/16] py2 fixups --- marshmallow_polyfield/polyfield.py | 14 ++++++++--- tests/test_serialization.py | 39 ++++++++++++++++++------------ 2 files changed, 34 insertions(+), 19 deletions(-) diff --git a/marshmallow_polyfield/polyfield.py b/marshmallow_polyfield/polyfield.py index 231f2f3..a2cecc2 100644 --- a/marshmallow_polyfield/polyfield.py +++ b/marshmallow_polyfield/polyfield.py @@ -153,15 +153,23 @@ def __init__( self, class_to_schema_mapping, create_label_schema=create_label_schema, + class_to_name_overrides=None, many=False, **metadata ): super(ExplicitPolyField, self).__init__(many=many, **metadata) + + if class_to_name_overrides is None: + class_to_name_overrides = {} + self._class_to_schema_mapping = class_to_schema_mapping self._class_to_name = { cls: cls.__name__ for cls in self._class_to_schema_mapping.keys() } + + self._class_to_name.update(class_to_name_overrides) + self._name_to_class = { name: cls for cls, name in self._class_to_name.items() @@ -179,14 +187,14 @@ def serialization_value_modifier(self, base_object, parent_obj): cls = type(base_object) name = self._class_to_name[cls] - return {'type': name, 'value': base_object} + return {u'type': name, u'value': base_object} def deserialization_schema_selector(self, object_dict, parent_object_dict): - name = object_dict['type'] + name = object_dict[u'type'] cls = self._name_to_class[name] schema = self._class_to_schema_mapping[cls] return schema() def deserialization_value_modifier(self, object_dict, parent_object_dict): - return object_dict['value'] + return object_dict[u'value'] diff --git a/tests/test_serialization.py b/tests/test_serialization.py index 1928416..94cc8f2 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -2,6 +2,7 @@ from marshmallow import decorators, fields, Schema from marshmallow_polyfield.polyfield import PolyField, ExplicitPolyField import pytest +from six import text_type from tests.shapes import ( Rectangle, Triangle, @@ -112,19 +113,20 @@ class LabelSchema(marshmallow.Schema): return LabelSchema class_to_schema = { - str: marshmallow.fields.String, + text_type: marshmallow.fields.String, int: marshmallow.fields.Integer, } name_to_class = { - 'str': str, - 'int': int, + u'str': text_type, + u'int': int, } class_to_name = { cls: name for name, cls in name_to_class.items() } + class_to_name[text_type] = u'str' def serialization_schema(base_object, parent_obj): cls = type(base_object) @@ -169,12 +171,12 @@ class TopSchema(marshmallow.Schema): ) @marshmallow.decorators.post_load - def make_object(self, data, many, partial): + def make_object(self, data, many=None, partial=None): return TopClass(**data) top_schema = TopSchema() - top_class_str_example = TopClass(polyfield='abc') + top_class_str_example = TopClass(polyfield=u'abc') top_class_str_example_dumped = top_schema.dump(top_class_str_example) print(top_class_str_example_dumped) top_class_str_example_loaded = top_schema.load(top_class_str_example_dumped) @@ -203,18 +205,21 @@ def __eq__(self, other): class TopSchema(Schema): polyfield = ExplicitPolyField( class_to_schema_mapping={ - str: fields.String, + text_type: fields.String, int: fields.Integer, }, + class_to_name_overrides={ + text_type: u'str', + }, ) @decorators.post_load - def make_object(self, data, many, partial): + def make_object(self, data, many=None, partial=None): return TopClass(**data) top_schema = TopSchema() - top_class_str_example = TopClass(polyfield='abc') + top_class_str_example = TopClass(polyfield=u'abc') top_class_str_example_dumped = top_schema.dump(top_class_str_example) print(top_class_str_example_dumped) top_class_str_example_loaded = top_schema.load(top_class_str_example_dumped) @@ -231,10 +236,13 @@ def make_object(self, data, many, partial): explicit_poly_field = ExplicitPolyField( class_to_schema_mapping={ - str: fields.String, + text_type: fields.String, int: fields.Integer, dict: fields.Dict, }, + class_to_name_overrides={ + text_type: 'str', + }, ) @@ -262,18 +270,18 @@ def create_explicit_poly_field_example(type_name, value, field): ['example'], [ [create_explicit_poly_field_example( - type_name='str', - value='red', + type_name=u'str', + value=u'red', field=fields.String, )], [create_explicit_poly_field_example( - type_name='int', + type_name=u'int', value=42, field=fields.Integer, )], [create_explicit_poly_field_example( - type_name='dict', - value={'puppy': 3.9}, + type_name=u'dict', + value={u'puppy': 3.9}, field=fields.Dict, )], ], @@ -322,5 +330,4 @@ def test_deserializing_explicit_poly_field_field_type(example): {'x': example.layer}, )) is type(example.field()) - ) - + ) # noqa E721 From 70e57766591e5f2f529256aec10bae9890d7f70c Mon Sep 17 00:00:00 2001 From: Kyle Altendorf Date: Tue, 7 Jan 2020 01:31:05 -0500 Subject: [PATCH 07/16] add coverage of ExplicitPolyField without name overrides --- tests/test_serialization.py | 39 ++++++++++++++++++++++++++++++------- 1 file changed, 32 insertions(+), 7 deletions(-) diff --git a/tests/test_serialization.py b/tests/test_serialization.py index 94cc8f2..a672a0a 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -234,7 +234,7 @@ def make_object(self, data, many=None, partial=None): assert top_class_int_example_loaded == top_class_int_example -explicit_poly_field = ExplicitPolyField( +explicit_poly_field_with_overrides = ExplicitPolyField( class_to_schema_mapping={ text_type: fields.String, int: fields.Integer, @@ -246,6 +246,14 @@ def make_object(self, data, many=None, partial=None): ) +explicit_poly_field_without_overrides = ExplicitPolyField( + class_to_schema_mapping={ + int: fields.Integer, + dict: fields.Dict, + }, +) + + ExplicitPolyFieldExample = namedtuple( 'ExplicitPolyFieldExample', [ @@ -253,16 +261,18 @@ def make_object(self, data, many=None, partial=None): 'value', 'layer', 'field', + 'polyfield', ], ) -def create_explicit_poly_field_example(type_name, value, field): +def create_explicit_poly_field_example(type_name, value, field, polyfield): return ExplicitPolyFieldExample( type_name=type_name, value=value, layer={'type': type_name, 'value': value}, field=field, + polyfield=polyfield, ) @@ -273,16 +283,31 @@ def create_explicit_poly_field_example(type_name, value, field): type_name=u'str', value=u'red', field=fields.String, + polyfield=explicit_poly_field_with_overrides, + )], + [create_explicit_poly_field_example( + type_name=u'int', + value=42, + field=fields.Integer, + polyfield=explicit_poly_field_with_overrides, + )], + [create_explicit_poly_field_example( + type_name=u'dict', + value={u'puppy': 3.9}, + field=fields.Dict, + polyfield=explicit_poly_field_with_overrides, )], [create_explicit_poly_field_example( type_name=u'int', value=42, field=fields.Integer, + polyfield=explicit_poly_field_without_overrides, )], [create_explicit_poly_field_example( type_name=u'dict', value={u'puppy': 3.9}, field=fields.Dict, + polyfield=explicit_poly_field_without_overrides, )], ], ) @@ -293,7 +318,7 @@ def test_serializing_explicit_poly_field(example): Point = namedtuple('Point', ['x', 'y']) p = Point(x=example.value, y=37) - assert explicit_poly_field.serialize('x', p) == example.layer + assert example.polyfield.serialize('x', p) == example.layer @parametrize_explicit_poly_field_type_name_and_value @@ -301,7 +326,7 @@ def test_serializing_explicit_poly_field_type_name(example): Point = namedtuple('Point', ['x', 'y']) p = Point(x=example.value, y=37) - serialized = explicit_poly_field.serialize('x', p) + serialized = example.polyfield.serialize('x', p) assert serialized['type'] == example.type_name @@ -310,13 +335,13 @@ def test_serializing_explicit_poly_field_value(example): Point = namedtuple('Point', ['x', 'y']) p = Point(x=example.value, y=37) - serialized = explicit_poly_field.serialize('x', p) + serialized = example.polyfield.serialize('x', p) assert serialized['value'] is example.value @parametrize_explicit_poly_field_type_name_and_value def test_deserializing_explicit_poly_field_value(example): - assert explicit_poly_field.deserialize(example.layer) is example.value + assert example.polyfield.deserialize(example.layer) is example.value @parametrize_explicit_poly_field_type_name_and_value @@ -325,7 +350,7 @@ def test_deserializing_explicit_poly_field_field_type(example): # the fields but they don't implement == so we'll have to code # that up to check it. assert ( - type(explicit_poly_field.deserialization_schema_selector( + type(example.polyfield.deserialization_schema_selector( example.layer, {'x': example.layer}, )) From 2c7d9164fb320c18d9d6496b87a8b9253d14a4ff Mon Sep 17 00:00:00 2001 From: Kyle Altendorf Date: Wed, 15 Jan 2020 21:32:52 -0500 Subject: [PATCH 08/16] add 'direct' tests of new [de]serialization_value_modifier() methods --- tests/polyclasses.py | 27 +++++++++++++++++++++++++++ tests/test_serialization.py | 25 ++++++++++++++++++++++++- 2 files changed, 51 insertions(+), 1 deletion(-) diff --git a/tests/polyclasses.py b/tests/polyclasses.py index 6919e6a..16afac4 100644 --- a/tests/polyclasses.py +++ b/tests/polyclasses.py @@ -1,4 +1,5 @@ from marshmallow_polyfield import PolyFieldBase +from marshmallow import Schema, fields from tests.shapes import ( shape_schema_serialization_disambiguation, @@ -33,3 +34,29 @@ def serialization_schema_selector(self, value, obj): def deserialization_schema_selector(self, value, obj): return shape_property_schema_deserialization_disambiguation(value, obj) + + +class BadStringValueModifierSchema(Schema): + a = fields.String() + + +class BadStringValueModifierPolyField(PolyFieldBase): + def __init__(self, bad_string_value, many=False, **metadata): + super(BadStringValueModifierPolyField, self).__init__( + many=many, + **metadata + ) + + self.bad_string_value = bad_string_value + + def serialization_schema_selector(self, value, obj): + return BadStringValueModifierSchema() + + def deserialization_schema_selector(self, value, obj): + return BadStringValueModifierSchema() + + def serialization_value_modifier(self, value, obj): + return {'a': self.bad_string_value} + + def deserialization_value_modifier(self, value, obj): + return {'a': self.bad_string_value} diff --git a/tests/test_serialization.py b/tests/test_serialization.py index a672a0a..ec1a39c 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -9,7 +9,11 @@ shape_schema_serialization_disambiguation, shape_schema_deserialization_disambiguation, ) -from tests.polyclasses import ShapePolyField, with_all +from tests.polyclasses import ( + BadStringValueModifierPolyField, + ShapePolyField, + with_all, +) def with_both_shapes(func): @@ -234,6 +238,25 @@ def make_object(self, data, many=None, partial=None): assert top_class_int_example_loaded == top_class_int_example +def test_polyfield_serialization_value_modifier(): + bad_value = 'here is a specific string' + + field = BadStringValueModifierPolyField(bad_string_value=bad_value) + + Point = namedtuple('Point', ['x', 'y']) + p = Point(x='another different string', y=37) + + assert field.serialize('x', p)['a'] is bad_value + + +def test_polyfield_deserialization_value_modifier(): + bad_value = 'here is a specific string' + + field = BadStringValueModifierPolyField(bad_string_value=bad_value) + + assert field.deserialize('another different string')['a'] is bad_value + + explicit_poly_field_with_overrides = ExplicitPolyField( class_to_schema_mapping={ text_type: fields.String, From eb17a1e3b067974a692994bd8b3639b8c4fb4c84 Mon Sep 17 00:00:00 2001 From: Kyle Altendorf Date: Fri, 17 Jan 2020 21:30:40 -0500 Subject: [PATCH 09/16] remove exploratory 'tests' --- tests/test_serialization.py | 134 +----------------------------------- 1 file changed, 1 insertion(+), 133 deletions(-) diff --git a/tests/test_serialization.py b/tests/test_serialization.py index ec1a39c..91f3c0d 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -1,5 +1,5 @@ from collections import namedtuple -from marshmallow import decorators, fields, Schema +from marshmallow import fields, Schema from marshmallow_polyfield.polyfield import PolyField, ExplicitPolyField import pytest from six import text_type @@ -106,138 +106,6 @@ def test_serializing_polyfield_by_parent_type(field): assert rect_dict == {"length": 4, "width": 10, "color": "blue"} -def test_serializing_with_modification(): - import marshmallow - - def create_label_schema(schema): - class LabelSchema(marshmallow.Schema): - type = marshmallow.fields.String() - value = schema() - - return LabelSchema - - class_to_schema = { - text_type: marshmallow.fields.String, - int: marshmallow.fields.Integer, - } - - name_to_class = { - u'str': text_type, - u'int': int, - } - - class_to_name = { - cls: name - for name, cls in name_to_class.items() - } - class_to_name[text_type] = u'str' - - def serialization_schema(base_object, parent_obj): - cls = type(base_object) - schema = class_to_schema[cls] - - label_schema = create_label_schema(schema=schema) - - return label_schema() - - def serialization_value(base_object, parent_obj): - cls = type(base_object) - name = class_to_name[cls] - - return {'type': name, 'value': base_object} - - def deserialization_schema(object_dict, parent_object_dict): - name = object_dict['type'] - cls = name_to_class[name] - schema = class_to_schema[cls] - - return schema() - - def deserialization_value(object_dict, parent_object_dict): - return object_dict['value'] - - class TopClass: - def __init__(self, polyfield): - self.polyfield = polyfield - - def __eq__(self, other): - if type(self) != type(other): - return False - - return self.polyfield == other.polyfield - - class TopSchema(marshmallow.Schema): - polyfield = PolyField( - serialization_schema_selector=serialization_schema, - deserialization_schema_selector=deserialization_schema, - serialization_value_modifier=serialization_value, - deserialization_value_modifier=deserialization_value, - ) - - @marshmallow.decorators.post_load - def make_object(self, data, many=None, partial=None): - return TopClass(**data) - - top_schema = TopSchema() - - top_class_str_example = TopClass(polyfield=u'abc') - top_class_str_example_dumped = top_schema.dump(top_class_str_example) - print(top_class_str_example_dumped) - top_class_str_example_loaded = top_schema.load(top_class_str_example_dumped) - assert top_class_str_example_loaded == top_class_str_example - - print('---') - - top_class_int_example = TopClass(polyfield=42) - top_class_int_example_dumped = top_schema.dump(top_class_int_example) - print(top_class_int_example_dumped) - top_class_int_example_loaded = top_schema.load(top_class_int_example_dumped) - assert top_class_int_example_loaded == top_class_int_example - - -def test_serializing_with_modification_ExplicitPolyField(): - class TopClass: - def __init__(self, polyfield): - self.polyfield = polyfield - - def __eq__(self, other): - if type(self) != type(other): - return False - - return self.polyfield == other.polyfield - - class TopSchema(Schema): - polyfield = ExplicitPolyField( - class_to_schema_mapping={ - text_type: fields.String, - int: fields.Integer, - }, - class_to_name_overrides={ - text_type: u'str', - }, - ) - - @decorators.post_load - def make_object(self, data, many=None, partial=None): - return TopClass(**data) - - top_schema = TopSchema() - - top_class_str_example = TopClass(polyfield=u'abc') - top_class_str_example_dumped = top_schema.dump(top_class_str_example) - print(top_class_str_example_dumped) - top_class_str_example_loaded = top_schema.load(top_class_str_example_dumped) - assert top_class_str_example_loaded == top_class_str_example - - print('---') - - top_class_int_example = TopClass(polyfield=42) - top_class_int_example_dumped = top_schema.dump(top_class_int_example) - print(top_class_int_example_dumped) - top_class_int_example_loaded = top_schema.load(top_class_int_example_dumped) - assert top_class_int_example_loaded == top_class_int_example - - def test_polyfield_serialization_value_modifier(): bad_value = 'here is a specific string' From f4b14b4fffc2b6b86271f93afdee960beca31676 Mon Sep 17 00:00:00 2001 From: Kyle Altendorf Date: Fri, 17 Jan 2020 21:51:00 -0500 Subject: [PATCH 10/16] move deserialization tests to tests_deserialization --- tests/polyclasses.py | 87 +++++++++++++++++++++++++- tests/test_deserialization.py | 29 +++++++++ tests/test_serialization.py | 111 +--------------------------------- 3 files changed, 118 insertions(+), 109 deletions(-) diff --git a/tests/polyclasses.py b/tests/polyclasses.py index 16afac4..8caa623 100644 --- a/tests/polyclasses.py +++ b/tests/polyclasses.py @@ -1,5 +1,11 @@ -from marshmallow_polyfield import PolyFieldBase +from collections import namedtuple + from marshmallow import Schema, fields +import pytest +from six import text_type + +from marshmallow_polyfield import PolyFieldBase +from marshmallow_polyfield.polyfield import ExplicitPolyField from tests.shapes import ( shape_schema_serialization_disambiguation, @@ -60,3 +66,82 @@ def serialization_value_modifier(self, value, obj): def deserialization_value_modifier(self, value, obj): return {'a': self.bad_string_value} + + +explicit_poly_field_with_overrides = ExplicitPolyField( + class_to_schema_mapping={ + text_type: fields.String, + int: fields.Integer, + dict: fields.Dict, + }, + class_to_name_overrides={ + text_type: 'str', + }, +) + + +explicit_poly_field_without_overrides = ExplicitPolyField( + class_to_schema_mapping={ + int: fields.Integer, + dict: fields.Dict, + }, +) + + +ExplicitPolyFieldExample = namedtuple( + 'ExplicitPolyFieldExample', + [ + 'type_name', + 'value', + 'layer', + 'field', + 'polyfield', + ], +) + + +def create_explicit_poly_field_example(type_name, value, field, polyfield): + return ExplicitPolyFieldExample( + type_name=type_name, + value=value, + layer={'type': type_name, 'value': value}, + field=field, + polyfield=polyfield, + ) + + +parametrize_explicit_poly_field_type_name_and_value = pytest.mark.parametrize( + ['example'], + [ + [create_explicit_poly_field_example( + type_name=u'str', + value=u'red', + field=fields.String, + polyfield=explicit_poly_field_with_overrides, + )], + [create_explicit_poly_field_example( + type_name=u'int', + value=42, + field=fields.Integer, + polyfield=explicit_poly_field_with_overrides, + )], + [create_explicit_poly_field_example( + type_name=u'dict', + value={u'puppy': 3.9}, + field=fields.Dict, + polyfield=explicit_poly_field_with_overrides, + )], + [create_explicit_poly_field_example( + type_name=u'int', + value=42, + field=fields.Integer, + polyfield=explicit_poly_field_without_overrides, + )], + [create_explicit_poly_field_example( + type_name=u'dict', + value={u'puppy': 3.9}, + field=fields.Dict, + polyfield=explicit_poly_field_without_overrides, + )], + ], +) diff --git a/tests/test_deserialization.py b/tests/test_deserialization.py index b119d56..7ef2c0f 100644 --- a/tests/test_deserialization.py +++ b/tests/test_deserialization.py @@ -12,6 +12,8 @@ fuzzy_schema_deserialization_disambiguation, ) from tests.polyclasses import ( + BadStringValueModifierPolyField, + parametrize_explicit_poly_field_type_name_and_value, ShapePolyField, ShapePropertyPolyField, with_all @@ -291,3 +293,30 @@ def test_deserialize_polyfield(self, schema): 'type': 'rectangle'} ) assert data == original + + +def test_polyfield_deserialization_value_modifier(): + bad_value = 'here is a specific string' + + field = BadStringValueModifierPolyField(bad_string_value=bad_value) + + assert field.deserialize('another different string')['a'] is bad_value + + +@parametrize_explicit_poly_field_type_name_and_value +def test_deserializing_explicit_poly_field_value(example): + assert example.polyfield.deserialize(example.layer) is example.value + + +@parametrize_explicit_poly_field_type_name_and_value +def test_deserializing_explicit_poly_field_field_type(example): + # TODO: Checking the type only does so much, really want to compare + # the fields but they don't implement == so we'll have to code + # that up to check it. + assert ( + type(example.polyfield.deserialization_schema_selector( + example.layer, + {'x': example.layer}, + )) + is type(example.field()) + ) # noqa E721 diff --git a/tests/test_serialization.py b/tests/test_serialization.py index 91f3c0d..ce4df72 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -1,8 +1,8 @@ from collections import namedtuple from marshmallow import fields, Schema -from marshmallow_polyfield.polyfield import PolyField, ExplicitPolyField +from marshmallow_polyfield.polyfield import PolyField import pytest -from six import text_type + from tests.shapes import ( Rectangle, Triangle, @@ -11,6 +11,7 @@ ) from tests.polyclasses import ( BadStringValueModifierPolyField, + parametrize_explicit_poly_field_type_name_and_value, ShapePolyField, with_all, ) @@ -117,93 +118,6 @@ def test_polyfield_serialization_value_modifier(): assert field.serialize('x', p)['a'] is bad_value -def test_polyfield_deserialization_value_modifier(): - bad_value = 'here is a specific string' - - field = BadStringValueModifierPolyField(bad_string_value=bad_value) - - assert field.deserialize('another different string')['a'] is bad_value - - -explicit_poly_field_with_overrides = ExplicitPolyField( - class_to_schema_mapping={ - text_type: fields.String, - int: fields.Integer, - dict: fields.Dict, - }, - class_to_name_overrides={ - text_type: 'str', - }, -) - - -explicit_poly_field_without_overrides = ExplicitPolyField( - class_to_schema_mapping={ - int: fields.Integer, - dict: fields.Dict, - }, -) - - -ExplicitPolyFieldExample = namedtuple( - 'ExplicitPolyFieldExample', - [ - 'type_name', - 'value', - 'layer', - 'field', - 'polyfield', - ], -) - - -def create_explicit_poly_field_example(type_name, value, field, polyfield): - return ExplicitPolyFieldExample( - type_name=type_name, - value=value, - layer={'type': type_name, 'value': value}, - field=field, - polyfield=polyfield, - ) - - -parametrize_explicit_poly_field_type_name_and_value = pytest.mark.parametrize( - ['example'], - [ - [create_explicit_poly_field_example( - type_name=u'str', - value=u'red', - field=fields.String, - polyfield=explicit_poly_field_with_overrides, - )], - [create_explicit_poly_field_example( - type_name=u'int', - value=42, - field=fields.Integer, - polyfield=explicit_poly_field_with_overrides, - )], - [create_explicit_poly_field_example( - type_name=u'dict', - value={u'puppy': 3.9}, - field=fields.Dict, - polyfield=explicit_poly_field_with_overrides, - )], - [create_explicit_poly_field_example( - type_name=u'int', - value=42, - field=fields.Integer, - polyfield=explicit_poly_field_without_overrides, - )], - [create_explicit_poly_field_example( - type_name=u'dict', - value={u'puppy': 3.9}, - field=fields.Dict, - polyfield=explicit_poly_field_without_overrides, - )], - ], -) - - @parametrize_explicit_poly_field_type_name_and_value def test_serializing_explicit_poly_field(example): Point = namedtuple('Point', ['x', 'y']) @@ -228,22 +142,3 @@ def test_serializing_explicit_poly_field_value(example): serialized = example.polyfield.serialize('x', p) assert serialized['value'] is example.value - - -@parametrize_explicit_poly_field_type_name_and_value -def test_deserializing_explicit_poly_field_value(example): - assert example.polyfield.deserialize(example.layer) is example.value - - -@parametrize_explicit_poly_field_type_name_and_value -def test_deserializing_explicit_poly_field_field_type(example): - # TODO: Checking the type only does so much, really want to compare - # the fields but they don't implement == so we'll have to code - # that up to check it. - assert ( - type(example.polyfield.deserialization_schema_selector( - example.layer, - {'x': example.layer}, - )) - is type(example.field()) - ) # noqa E721 From aebb09888869d237a2842acc04f17c94224899eb Mon Sep 17 00:00:00 2001 From: Kyle Altendorf Date: Fri, 17 Jan 2020 22:11:35 -0500 Subject: [PATCH 11/16] expose ExplicitPolyField as marshmallow_polyfield.ExplicitPolyField --- marshmallow_polyfield/__init__.py | 8 ++++++-- tests/polyclasses.py | 3 +-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/marshmallow_polyfield/__init__.py b/marshmallow_polyfield/__init__.py index 92ff052..1600f5a 100644 --- a/marshmallow_polyfield/__init__.py +++ b/marshmallow_polyfield/__init__.py @@ -1,3 +1,7 @@ -from marshmallow_polyfield.polyfield import PolyField, PolyFieldBase +from marshmallow_polyfield.polyfield import ( + PolyField, + PolyFieldBase, + ExplicitPolyField, +) -__all__ = ['PolyField', 'PolyFieldBase'] +__all__ = ['PolyField', 'PolyFieldBase', 'ExplicitPolyField'] diff --git a/tests/polyclasses.py b/tests/polyclasses.py index 8caa623..36c387e 100644 --- a/tests/polyclasses.py +++ b/tests/polyclasses.py @@ -4,8 +4,7 @@ import pytest from six import text_type -from marshmallow_polyfield import PolyFieldBase -from marshmallow_polyfield.polyfield import ExplicitPolyField +from marshmallow_polyfield import PolyFieldBase, ExplicitPolyField from tests.shapes import ( shape_schema_serialization_disambiguation, From 08a5986ee73f96dbdeb3686505140898d4608520 Mon Sep 17 00:00:00 2001 From: Kyle Altendorf Date: Fri, 17 Jan 2020 22:34:33 -0500 Subject: [PATCH 12/16] Add docstrings for ExplicitPolyField --- marshmallow_polyfield/polyfield.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/marshmallow_polyfield/polyfield.py b/marshmallow_polyfield/polyfield.py index a2cecc2..3c32e22 100644 --- a/marshmallow_polyfield/polyfield.py +++ b/marshmallow_polyfield/polyfield.py @@ -149,6 +149,12 @@ class LabelSchema(Schema): class ExplicitPolyField(PolyFieldBase): + """ + Similar to PolyField except that disambiguation is done by creating and + consuming an extra layer with an explicit string used to disambiguate the + type. The layer defaults to the form of ``{'type': cls.__name__, + 'value': }``. + """ def __init__( self, class_to_schema_mapping, @@ -157,6 +163,14 @@ def __init__( many=False, **metadata ): + """ + :param class_to_schema_mapping: Classes as keys mapped to the schema + to be used for each. + :param create_label_schema: Callable returning a schema used to create + the extra serialized layer including the type name. + :param class_to_name_overrides: Classes as keys mapped to the name to + use as the serialized type name. Default is to use ``cls.__name__``. + """ super(ExplicitPolyField, self).__init__(many=many, **metadata) if class_to_name_overrides is None: From 3b221f80012c6a21cbc9be60d2ff1b54045d388c Mon Sep 17 00:00:00 2001 From: Kyle Altendorf Date: Fri, 17 Jan 2020 23:02:43 -0500 Subject: [PATCH 13/16] Add ExplicitPolyField example to README.rst --- README.rst | 66 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) diff --git a/README.rst b/README.rst index 6921c5b..9c8688d 100644 --- a/README.rst +++ b/README.rst @@ -103,3 +103,69 @@ Once setup the schema should act like any other schema. If it does not then plea data.get('main'), data.get('others') ) + +ExplicitPolyField +----------------- + +The ``ExplicitPolyField`` class adds an additional layer to the serialized data to embed a string used to disambiguate the type of the serialized data. +This avoids any uncertainty when faced with similarly serialized classes. +A mapping from classes to be supported to the schemas used to process them must be provided. +By default the serialized type names are taken from ``cls.__name__`` but this can be overridden. + +.. code:: python + + import json + + from marshmallow import Schema, decorators, fields + from marshmallow_polyfield import ExplicitPolyField + from six import text_type + + + class TopClass: + def __init__(self, polyfield_list): + self.polyfield_list = polyfield_list + + def __eq__(self, other): + if type(self) != type(other): + return False + + return self.polyfield_list == other.polyfield_list + + class TopSchema(Schema): + polyfield_list = fields.List(ExplicitPolyField( + class_to_schema_mapping={ + text_type: fields.String, + int: fields.Integer, + }, + class_to_name_overrides={ + text_type: u'my string name', + }, + )) + + @decorators.post_load + def make_object(self, data, many=None, partial=None): + return TopClass(**data) + + top_schema = TopSchema() + + top_class_example = TopClass(polyfield_list=[u'epf', 37]) + top_class_example_dumped = top_schema.dump(top_class_example) + top_class_example_loaded = top_schema.load(top_class_example_dumped) + + assert top_class_example_loaded == top_class_example + print(json.dumps(top_class_example_dumped, indent=4)) + +.. code:: json + + { + "polyfield_list": [ + { + "type": "my string name", + "value": "epf" + }, + { + "type": "int", + "value": 37 + } + ] + } From 55e5157ac9727cdc13e16470e44154adf2522e6c Mon Sep 17 00:00:00 2001 From: Kyle Altendorf Date: Fri, 17 Jan 2020 23:46:58 -0500 Subject: [PATCH 14/16] Raise exception if final name mapping reuses names --- marshmallow_polyfield/polyfield.py | 28 ++++++++++++++++++++++++++-- tests/test_polyfield_base.py | 26 +++++++++++++++++++++++++- 2 files changed, 51 insertions(+), 3 deletions(-) diff --git a/marshmallow_polyfield/polyfield.py b/marshmallow_polyfield/polyfield.py index 3c32e22..c2d51ec 100644 --- a/marshmallow_polyfield/polyfield.py +++ b/marshmallow_polyfield/polyfield.py @@ -1,4 +1,6 @@ import abc +import itertools + from six import raise_from, with_metaclass from marshmallow import Schema, ValidationError @@ -148,6 +150,10 @@ class LabelSchema(Schema): return LabelSchema +class ExplicitNamesNotUniqueError(Exception): + pass + + class ExplicitPolyField(PolyFieldBase): """ Similar to PolyField except that disambiguation is done by creating and @@ -184,10 +190,28 @@ def __init__( self._class_to_name.update(class_to_name_overrides) + name_to_classes = { + name: [cls for cls, name in class_name_pairs] + for name, class_name_pairs in itertools.groupby( + sorted(self._class_to_name.items(), key=lambda x: x[1]), + key=lambda x: x[1], + ) + } + + reused_names = { + name: classes + for name, classes in name_to_classes.items() + if len(classes) > 1 + } + + if len(reused_names) > 0: + raise ExplicitNamesNotUniqueError(repr(reused_names)) + self._name_to_class = { - name: cls - for cls, name in self._class_to_name.items() + name: classes[0] + for name, classes in name_to_classes.items() } + self.create_label_schema = create_label_schema def serialization_schema_selector(self, base_object, parent_obj): diff --git a/tests/test_polyfield_base.py b/tests/test_polyfield_base.py index a2ff228..9feb59a 100644 --- a/tests/test_polyfield_base.py +++ b/tests/test_polyfield_base.py @@ -1,4 +1,12 @@ -from marshmallow_polyfield.polyfield import PolyFieldBase +import re + +from marshmallow import fields +from marshmallow_polyfield.polyfield import ( + ExplicitPolyField, + ExplicitNamesNotUniqueError, + PolyFieldBase, +) +import pytest class TrivialExample(PolyFieldBase): @@ -24,3 +32,19 @@ def test_polyfield_base(): pass else: assert False, 'expected to raise' + + +def test_explicit_polyfield_raises_for_nonunique_names(): + same_name = 'same name' + + with pytest.raises( + ExplicitNamesNotUniqueError, + match=re.escape("{'same name': [, ]}"), + ): + ExplicitPolyField( + class_to_schema_mapping={ + str: fields.String, + int: fields.Integer, + }, + class_to_name_overrides={str: same_name, int: same_name}, + ) From 7d3a5e550ba1fd34ab330ba6d2c8fd4b2d4ab25d Mon Sep 17 00:00:00 2001 From: Kyle Altendorf Date: Fri, 17 Jan 2020 23:51:46 -0500 Subject: [PATCH 15/16] remove some py2 stuff introduced in this branch --- marshmallow_polyfield/polyfield.py | 2 +- tests/polyclasses.py | 10 +++------- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/marshmallow_polyfield/polyfield.py b/marshmallow_polyfield/polyfield.py index a0dc2a2..3338eb4 100644 --- a/marshmallow_polyfield/polyfield.py +++ b/marshmallow_polyfield/polyfield.py @@ -173,7 +173,7 @@ def __init__( :param class_to_name_overrides: Classes as keys mapped to the name to use as the serialized type name. Default is to use ``cls.__name__``. """ - super(ExplicitPolyField, self).__init__(many=many, **metadata) + super().__init__(many=many, **metadata) if class_to_name_overrides is None: class_to_name_overrides = {} diff --git a/tests/polyclasses.py b/tests/polyclasses.py index 36c387e..a75694e 100644 --- a/tests/polyclasses.py +++ b/tests/polyclasses.py @@ -2,7 +2,6 @@ from marshmallow import Schema, fields import pytest -from six import text_type from marshmallow_polyfield import PolyFieldBase, ExplicitPolyField @@ -47,10 +46,7 @@ class BadStringValueModifierSchema(Schema): class BadStringValueModifierPolyField(PolyFieldBase): def __init__(self, bad_string_value, many=False, **metadata): - super(BadStringValueModifierPolyField, self).__init__( - many=many, - **metadata - ) + super().__init__(many=many, **metadata) self.bad_string_value = bad_string_value @@ -69,12 +65,12 @@ def deserialization_value_modifier(self, value, obj): explicit_poly_field_with_overrides = ExplicitPolyField( class_to_schema_mapping={ - text_type: fields.String, + str: fields.String, int: fields.Integer, dict: fields.Dict, }, class_to_name_overrides={ - text_type: 'str', + str: 'str', }, ) From aa6c2a6442ae60353cedb362be985e3679fc98de Mon Sep 17 00:00:00 2001 From: Kyle Altendorf Date: Fri, 17 Jan 2020 23:54:09 -0500 Subject: [PATCH 16/16] sort classes for consistent ordering in exception --- marshmallow_polyfield/polyfield.py | 5 ++++- tests/test_polyfield_base.py | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/marshmallow_polyfield/polyfield.py b/marshmallow_polyfield/polyfield.py index 3338eb4..df39aa2 100644 --- a/marshmallow_polyfield/polyfield.py +++ b/marshmallow_polyfield/polyfield.py @@ -187,7 +187,10 @@ def __init__( self._class_to_name.update(class_to_name_overrides) name_to_classes = { - name: [cls for cls, name in class_name_pairs] + name: sorted( + (cls for cls, name in class_name_pairs), + key=lambda cls: cls.__name__, + ) for name, class_name_pairs in itertools.groupby( sorted(self._class_to_name.items(), key=lambda x: x[1]), key=lambda x: x[1], diff --git a/tests/test_polyfield_base.py b/tests/test_polyfield_base.py index bd72eae..143627e 100644 --- a/tests/test_polyfield_base.py +++ b/tests/test_polyfield_base.py @@ -39,7 +39,7 @@ def test_explicit_polyfield_raises_for_nonunique_names(): with pytest.raises( ExplicitNamesNotUniqueError, - match=re.escape("{'same name': [, ]}"), + match=re.escape("{'same name': [, ]}"), ): ExplicitPolyField( class_to_schema_mapping={