Skip to content

Commit 5c63425

Browse files
committed
Polymorphic serializers refactor
1 parent 96cbab0 commit 5c63425

File tree

4 files changed

+158
-62
lines changed

4 files changed

+158
-62
lines changed

example/serializers.py

+9-49
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
from datetime import datetime
2-
from django.db.models.query import QuerySet
3-
from rest_framework.utils.serializer_helpers import BindingDict
4-
from rest_framework_json_api import serializers, relations, utils
2+
from rest_framework_json_api import serializers, relations
53
from example import models
64

75

@@ -41,15 +39,15 @@ def __init__(self, *args, **kwargs):
4139
}
4240

4341
body_format = serializers.SerializerMethodField()
44-
# many related from model
42+
# Many related from model
4543
comments = relations.ResourceRelatedField(
46-
source='comment_set', many=True, read_only=True)
47-
# many related from serializer
44+
source='comment_set', many=True, read_only=True)
45+
# Many related from serializer
4846
suggested = relations.SerializerMethodResourceRelatedField(
49-
source='get_suggested', model=models.Entry, many=True, read_only=True)
50-
# single related from serializer
47+
source='get_suggested', model=models.Entry, many=True, read_only=True)
48+
# Single related from serializer
5149
featured = relations.SerializerMethodResourceRelatedField(
52-
source='get_featured', model=models.Entry, read_only=True)
50+
source='get_featured', model=models.Entry, read_only=True)
5351

5452
def get_suggested(self, obj):
5553
return models.Entry.objects.exclude(pk=obj.pk)
@@ -108,51 +106,13 @@ class Meta:
108106
exclude = ('polymorphic_ctype',)
109107

110108

111-
class ProjectSerializer(serializers.ModelSerializer):
112-
113-
polymorphic_serializers = [
114-
{'model': models.ArtProject, 'serializer': ArtProjectSerializer},
115-
{'model': models.ResearchProject, 'serializer': ResearchProjectSerializer},
116-
]
109+
class ProjectSerializer(serializers.PolymorphicModelSerializer):
110+
polymorphic_serializers = [ArtProjectSerializer, ResearchProjectSerializer]
117111

118112
class Meta:
119113
model = models.Project
120114
exclude = ('polymorphic_ctype',)
121115

122-
def _get_actual_serializer_from_instance(self, instance):
123-
for info in self.polymorphic_serializers:
124-
if isinstance(instance, info.get('model')):
125-
actual_serializer = info.get('serializer')
126-
return actual_serializer(instance, context=self.context)
127-
128-
@property
129-
def fields(self):
130-
_fields = BindingDict(self)
131-
for key, value in self.get_fields().items():
132-
_fields[key] = value
133-
return _fields
134-
135-
def get_fields(self):
136-
if self.instance is not None:
137-
if not isinstance(self.instance, QuerySet):
138-
return self._get_actual_serializer_from_instance(self.instance).get_fields()
139-
else:
140-
raise Exception("Cannot get fields from a polymorphic serializer given a queryset")
141-
return super(ProjectSerializer, self).get_fields()
142-
143-
def to_representation(self, instance):
144-
# Handle polymorphism
145-
return self._get_actual_serializer_from_instance(instance).to_representation(instance)
146-
147-
def to_internal_value(self, data):
148-
data_type = data.get('type')
149-
for info in self.polymorphic_serializers:
150-
actual_serializer = info['serializer']
151-
if data_type == utils.get_resource_type_from_serializer(actual_serializer):
152-
self.__class__ = actual_serializer
153-
return actual_serializer(data, context=self.context).to_internal_value(data)
154-
raise Exception("Could not deserialize")
155-
156116

