Skip to content

Commit 681b5aa

Browse files
ograycodeleo-naeka
authored andcommitted
Adds the following features:
- support for post and patch request on polymorphic model endpoints. - makes polymorphic serializers give child fields instead of its own.
1 parent 8c73d95 commit 681b5aa

File tree

6 files changed

+163
-15
lines changed

6 files changed

+163
-15
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# -*- coding: utf-8 -*-
2+
# Generated by Django 1.9.6 on 2016-05-13 08:57
3+
from __future__ import unicode_literals
4+
5+
from django.db import migrations, models
6+
import django.db.models.deletion
7+
8+
9+
class Migration(migrations.Migration):
10+
11+
dependencies = [
12+
('contenttypes', '0002_remove_content_type_name'),
13+
('example', '0001_initial'),
14+
]
15+
16+
operations = [
17+
migrations.CreateModel(
18+
name='Company',
19+
fields=[
20+
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
21+
('name', models.CharField(max_length=100)),
22+
],
23+
),
24+
migrations.CreateModel(
25+
name='Project',
26+
fields=[
27+
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
28+
('topic', models.CharField(max_length=30)),
29+
],
30+
options={
31+
'abstract': False,
32+
},
33+
),
34+
migrations.CreateModel(
35+
name='ArtProject',
36+
fields=[
37+
('project_ptr', models.OneToOneField(auto_created=True, on_delete=django.db.models.deletion.CASCADE, parent_link=True, primary_key=True, serialize=False, to='example.Project')),
38+
('artist', models.CharField(max_length=30)),
39+
],
40+
options={
41+
'abstract': False,
42+
},
43+
bases=('example.project',),
44+
),
45+
migrations.CreateModel(
46+
name='ResearchProject',
47+
fields=[
48+
('project_ptr', models.OneToOneField(auto_created=True, on_delete=django.db.models.deletion.CASCADE, parent_link=True, primary_key=True, serialize=False, to='example.Project')),
49+
('supervisor', models.CharField(max_length=30)),
50+
],
51+
options={
52+
'abstract': False,
53+
},
54+
bases=('example.project',),
55+
),
56+
migrations.AddField(
57+
model_name='project',
58+
name='polymorphic_ctype',
59+
field=models.ForeignKey(editable=False, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='polymorphic_example.project_set+', to='contenttypes.ContentType'),
60+
),
61+
migrations.AddField(
62+
model_name='company',
63+
name='current_project',
64+
field=models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='companies', to='example.Project'),
65+
),
66+
migrations.AddField(
67+
model_name='company',
68+
name='future_projects',
69+
field=models.ManyToManyField(to='example.Project'),
70+
),
71+
]

example/serializers.py

+42-11
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from datetime import datetime
2-
from rest_framework_json_api import serializers, relations
2+
from django.db.models.query import QuerySet
3+
from rest_framework.utils.serializer_helpers import BindingDict
4+
from rest_framework_json_api import serializers, relations, utils
35
from example import models
46

57

@@ -44,13 +46,13 @@ def __init__(self, *args, **kwargs):
4446
source='comment_set', many=True, read_only=True)
4547
# many related from serializer
4648
suggested = relations.SerializerMethodResourceRelatedField(
47-
source='get_suggested', model=Entry, many=True, read_only=True)
49+
source='get_suggested', model=models.Entry, many=True, read_only=True)
4850
# single related from serializer
4951
featured = relations.SerializerMethodResourceRelatedField(
50-
source='get_featured', model=Entry, read_only=True)
52+
source='get_featured', model=models.Entry, read_only=True)
5153

5254
def get_suggested(self, obj):
53-
return models.Entry.objects.exclude(pk=obj.pk).first()
55+
return models.Entry.objects.exclude(pk=obj.pk)
5456

5557
def get_featured(self, obj):
5658
return models.Entry.objects.exclude(pk=obj.pk).first()
@@ -108,19 +110,48 @@ class Meta:
108110