157117
class CompanySerializer(serializers.ModelSerializer):
158118
included_serializers = {

example/tests/integration/test_polymorphism.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,14 @@ def test_polymorphism_on_included_relations(single_company, client):
2929
assert content["data"]["relationships"]["currentProject"]["data"]["type"] == "artProjects"
3030
assert [rel["type"] for rel in content["data"]["relationships"]["futureProjects"]["data"]] == [
3131
"researchProjects", "artProjects"]
32-
assert [x.get('type') for x in content.get('included')] == ['artProjects', 'artProjects', 'researchProjects'], \
33-
'Detail included types are incorrect'
32+
assert [x.get('type') for x in content.get('included')] == [
33+
'artProjects', 'artProjects', 'researchProjects'], 'Detail included types are incorrect'
3434
# Ensure that the child fields are present.
3535
assert content.get('included')[0].get('attributes').get('artist') is not None
3636
assert content.get('included')[1].get('attributes').get('artist') is not None
3737
assert content.get('included')[2].get('attributes').get('supervisor') is not None
3838

39+
3940
def test_polymorphism_on_polymorphic_model_detail_patch(single_art_project, client):
4041
url = reverse("project-detail", kwargs={'pk': single_art_project.pk})
4142
response = client.get(url)
@@ -50,6 +51,7 @@ def test_polymorphism_on_polymorphic_model_detail_patch(single_art_project, clie
5051
assert new_content['data']['attributes']['topic'] == test_topic
5152
assert new_content['data']['attributes']['artist'] == test_artist
5253

54+
5355
def test_polymorphism_on_polymorphic_model_list_post(client):
5456
test_topic = 'New test topic {}'.format(random.randint(0, 999999))
5557
test_artist = 'test-{}'.format(random.randint(0, 999999))

rest_framework_json_api/serializers.py

+138-8
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,22 @@
11
import inflection
2+
3+
from django.db.models.query import QuerySet
24
from django.utils.translation import ugettext_lazy as _
5+
from django.utils import six
36
from rest_framework.exceptions import ParseError
47
from rest_framework.serializers import *
58

69
from rest_framework_json_api.relations import ResourceRelatedField
10+
from rest_framework_json_api.exceptions import Conflict
711
from rest_framework_json_api.utils import (
812
get_resource_type_from_model, get_resource_type_from_instance,
913
get_resource_type_from_serializer, get_included_serializers, get_included_resources)
1014

1115

1216
class ResourceIdentifierObjectSerializer(BaseSerializer):
1317
default_error_messages = {
14-
'incorrect_model_type': _('Incorrect model type. Expected {model_type}, received {received_type}.'),
18+
'incorrect_model_type': _('Incorrect model type. Expected {model_type}, '
19+
'received {received_type}.'),
1520
'does_not_exist': _('Invalid pk "{pk_value}" - object does not exist.'),
1621
'incorrect_type': _('Incorrect type. Expected pk value, received {data_type}.'),
1722
}
@@ -21,7 +26,8 @@ class ResourceIdentifierObjectSerializer(BaseSerializer):
2126
def __init__(self, *args, **kwargs):
2227
self.model_class = kwargs.pop('model_class', self.model_class)
2328
if 'instance' not in kwargs and not self.model_class:
24-
raise RuntimeError('ResourceIdentifierObjectsSerializer must be initialized with a model class.')
29+
raise RuntimeError(
30+
'ResourceIdentifierObjectsSerializer must be initialized with a model class.')
2531
super(ResourceIdentifierObjectSerializer, self).__init__(*args, **kwargs)
2632

2733
def to_representation(self, instance):
@@ -32,7 +38,8 @@ def to_representation(self, instance):
3238

3339
def to_internal_value(self, data):
3440
if data['type'] != get_resource_type_from_model(self.model_class):
35-
self.fail('incorrect_model_type', model_type=self.model_class, received_type=data['type'])
41+
self.fail(
42+
'incorrect_model_type', model_type=self.model_class, received_type=data['type'])
3643
pk = data['id']
3744
try:
3845
return self.model_class.objects.get(pk=pk)
@@ -48,15 +55,18 @@ def __init__(self, *args, **kwargs):
4855
request = context.get('request') if context else None
4956

5057
if request:
51-
sparse_fieldset_query_param = 'fields[{}]'.format(get_resource_type_from_serializer(self))
58+
sparse_fieldset_query_param = 'fields[{}]'.format(
59+
get_resource_type_from_serializer(self))
5260
try:
53-
param_name = next(key for key in request.query_params if sparse_fieldset_query_param in key)
61+
param_name = next(
62+
key for key in request.query_params if sparse_fieldset_query_param in key)
5463
except StopIteration:
5564
pass
5665
else:
5766
fieldset = request.query_params.get(param_name).split(',')
58-
# iterate over a *copy* of self.fields' underlying OrderedDict, because we may modify the
59-
# original during the iteration. self.fields is a `rest_framework.utils.serializer_helpers.BindingDict`
67+
# Iterate over a *copy* of self.fields' underlying OrderedDict, because we may
68+
# modify the original during the iteration.
69+
# self.fields is a `rest_framework.utils.serializer_helpers.BindingDict`
6070
for field_name, field in self.fields.fields.copy().items():
6171
if field_name == api_settings.URL_FIELD_NAME: # leave self link there
6272
continue
@@ -100,7 +110,8 @@ def validate_path(serializer_class, field_path, path):
100110
super(IncludedResourcesValidationMixin, self).__init__(*args, **kwargs)
101111

102112

103-
class HyperlinkedModelSerializer(IncludedResourcesValidationMixin, SparseFieldsetsMixin, HyperlinkedModelSerializer):
113+
class HyperlinkedModelSerializer(IncludedResourcesValidationMixin, SparseFieldsetsMixin,
114+
HyperlinkedModelSerializer):
104115
"""
105116
A type of `ModelSerializer` that uses hyperlinked relationships instead
106117
of primary key relationships. Specifically:
@@ -151,3 +162,122 @@ def get_field_names(self, declared_fields, info):
151162
declared[field_name] = field
152163
fields = super(ModelSerializer, self).get_field_names(declared, info)
153164
return list(fields) + list(getattr(self.Meta, 'meta_fields', list()))
165+
166+
167+
class PolymorphicSerializerMetaclass(SerializerMetaclass):
168+
"""
169+
This metaclass ensures that the `polymorphic_serializers` is correctly defined on a
170+
`PolymorphicSerializer` class and make a cache of model/serializer/type mappings.
171+
"""
172+
173+
def __new__(cls, name, bases, attrs):
174+
new_class = super(PolymorphicSerializerMetaclass, cls).__new__(cls, name, bases, attrs)
175+
176+
# Ensure initialization is only performed for subclasses of PolymorphicModelSerializer
177+
# (excluding PolymorphicModelSerializer class itself).
178+
parents = [b for b in bases if isinstance(b, PolymorphicSerializerMetaclass)]
179+
if not parents:
180+
return new_class
181+
182+
polymorphic_serializers = getattr(new_class, 'polymorphic_serializers', None)
183+
if not polymorphic_serializers:
184+
raise NotImplementedError(
185+
"A PolymorphicModelSerializer must define a `polymorphic_serializers` attribute.")
186+
serializer_to_model = {
187+
serializer: serializer.Meta.model for serializer in polymorphic_serializers}
188+
model_to_serializer = {
189+
serializer.Meta.model: serializer for serializer in polymorphic_serializers}
190+
type_to_model = {
191+
get_resource_type_from_model(model): model for model in model_to_serializer.keys()}
192+
setattr(new_class, '_poly_serializer_model_map', serializer_to_model)
193+
setattr(new_class, '_poly_model_serializer_map', model_to_serializer)
194+
setattr(new_class, '_poly_type_model_map', type_to_model)
195+
return new_class
196+
197+
198+
@six.add_metaclass(PolymorphicSerializerMetaclass)
199+
class PolymorphicModelSerializer(ModelSerializer):
200+
"""
201+
A serializer for polymorphic models.
202+
Useful for "lazy" parent models. Leaves should be represented with a regular serializer.
203+
"""
204+
def get_fields(self):
205+
"""
206+
Return an exhaustive list of the polymorphic serializer fields.
207+
"""
208+
if self.instance is not None:
209+
if not isinstance(self.instance, QuerySet):
210+
serializer_class = self.get_polymorphic_serializer_for_instance(self.instance)
211+
return serializer_class(self.instance, context=self.context).get_fields()
212+
else:
213+
raise Exception("Cannot get fields from a polymorphic serializer given a queryset")
214+
return super(PolymorphicModelSerializer, self).get_fields()
215+
216+
def get_polymorphic_serializer_for_instance(self, instance):
217+
"""
218+
Return the polymorphic serializer associated with the given instance/model.
219+
Raise `NotImplementedError` if no serializer is found for the given model. This usually
220+
means that a serializer is missing in the class's `polymorphic_serializers` attribute.
221+
"""
222+
try:
223+
return self._poly_model_serializer_map[instance._meta.model]
224+
except KeyError:
225+
raise NotImplementedError(
226+
"No polymorphic serializer has been found for model {}".format(
227+
instance._meta.model.__name__))
228+
229+
def get_polymorphic_model_for_serializer(self, serializer):
230+
"""
231+
Return the polymorphic model associated with the given serializer.
232+
Raise `NotImplementedError` if no model is found for the given serializer. This usually
233+
means that a serializer is missing in the class's `polymorphic_serializers` attribute.
234+
"""
235+
try:
236+
return self._poly_serializer_model_map[serializer]
237+
except KeyError:
238+
raise NotImplementedError(
239+
"No polymorphic model has been found for serializer {}".format(serializer.__name__))
240+
241+
def get_polymorphic_model_for_type(self, obj_type):
242+
"""
243+
Return the polymorphic model associated with the given type.
244+
Raise `NotImplementedError` if no model is found for the given type. This usually
245+
means that a serializer is missing in the class's `polymorphic_serializers` attribute.
246+
"""
247+
try:
248+
return self._poly_type_model_map[obj_type]
249+
except KeyError:
250+
raise NotImplementedError(
251+
"No polymorphic model has been found for type {}".format(obj_type))
252+
253+
def get_polymorphic_serializer_for_type(self, obj_type):
254+
"""
255+
Return the polymorphic serializer associated with the given type.
256+
Raise `NotImplementedError` if no serializer is found for the given type. This usually
257+
means that a serializer is missing in the class's `polymorphic_serializers` attribute.
258+
"""
259+
return self.get_polymorphic_serializer_for_instance(
260+
self.get_polymorphic_model_for_type(obj_type))
261+
262+
def to_representation(self, instance):
263+
"""
264+
Retrieve the appropriate polymorphic serializer and use this to handle representation.
265+
"""
266+
serializer_class = self.get_polymorphic_serializer_for_instance(instance)
267+
return serializer_class(instance, context=self.context).to_representation(instance)
268+
269+
def to_internal_value(self, data):
270+
"""
271+
Ensure that the given type is one of the expected polymorphic types, then retrieve the
272+
appropriate polymorphic serializer and use this to handle internal value.
273+
"""
274+
received_type = data.get('type')
275+
expected_types = self._poly_type_model_map.keys()
276+
if received_type not in expected_types:
277+
raise Conflict(
278+
'Incorrect relation type. Expected on of {expected_types}, '
279+
'received {received_type}.'.format(
280+
expected_types=', '.join(expected_types), received_type=received_type))
281+
serializer_class = self.get_polymorphic_serializer_for_type(received_type)
282+
self.__class__ = serializer_class
283+
return serializer_class(data, context=self.context).to_internal_value(data)

rest_framework_json_api/utils.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ def get_serializer_fields(serializer):
9292
pass
9393
return fields
9494

95+
9596
def format_keys(obj, format_type=None):
9697
"""
9798
Takes either a dict or list and returns it with camelized keys only if
@@ -147,12 +148,15 @@ def format_value(value, format_type=None):
147148

148149

149150
def format_relation_name(value, format_type=None):
150-
warnings.warn("The 'format_relation_name' function has been renamed 'format_resource_type' and the settings are now 'JSON_API_FORMAT_TYPES' and 'JSON_API_PLURALIZE_TYPES'")
151+
warnings.warn(
152+
"The 'format_relation_name' function has been renamed 'format_resource_type' and "
153+
"the settings are now 'JSON_API_FORMAT_TYPES' and 'JSON_API_PLURALIZE_TYPES'")
151154
if format_type is None:
152155
format_type = getattr(settings, 'JSON_API_FORMAT_RELATION_KEYS', None)
153156
pluralize = getattr(settings, 'JSON_API_PLURALIZE_RELATION_TYPE', None)
154157
return format_resource_type(value, format_type, pluralize)
155158

159+
156160
def format_resource_type(value, format_type=None, pluralize=None):
157161
if format_type is None:
158162
format_type = getattr(settings, 'JSON_API_FORMAT_TYPES', False)
@@ -243,8 +247,8 @@ def get_resource_type_from_serializer(serializer):
243247
json_api_meta = getattr(serializer, 'JSONAPIMeta', None)
244248
meta = getattr(serializer, 'Meta', None)
245249
if hasattr(serializer, 'polymorphic_serializers'):
246-
return [get_resource_type_from_serializer(s['serializer']) for s in serializer.polymorphic_serializers]
247-
if hasattr(json_api_meta, 'resource_name'):
250+
return [get_resource_type_from_serializer(s) for s in serializer.polymorphic_serializers]
251+
elif hasattr(json_api_meta, 'resource_name'):
248252
return json_api_meta.resource_name
249253
elif hasattr(meta, 'resource_name'):
250254
return meta.resource_name

0 commit comments

Comments
 (0)