109111
class ProjectSerializer(serializers.ModelSerializer):
110112

113+
polymorphic_serializers = [
114+
{'model': models.ArtProject, 'serializer': ArtProjectSerializer},
115+
{'model': models.ResearchProject, 'serializer': ResearchProjectSerializer},
116+
]
117+
111118
class Meta:
112119
model = models.Project
113120
exclude = ('polymorphic_ctype',)
114121

122+
def _get_actual_serializer_from_instance(self, instance):
123+
for info in self.polymorphic_serializers:
124+
if isinstance(instance, info.get('model')):
125+
actual_serializer = info.get('serializer')
126+
return actual_serializer(instance, context=self.context)
127+
128+
@property
129+
def fields(self):
130+
_fields = BindingDict(self)
131+
for key, value in self.get_fields().items():
132+
_fields[key] = value
133+
return _fields
134+
135+
def get_fields(self):
136+
if self.instance is not None:
137+
if not isinstance(self.instance, QuerySet):
138+
return self._get_actual_serializer_from_instance(self.instance).get_fields()
139+
else:
140+
raise Exception("Cannot get fields from a polymorphic serializer given a queryset")
141+
return super(ProjectSerializer, self).get_fields()
142+
115143
def to_representation(self, instance):
116144
# Handle polymorphism
117-
if isinstance(instance, models.ArtProject):
118-
return ArtProjectSerializer(
119-
instance, context=self.context).to_representation(instance)
120-
elif isinstance(instance, models.ResearchProject):
121-
return ResearchProjectSerializer(
122-
instance, context=self.context).to_representation(instance)
123-
return super(ProjectSerializer, self).to_representation(instance)
145+
return self._get_actual_serializer_from_instance(instance).to_representation(instance)
146+
147+
def to_internal_value(self, data):
148+
data_type = data.get('type')
149+
for info in self.polymorphic_serializers:
150+
actual_serializer = info['serializer']
151+
if data_type == utils.get_resource_type_from_serializer(actual_serializer):
152+
self.__class__ = actual_serializer
153+
return actual_serializer(data, context=self.context).to_internal_value(data)
154+
raise Exception("Could not deserialize")
124155

125156

126157
class CompanySerializer(serializers.ModelSerializer):

example/tests/integration/test_polymorphism.py

+40
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import pytest
2+
import random
3+
import json
24
from django.core.urlresolvers import reverse
35

46
from example.tests.utils import load_json
@@ -29,3 +31,41 @@ def test_polymorphism_on_included_relations(single_company, client):
2931
"researchProjects", "artProjects"]
3032
assert [x.get('type') for x in content.get('included')] == ['artProjects', 'artProjects', 'researchProjects'], \
3133
'Detail included types are incorrect'
34+
# Ensure that the child fields are present.
35+
assert content.get('included')[0].get('attributes').get('artist') is not None
36+
assert content.get('included')[1].get('attributes').get('artist') is not None
37+
assert content.get('included')[2].get('attributes').get('supervisor') is not None
38+
39+
def test_polymorphism_on_polymorphic_model_detail_patch(single_art_project, client):
40+
url = reverse("project-detail", kwargs={'pk': single_art_project.pk})
41+
response = client.get(url)
42+
content = load_json(response.content)
43+
test_topic = 'test-{}'.format(random.randint(0, 999999))
44+
test_artist = 'test-{}'.format(random.randint(0, 999999))
45+
content['data']['attributes']['topic'] = test_topic
46+
content['data']['attributes']['artist'] = test_artist
47+
response = client.patch(url, data=json.dumps(content), content_type='application/vnd.api+json')
48+
new_content = load_json(response.content)
49+
assert new_content["data"]["type"] == "artProjects"
50+
assert new_content['data']['attributes']['topic'] == test_topic
51+
assert new_content['data']['attributes']['artist'] == test_artist
52+
53+
def test_polymorphism_on_polymorphic_model_list_post(client):
54+
test_topic = 'New test topic {}'.format(random.randint(0, 999999))
55+
test_artist = 'test-{}'.format(random.randint(0, 999999))
56+
url = reverse('project-list')
57+
data = {
58+
'data': {
59+
'type': 'artProjects',
60+
'attributes': {
61+
'topic': test_topic,
62+
'artist': test_artist
63+
}
64+
}
65+
}
66+
response = client.post(url, data=json.dumps(data), content_type='application/vnd.api+json')
67+
content = load_json(response.content)
68+
assert content['data']['id'] is not None
69+
assert content["data"]["type"] == "artProjects"
70+
assert content['data']['attributes']['topic'] == test_topic
71+
assert content['data']['attributes']['artist'] == test_artist

rest_framework_json_api/parsers.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
22
Parsers
33
"""
4+
import six
45
from rest_framework import parsers
56
from rest_framework.exceptions import ParseError
67

@@ -80,7 +81,11 @@ def parse(self, stream, media_type=None, parser_context=None):
8081

8182
# Check for inconsistencies
8283
resource_name = utils.get_resource_name(parser_context)
83-
if data.get('type') != resource_name and request.method in ('PUT', 'POST', 'PATCH'):
84+
if isinstance(resource_name, six.string_types):
85+
doesnt_match = data.get('type') != resource_name
86+
else:
87+
doesnt_match = data.get('type') not in resource_name
88+
if doesnt_match and request.method in ('PUT', 'POST', 'PATCH'):
8489
raise exceptions.Conflict(
8590
"The resource object's type ({data_type}) is not the type "
8691
"that constitute the collection represented by the endpoint ({resource_type}).".format(
@@ -92,7 +97,7 @@ def parse(self, stream, media_type=None, parser_context=None):
9297
raise ParseError("The resource identifier object must contain an 'id' member")
9398

9499
# Construct the return data
95-
parsed_data = {'id': data.get('id')}
100+
parsed_data = {'id': data.get('id'), 'type': data.get('type')}
96101
parsed_data.update(self.parse_attributes(data))
97102
parsed_data.update(self.parse_relationships(data))
98103
parsed_data.update(self.parse_metadata(result))

rest_framework_json_api/renderers.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -311,8 +311,6 @@ def extract_included(fields, resource, resource_instance, included_resources):
311311
relation_type = utils.get_resource_type_from_serializer(serializer)
312312
relation_queryset = list(relation_instance)
313313

314-
# Get the serializer fields
315-
serializer_fields = utils.get_serializer_fields(serializer)
316314
if serializer_data:
317315
for position in range(len(serializer_data)):
318316
serializer_resource = serializer_data[position]
@@ -321,6 +319,7 @@ def extract_included(fields, resource, resource_instance, included_resources):
321319
relation_type or
322320
utils.get_resource_type_from_instance(nested_resource_instance)
323321
)
322+
serializer_fields = utils.get_serializer_fields(serializer.__class__(nested_resource_instance, context=serializer.context))
324323
included_data.append(
325324
JSONRenderer.build_json_resource_obj(
326325
serializer_fields, serializer_resource, nested_resource_instance, resource_type

rest_framework_json_api/utils.py

+2
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,8 @@ def get_resource_type_from_manager(manager):
242242
def get_resource_type_from_serializer(serializer):
243243
json_api_meta = getattr(serializer, 'JSONAPIMeta', None)
244244
meta = getattr(serializer, 'Meta', None)
245+
if hasattr(serializer, 'polymorphic_serializers'):
246+
return [get_resource_type_from_serializer(s['serializer']) for s in serializer.polymorphic_serializers]
245247
if hasattr(json_api_meta, 'resource_name'):
246248
return json_api_meta.resource_name
247249
elif hasattr(meta, 'resource_name'):

0 commit comments

Comments
 (0)