diff --git a/.gitignore b/.gitignore index 5c4c8515..ff0958f1 100644 --- a/.gitignore +++ b/.gitignore @@ -34,6 +34,8 @@ pip-delete-this-directory.txt # Tox .tox/ +.cache/ +.python-version # VirtualEnv .venv/ diff --git a/AUTHORS b/AUTHORS index 419f2e90..dfa060c0 100644 --- a/AUTHORS +++ b/AUTHORS @@ -1,7 +1,9 @@ Adam Wróbel Christian Zosel Greg Aker +Jamie Bliss Jerel Unruh +Léo S. Matt Layman Oliver Sauder Yaniv Peer diff --git a/CHANGELOG.md b/CHANGELOG.md index 5c157642..b78f0110 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,6 @@ v2.3.0 +* Added support for polymorphic models * When `JSON_API_FORMAT_KEYS` is False (the default) do not translate request attributes and relations to snake\_case format. This conversion was unexpected and there was no way to turn it off. diff --git a/docs/usage.md b/docs/usage.md index 55eb1996..f4faeea8 100644 --- a/docs/usage.md +++ b/docs/usage.md @@ -425,6 +425,59 @@ field_name_mapping = { ``` +### Working with polymorphic resources + +Polymorphic resources allow you to use specialized subclasses without requiring +special endpoints to expose the specialized versions. For example, if you had a +`Project` that could be either an `ArtProject` or a `ResearchProject`, you can +have both kinds at the same URL. + +DJA tests its polymorphic support against [django-polymorphic](https://django-polymorphic.readthedocs.io/en/stable/). +The polymorphic feature should also work with other popular libraries like +django-polymodels or django-typed-models. + +#### Writing polymorphic resources + +A polymorphic endpoint can be set up if associated with a polymorphic serializer. +A polymorphic serializer takes care of (de)serializing the correct instances types and can be defined like this: + +```python +class ProjectSerializer(serializers.PolymorphicModelSerializer): + polymorphic_serializers = [ArtProjectSerializer, ResearchProjectSerializer] + + class Meta: + model = models.Project +``` + +It must inherit from `serializers.PolymorphicModelSerializer` and define the `polymorphic_serializers` list. +This attribute defines the accepted resource types. + + +Polymorphic relations can also be handled with `relations.PolymorphicResourceRelatedField` like this: + +```python +class CompanySerializer(serializers.ModelSerializer): + current_project = relations.PolymorphicResourceRelatedField( + ProjectSerializer, queryset=models.Project.objects.all()) + future_projects = relations.PolymorphicResourceRelatedField( + ProjectSerializer, queryset=models.Project.objects.all(), many=True) + + class Meta: + model = models.Company +``` + +They must be explicitly declared with the `polymorphic_serializer` (first positional argument) correctly defined. +It must be a subclass of `serializers.PolymorphicModelSerializer`. + +
+ Note: + Polymorphic resources are not compatible with + + resource_name + + defined on the view. +
+ ### Meta You may add metadata to the rendered json in two different ways: `meta_fields` and `get_root_meta`. diff --git a/example/factories/__init__.py b/example/factories/__init__.py index de9a02fa..a7485500 100644 --- a/example/factories/__init__.py +++ b/example/factories/__init__.py @@ -2,7 +2,9 @@ import factory from faker import Factory as FakerFactory -from example.models import Blog, Author, AuthorBio, Entry, Comment, TaggedItem +from example.models import ( + Blog, Author, AuthorBio, Entry, Comment, TaggedItem, ArtProject, ResearchProject, Company +) faker = FakerFactory.create() faker.seed(983843) @@ -68,3 +70,35 @@ class Meta: content_object = factory.SubFactory(EntryFactory) tag = factory.LazyAttribute(lambda x: faker.word()) + + +class ArtProjectFactory(factory.django.DjangoModelFactory): + class Meta: + model = ArtProject + + topic = factory.LazyAttribute(lambda x: faker.catch_phrase()) + artist = factory.LazyAttribute(lambda x: faker.name()) + + +class ResearchProjectFactory(factory.django.DjangoModelFactory): + class Meta: + model = ResearchProject + + topic = factory.LazyAttribute(lambda x: faker.catch_phrase()) + supervisor = factory.LazyAttribute(lambda x: faker.name()) + + +class CompanyFactory(factory.django.DjangoModelFactory): + class Meta: + model = Company + + name = factory.LazyAttribute(lambda x: faker.company()) + current_project = factory.SubFactory(ArtProjectFactory) + + @factory.post_generation + def future_projects(self, create, extracted, **kwargs): + if not create: + return + if extracted: + for project in extracted: + self.future_projects.add(project) diff --git a/example/migrations/0003_polymorphics.py b/example/migrations/0003_polymorphics.py new file mode 100644 index 00000000..9020176b --- /dev/null +++ b/example/migrations/0003_polymorphics.py @@ -0,0 +1,76 @@ +# -*- coding: utf-8 -*- +# Generated by Django 1.11.1 on 2017-05-17 14:49 +from __future__ import unicode_literals + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + ('contenttypes', '0002_remove_content_type_name'), + ('example', '0002_taggeditem'), + ] + + operations = [ + migrations.CreateModel( + name='Company', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('name', models.CharField(max_length=100)), + ], + ), + migrations.CreateModel( + name='Project', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('topic', models.CharField(max_length=30)), + ], + options={ + 'abstract': False, + }, + ), + migrations.AlterField( + model_name='comment', + name='entry', + field=models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='comments', to='example.Entry'), + ), + migrations.CreateModel( + name='ArtProject', + fields=[ + ('project_ptr', models.OneToOneField(auto_created=True, on_delete=django.db.models.deletion.CASCADE, parent_link=True, primary_key=True, serialize=False, to='example.Project')), + ('artist', models.CharField(max_length=30)), + ], + options={ + 'abstract': False, + }, + bases=('example.project',), + ), + migrations.CreateModel( + name='ResearchProject', + fields=[ + ('project_ptr', models.OneToOneField(auto_created=True, on_delete=django.db.models.deletion.CASCADE, parent_link=True, primary_key=True, serialize=False, to='example.Project')), + ('supervisor', models.CharField(max_length=30)), + ], + options={ + 'abstract': False, + }, + bases=('example.project',), + ), + migrations.AddField( + model_name='project', + name='polymorphic_ctype', + field=models.ForeignKey(editable=False, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='polymorphic_example.project_set+', to='contenttypes.ContentType'), + ), + migrations.AddField( + model_name='company', + name='current_project', + field=models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='companies', to='example.Project'), + ), + migrations.AddField( + model_name='company', + name='future_projects', + field=models.ManyToManyField(to='example.Project'), + ), + ] diff --git a/example/models.py b/example/models.py index 6442b0e4..f7e8ac7d 100644 --- a/example/models.py +++ b/example/models.py @@ -6,6 +6,7 @@ from django.contrib.contenttypes.fields import GenericRelation from django.db import models from django.utils.encoding import python_2_unicode_compatible +from polymorphic.models import PolymorphicModel class BaseModel(models.Model): @@ -86,3 +87,25 @@ class Comment(BaseModel): def __str__(self): return self.body + + +class Project(PolymorphicModel): + topic = models.CharField(max_length=30) + + +class ArtProject(Project): + artist = models.CharField(max_length=30) + + +class ResearchProject(Project): + supervisor = models.CharField(max_length=30) + + +@python_2_unicode_compatible +class Company(models.Model): + name = models.CharField(max_length=100) + current_project = models.ForeignKey(Project, related_name='companies') + future_projects = models.ManyToManyField(Project) + + def __str__(self): + return self.name diff --git a/example/serializers.py b/example/serializers.py index 0dfc49b4..d0bc5a35 100644 --- a/example/serializers.py +++ b/example/serializers.py @@ -1,6 +1,12 @@ from datetime import datetime + +import rest_framework from rest_framework_json_api import serializers, relations -from example.models import Blog, Entry, Author, AuthorBio, Comment, TaggedItem +from packaging import version +from example.models import ( + Blog, Entry, Author, AuthorBio, Comment, TaggedItem, Project, ArtProject, ResearchProject, + Company, +) class TaggedItemSerializer(serializers.ModelSerializer): @@ -115,3 +121,40 @@ class Meta: model = Comment exclude = ('created_at', 'modified_at',) # fields = ('entry', 'body', 'author',) + + +class ArtProjectSerializer(serializers.ModelSerializer): + class Meta: + model = ArtProject + exclude = ('polymorphic_ctype',) + + +class ResearchProjectSerializer(serializers.ModelSerializer): + class Meta: + model = ResearchProject + exclude = ('polymorphic_ctype',) + + +class ProjectSerializer(serializers.PolymorphicModelSerializer): + polymorphic_serializers = [ArtProjectSerializer, ResearchProjectSerializer] + + class Meta: + model = Project + exclude = ('polymorphic_ctype',) + + +class CompanySerializer(serializers.ModelSerializer): + current_project = relations.PolymorphicResourceRelatedField( + ProjectSerializer, queryset=Project.objects.all()) + future_projects = relations.PolymorphicResourceRelatedField( + ProjectSerializer, queryset=Project.objects.all(), many=True) + + included_serializers = { + 'current_project': ProjectSerializer, + 'future_projects': ProjectSerializer, + } + + class Meta: + model = Company + if version.parse(rest_framework.VERSION) >= version.parse('3.3'): + fields = '__all__' diff --git a/example/settings/dev.py b/example/settings/dev.py index 3cc1d6e1..c5a1f742 100644 --- a/example/settings/dev.py +++ b/example/settings/dev.py @@ -23,6 +23,7 @@ 'django.contrib.auth', 'django.contrib.admin', 'rest_framework', + 'polymorphic', 'example', ] diff --git a/example/tests/conftest.py b/example/tests/conftest.py index acdc8543..9db8edc1 100644 --- a/example/tests/conftest.py +++ b/example/tests/conftest.py @@ -3,7 +3,7 @@ from example.factories import ( BlogFactory, AuthorFactory, AuthorBioFactory, EntryFactory, CommentFactory, - TaggedItemFactory + TaggedItemFactory, ArtProjectFactory, ResearchProjectFactory, CompanyFactory, ) register(BlogFactory) @@ -12,6 +12,9 @@ register(EntryFactory) register(CommentFactory) register(TaggedItemFactory) +register(ArtProjectFactory) +register(ResearchProjectFactory) +register(CompanyFactory) @pytest.fixture @@ -33,3 +36,14 @@ def multiple_entries(blog_factory, author_factory, entry_factory, comment_factor comment_factory(entry=entries[0]) comment_factory(entry=entries[1]) return entries + + +@pytest.fixture +def single_company(art_project_factory, research_project_factory, company_factory): + company = company_factory(future_projects=(research_project_factory(), art_project_factory())) + return company + + +@pytest.fixture +def single_art_project(art_project_factory): + return art_project_factory() diff --git a/example/tests/integration/test_polymorphism.py b/example/tests/integration/test_polymorphism.py new file mode 100644 index 00000000..8612319a --- /dev/null +++ b/example/tests/integration/test_polymorphism.py @@ -0,0 +1,152 @@ +import pytest +import random +import json +from django.core.urlresolvers import reverse + +from example.tests.utils import load_json + +pytestmark = pytest.mark.django_db + + +def test_polymorphism_on_detail(single_art_project, client): + response = client.get(reverse("project-detail", kwargs={'pk': single_art_project.pk})) + content = load_json(response.content) + assert content["data"]["type"] == "artProjects" + + +def test_polymorphism_on_detail_relations(single_company, client): + response = client.get(reverse("company-detail", kwargs={'pk': single_company.pk})) + content = load_json(response.content) + assert content["data"]["relationships"]["currentProject"]["data"]["type"] == "artProjects" + assert ( + set([rel["type"] for rel in content["data"]["relationships"]["futureProjects"]["data"]]) == + set(["researchProjects", "artProjects"]) + ) + + +def test_polymorphism_on_included_relations(single_company, client): + response = client.get(reverse("company-detail", kwargs={'pk': single_company.pk}) + + '?include=current_project,future_projects') + content = load_json(response.content) + assert content["data"]["relationships"]["currentProject"]["data"]["type"] == "artProjects" + assert ( + set([rel["type"] for rel in content["data"]["relationships"]["futureProjects"]["data"]]) == + set(["researchProjects", "artProjects"]) + ) + assert set([x.get('type') for x in content.get('included')]) == set([ + 'artProjects', 'artProjects', 'researchProjects']), 'Detail included types are incorrect' + # Ensure that the child fields are present. + assert content.get('included')[0].get('attributes').get('artist') is not None + assert content.get('included')[1].get('attributes').get('artist') is not None + assert content.get('included')[2].get('attributes').get('supervisor') is not None + + +def test_polymorphism_on_polymorphic_model_detail_patch(single_art_project, client): + url = reverse("project-detail", kwargs={'pk': single_art_project.pk}) + response = client.get(url) + content = load_json(response.content) + test_topic = 'test-{}'.format(random.randint(0, 999999)) + test_artist = 'test-{}'.format(random.randint(0, 999999)) + content['data']['attributes']['topic'] = test_topic + content['data']['attributes']['artist'] = test_artist + response = client.patch(url, data=json.dumps(content), content_type='application/vnd.api+json') + new_content = load_json(response.content) + assert new_content['data']['type'] == "artProjects" + assert new_content['data']['attributes']['topic'] == test_topic + assert new_content['data']['attributes']['artist'] == test_artist + + +def test_polymorphism_on_polymorphic_model_list_post(client): + test_topic = 'New test topic {}'.format(random.randint(0, 999999)) + test_artist = 'test-{}'.format(random.randint(0, 999999)) + url = reverse('project-list') + data = { + 'data': { + 'type': 'artProjects', + 'attributes': { + 'topic': test_topic, + 'artist': test_artist + } + } + } + response = client.post(url, data=json.dumps(data), content_type='application/vnd.api+json') + content = load_json(response.content) + assert content['data']['id'] is not None + assert content['data']['type'] == "artProjects" + assert content['data']['attributes']['topic'] == test_topic + assert content['data']['attributes']['artist'] == test_artist + + +def test_invalid_type_on_polymorphic_model(client): + test_topic = 'New test topic {}'.format(random.randint(0, 999999)) + test_artist = 'test-{}'.format(random.randint(0, 999999)) + url = reverse('project-list') + data = { + 'data': { + 'type': 'invalidProjects', + 'attributes': { + 'topic': test_topic, + 'artist': test_artist + } + } + } + response = client.post(url, data=json.dumps(data), content_type='application/vnd.api+json') + assert response.status_code == 409 + content = load_json(response.content) + assert len(content["errors"]) is 1 + assert content["errors"][0]["status"] == "409" + try: + assert content["errors"][0]["detail"] == \ + "The resource object's type (invalidProjects) is not the type that constitute the " \ + "collection represented by the endpoint (one of [researchProjects, artProjects])." + except AssertionError: + # Available type list order isn't enforced + assert content["errors"][0]["detail"] == \ + "The resource object's type (invalidProjects) is not the type that constitute the " \ + "collection represented by the endpoint (one of [artProjects, researchProjects])." + + +def test_polymorphism_relations_update(single_company, research_project_factory, client): + response = client.get(reverse("company-detail", kwargs={'pk': single_company.pk})) + content = load_json(response.content) + assert content["data"]["relationships"]["currentProject"]["data"]["type"] == "artProjects" + + research_project = research_project_factory() + content["data"]["relationships"]["currentProject"]["data"] = { + "type": "researchProjects", + "id": research_project.pk + } + response = client.put(reverse("company-detail", kwargs={'pk': single_company.pk}), + data=json.dumps(content), content_type='application/vnd.api+json') + assert response.status_code == 200 + content = load_json(response.content) + assert content["data"]["relationships"]["currentProject"]["data"]["type"] == "researchProjects" + assert int(content["data"]["relationships"]["currentProject"]["data"]["id"]) == \ + research_project.pk + + +def test_invalid_type_on_polymorphic_relation(single_company, research_project_factory, client): + response = client.get(reverse("company-detail", kwargs={'pk': single_company.pk})) + content = load_json(response.content) + assert content["data"]["relationships"]["currentProject"]["data"]["type"] == "artProjects" + + research_project = research_project_factory() + content["data"]["relationships"]["currentProject"]["data"] = { + "type": "invalidProjects", + "id": research_project.pk + } + response = client.put(reverse("company-detail", kwargs={'pk': single_company.pk}), + data=json.dumps(content), content_type='application/vnd.api+json') + assert response.status_code == 409 + content = load_json(response.content) + assert len(content["errors"]) is 1 + assert content["errors"][0]["status"] == "409" + try: + assert content["errors"][0]["detail"] == \ + "Incorrect relation type. Expected one of [researchProjects, artProjects], " \ + "received invalidProjects." + except AssertionError: + # Available type list order isn't enforced + assert content["errors"][0]["detail"] == \ + "Incorrect relation type. Expected one of [artProjects, researchProjects], " \ + "received invalidProjects." diff --git a/example/urls.py b/example/urls.py index f48135c7..4443960f 100644 --- a/example/urls.py +++ b/example/urls.py @@ -1,7 +1,8 @@ from django.conf.urls import include, url from rest_framework import routers -from example.views import BlogViewSet, EntryViewSet, AuthorViewSet, CommentViewSet +from example.views import ( + BlogViewSet, EntryViewSet, AuthorViewSet, CommentViewSet, CompanyViewset, ProjectViewset) router = routers.DefaultRouter(trailing_slash=False) @@ -9,6 +10,8 @@ router.register(r'entries', EntryViewSet) router.register(r'authors', AuthorViewSet) router.register(r'comments', CommentViewSet) +router.register(r'companies', CompanyViewset) +router.register(r'projects', ProjectViewset) urlpatterns = [ url(r'^', include(router.urls)), diff --git a/example/urls_test.py b/example/urls_test.py index 2d569c16..e6555ec8 100644 --- a/example/urls_test.py +++ b/example/urls_test.py @@ -3,7 +3,8 @@ from example.views import ( BlogViewSet, EntryViewSet, AuthorViewSet, CommentViewSet, EntryRelationshipView, - BlogRelationshipView, CommentRelationshipView, AuthorRelationshipView + BlogRelationshipView, CommentRelationshipView, AuthorRelationshipView, + CompanyViewset, ProjectViewset, ) from .api.resources.identity import Identity, GenericIdentity @@ -13,6 +14,8 @@ router.register(r'entries', EntryViewSet) router.register(r'authors', AuthorViewSet) router.register(r'comments', CommentViewSet) +router.register(r'companies', CompanyViewset) +router.register(r'projects', ProjectViewset) # for the old tests router.register(r'identities', Identity) diff --git a/example/views.py b/example/views.py index 67cb7f67..4f28d8c7 100644 --- a/example/views.py +++ b/example/views.py @@ -5,9 +5,11 @@ import rest_framework_json_api.parsers import rest_framework_json_api.renderers from rest_framework_json_api.views import ModelViewSet, RelationshipView -from example.models import Blog, Entry, Author, Comment +from example.models import Blog, Entry, Author, Comment, Company, Project from example.serializers import ( - BlogSerializer, EntrySerializer, AuthorSerializer, CommentSerializer) + BlogSerializer, EntrySerializer, AuthorSerializer, CommentSerializer, CompanySerializer, + ProjectSerializer, +) from rest_framework_json_api.utils import format_drf_errors @@ -71,6 +73,16 @@ class CommentViewSet(ModelViewSet): serializer_class = CommentSerializer +class CompanyViewset(ModelViewSet): + queryset = Company.objects.all() + serializer_class = CompanySerializer + + +class ProjectViewset(ModelViewSet): + queryset = Project.objects.all() + serializer_class = ProjectSerializer + + class EntryRelationshipView(RelationshipView): queryset = Entry.objects.all() diff --git a/requirements-development.txt b/requirements-development.txt index 30ae754a..65fdf9e1 100644 --- a/requirements-development.txt +++ b/requirements-development.txt @@ -6,5 +6,6 @@ Faker recommonmark Sphinx sphinx_rtd_theme +django-polymorphic tox mock diff --git a/rest_framework_json_api/parsers.py b/rest_framework_json_api/parsers.py index 2f74f495..68ec45d2 100644 --- a/rest_framework_json_api/parsers.py +++ b/rest_framework_json_api/parsers.py @@ -1,11 +1,11 @@ """ Parsers """ +from django.conf import settings +from django.utils import six from rest_framework import parsers from rest_framework.exceptions import ParseError -from django.conf import settings - from . import utils, renderers, exceptions @@ -106,21 +106,31 @@ def parse(self, stream, media_type=None, parser_context=None): request = parser_context.get('request') # Check for inconsistencies - resource_name = utils.get_resource_name(parser_context) - if data.get('type') != resource_name and request.method in ('PUT', 'POST', 'PATCH'): - raise exceptions.Conflict( - "The resource object's type ({data_type}) is not the type " - "that constitute the collection represented by the endpoint " - "({resource_type}).".format( - data_type=data.get('type'), - resource_type=resource_name - ) - ) + if request.method in ('PUT', 'POST', 'PATCH'): + resource_name = utils.get_resource_name( + parser_context, expand_polymorphic_types=True) + if isinstance(resource_name, six.string_types): + if data.get('type') != resource_name: + raise exceptions.Conflict( + "The resource object's type ({data_type}) is not the type that " + "constitute the collection represented by the endpoint " + "({resource_type}).".format( + data_type=data.get('type'), + resource_type=resource_name)) + else: + if data.get('type') not in resource_name: + raise exceptions.Conflict( + "The resource object's type ({data_type}) is not the type that " + "constitute the collection represented by the endpoint " + "(one of [{resource_types}]).".format( + data_type=data.get('type'), + resource_types=", ".join(resource_name))) if not data.get('id') and request.method in ('PATCH', 'PUT'): raise ParseError("The resource identifier object must contain an 'id' member") # Construct the return data parsed_data = {'id': data.get('id')} if 'id' in data else {} + parsed_data['type'] = data.get('type') parsed_data.update(self.parse_attributes(data)) parsed_data.update(self.parse_relationships(data)) parsed_data.update(self.parse_metadata(result)) diff --git a/rest_framework_json_api/relations.py b/rest_framework_json_api/relations.py index b488b986..a697815e 100644 --- a/rest_framework_json_api/relations.py +++ b/rest_framework_json_api/relations.py @@ -21,6 +21,7 @@ class ResourceRelatedField(PrimaryKeyRelatedField): + _skip_polymorphic_optimization = True self_link_view_name = None related_link_view_name = None related_link_lookup_field = 'pk' @@ -168,7 +169,7 @@ def to_representation(self, value): pk = value.pk resource_type = self.get_resource_type_from_included_serializer() - if resource_type is None: + if resource_type is None or not self._skip_polymorphic_optimization: resource_type = get_resource_type_from_instance(value) return OrderedDict([('type', resource_type), ('id', str(pk))]) @@ -224,6 +225,48 @@ def get_choices(self, cutoff=None): ]) +class PolymorphicResourceRelatedField(ResourceRelatedField): + """ + Inform DRF that the relation must be considered polymorphic. + Takes a `polymorphic_serializer` as the first positional argument to + retrieve then validate the accepted types set. + """ + + _skip_polymorphic_optimization = False + default_error_messages = dict(ResourceRelatedField.default_error_messages, **{ + 'incorrect_relation_type': _('Incorrect relation type. Expected one of [{relation_type}], ' + 'received {received_type}.'), + }) + + def __init__(self, polymorphic_serializer, *args, **kwargs): + self.polymorphic_serializer = polymorphic_serializer + super(PolymorphicResourceRelatedField, self).__init__(*args, **kwargs) + + def to_internal_value(self, data): + if isinstance(data, six.text_type): + try: + data = json.loads(data) + except ValueError: + # show a useful error if they send a `pk` instead of resource object + self.fail('incorrect_type', data_type=type(data).__name__) + if not isinstance(data, dict): + self.fail('incorrect_type', data_type=type(data).__name__) + + if 'type' not in data: + self.fail('missing_type') + + if 'id' not in data: + self.fail('missing_id') + + expected_relation_types = self.polymorphic_serializer.get_polymorphic_types() + + if data['type'] not in expected_relation_types: + self.conflict('incorrect_relation_type', relation_type=", ".join( + expected_relation_types), received_type=data['type']) + + return super(ResourceRelatedField, self).to_internal_value(data['id']) + + class SerializerMethodResourceRelatedField(ResourceRelatedField): """ Allows us to use serializer method RelatedFields diff --git a/rest_framework_json_api/renderers.py b/rest_framework_json_api/renderers.py index 61a9238c..4c325525 100644 --- a/rest_framework_json_api/renderers.py +++ b/rest_framework_json_api/renderers.py @@ -12,7 +12,7 @@ from rest_framework.serializers import BaseSerializer, Serializer, ListSerializer from rest_framework.settings import api_settings -from . import utils +from rest_framework_json_api import utils class JSONRenderer(renderers.JSONRenderer): @@ -344,8 +344,6 @@ def extract_included(cls, fields, resource, resource_instance, included_resource relation_type = utils.get_resource_type_from_serializer(serializer) relation_queryset = list(relation_instance) - # Get the serializer fields - serializer_fields = utils.get_serializer_fields(serializer) if serializer_data: for position in range(len(serializer_data)): serializer_resource = serializer_data[position] @@ -354,12 +352,18 @@ def extract_included(cls, fields, resource, resource_instance, included_resource relation_type or utils.get_resource_type_from_instance(nested_resource_instance) ) + serializer_fields = utils.get_serializer_fields( + serializer.__class__( + nested_resource_instance, context=serializer.context + ) + ) included_data.append( cls.build_json_resource_obj( serializer_fields, serializer_resource, nested_resource_instance, - resource_type + resource_type, + getattr(serializer, '_poly_force_type_resolution', False) ) ) included_data.extend( @@ -381,7 +385,8 @@ def extract_included(cls, fields, resource, resource_instance, included_resource included_data.append( cls.build_json_resource_obj( serializer_fields, serializer_data, - relation_instance, relation_type) + relation_instance, relation_type, + getattr(field, '_poly_force_type_resolution', False)) ) included_data.extend( cls.extract_included( @@ -423,7 +428,11 @@ def extract_root_meta(cls, serializer, resource): return data @classmethod - def build_json_resource_obj(cls, fields, resource, resource_instance, resource_name): + def build_json_resource_obj(cls, fields, resource, resource_instance, resource_name, + force_type_resolution=False): + # Determine type from the instance if the underlying model is polymorphic + if force_type_resolution: + resource_name = utils.get_resource_type_from_instance(resource_instance) resource_data = [ ('type', resource_name), ('id', encoding.force_text(resource_instance.pk) if resource_instance else None), @@ -506,6 +515,9 @@ def render(self, data, accepted_media_type=None, renderer_context=None): # Get the serializer fields fields = utils.get_serializer_fields(serializer) + # Determine if resource name must be resolved on each instance (polymorphic serializer) + force_type_resolution = getattr(serializer, '_poly_force_type_resolution', False) + # Extract root meta for any type of serializer json_api_meta.update(self.extract_root_meta(serializer, serializer_data)) @@ -517,7 +529,7 @@ def render(self, data, accepted_media_type=None, renderer_context=None): resource_instance = serializer.instance[position] # Get current instance json_resource_obj = self.build_json_resource_obj( - fields, resource, resource_instance, resource_name + fields, resource, resource_instance, resource_name, force_type_resolution ) meta = self.extract_meta(serializer, resource) if meta: @@ -532,7 +544,7 @@ def render(self, data, accepted_media_type=None, renderer_context=None): else: resource_instance = serializer.instance json_api_data = self.build_json_resource_obj( - fields, serializer_data, resource_instance, resource_name + fields, serializer_data, resource_instance, resource_name, force_type_resolution ) meta = self.extract_meta(serializer, serializer_data) diff --git a/rest_framework_json_api/serializers.py b/rest_framework_json_api/serializers.py index b2a69c9d..d4808bec 100644 --- a/rest_framework_json_api/serializers.py +++ b/rest_framework_json_api/serializers.py @@ -1,9 +1,13 @@ import inflection + +from django.db.models.query import QuerySet from django.utils.translation import ugettext_lazy as _ +from django.utils import six from rest_framework.exceptions import ParseError from rest_framework.serializers import * # noqa: F403 from rest_framework_json_api.relations import ResourceRelatedField +from rest_framework_json_api.exceptions import Conflict from rest_framework_json_api.utils import ( get_resource_type_from_model, get_resource_type_from_instance, get_resource_type_from_serializer, get_included_serializers, get_included_resources) @@ -163,3 +167,140 @@ def get_field_names(self, declared_fields, info): declared[field_name] = field fields = super(ModelSerializer, self).get_field_names(declared, info) return list(fields) + list(getattr(self.Meta, 'meta_fields', list())) + + +class PolymorphicSerializerMetaclass(SerializerMetaclass): + """ + This metaclass ensures that the `polymorphic_serializers` is correctly defined on a + `PolymorphicSerializer` class and make a cache of model/serializer/type mappings. + """ + + def __new__(cls, name, bases, attrs): + new_class = super(PolymorphicSerializerMetaclass, cls).__new__(cls, name, bases, attrs) + + # Ensure initialization is only performed for subclasses of PolymorphicModelSerializer + # (excluding PolymorphicModelSerializer class itself). + parents = [b for b in bases if isinstance(b, PolymorphicSerializerMetaclass)] + if not parents: + return new_class + + polymorphic_serializers = getattr(new_class, 'polymorphic_serializers', None) + if not polymorphic_serializers: + raise NotImplementedError( + "A PolymorphicModelSerializer must define a `polymorphic_serializers` attribute.") + serializer_to_model = { + serializer: serializer.Meta.model for serializer in polymorphic_serializers} + model_to_serializer = { + serializer.Meta.model: serializer for serializer in polymorphic_serializers} + type_to_serializer = { + get_resource_type_from_serializer(serializer): serializer for + serializer in polymorphic_serializers} + new_class._poly_serializer_model_map = serializer_to_model + new_class._poly_model_serializer_map = model_to_serializer + new_class._poly_type_serializer_map = type_to_serializer + new_class._poly_force_type_resolution = True + + # Flag each linked polymorphic serializer to force type resolution based on instance + for serializer in polymorphic_serializers: + serializer._poly_force_type_resolution = True + + return new_class + + +@six.add_metaclass(PolymorphicSerializerMetaclass) +class PolymorphicModelSerializer(ModelSerializer): + """ + A serializer for polymorphic models. + Useful for "lazy" parent models. Leaves should be represented with a regular serializer. + """ + def get_fields(self): + """ + Return an exhaustive list of the polymorphic serializer fields. + """ + if self.instance is not None: + if not isinstance(self.instance, QuerySet): + serializer_class = self.get_polymorphic_serializer_for_instance(self.instance) + return serializer_class(self.instance, context=self.context).get_fields() + else: + raise Exception("Cannot get fields from a polymorphic serializer given a queryset") + return super(PolymorphicModelSerializer, self).get_fields() + + @classmethod + def get_polymorphic_serializer_for_instance(cls, instance): + """ + Return the polymorphic serializer associated with the given instance/model. + Raise `NotImplementedError` if no serializer is found for the given model. This usually + means that a serializer is missing in the class's `polymorphic_serializers` attribute. + """ + try: + return cls._poly_model_serializer_map[instance._meta.model] + except KeyError: + raise NotImplementedError( + "No polymorphic serializer has been found for model {}".format( + instance._meta.model.__name__)) + + @classmethod + def get_polymorphic_model_for_serializer(cls, serializer): + """ + Return the polymorphic model associated with the given serializer. + Raise `NotImplementedError` if no model is found for the given serializer. This usually + means that a serializer is missing in the class's `polymorphic_serializers` attribute. + """ + try: + return cls._poly_serializer_model_map[serializer] + except KeyError: + raise NotImplementedError( + "No polymorphic model has been found for serializer {}".format(serializer.__name__)) + + @classmethod + def get_polymorphic_serializer_for_type(cls, obj_type): + """ + Return the polymorphic serializer associated with the given type. + Raise `NotImplementedError` if no serializer is found for the given type. This usually + means that a serializer is missing in the class's `polymorphic_serializers` attribute. + """ + try: + return cls._poly_type_serializer_map[obj_type] + except KeyError: + raise NotImplementedError( + "No polymorphic serializer has been found for type {}".format(obj_type)) + + @classmethod + def get_polymorphic_model_for_type(cls, obj_type): + """ + Return the polymorphic model associated with the given type. + Raise `NotImplementedError` if no model is found for the given type. This usually + means that a serializer is missing in the class's `polymorphic_serializers` attribute. + """ + return cls.get_polymorphic_model_for_serializer( + cls.get_polymorphic_serializer_for_type(obj_type)) + + @classmethod + def get_polymorphic_types(cls): + """ + Return the list of accepted types. + """ + return cls._poly_type_serializer_map.keys() + + def to_representation(self, instance): + """ + Retrieve the appropriate polymorphic serializer and use this to handle representation. + """ + serializer_class = self.get_polymorphic_serializer_for_instance(instance) + return serializer_class(instance, context=self.context).to_representation(instance) + + def to_internal_value(self, data): + """ + Ensure that the given type is one of the expected polymorphic types, then retrieve the + appropriate polymorphic serializer and use this to handle internal value. + """ + received_type = data.get('type') + expected_types = self.get_polymorphic_types() + if received_type not in expected_types: + raise Conflict( + 'Incorrect relation type. Expected on of [{expected_types}], ' + 'received {received_type}.'.format( + expected_types=', '.join(expected_types), received_type=received_type)) + serializer_class = self.get_polymorphic_serializer_for_type(received_type) + self.__class__ = serializer_class + return serializer_class(data, context=self.context).to_internal_value(data) diff --git a/rest_framework_json_api/utils.py b/rest_framework_json_api/utils.py index f2eefa2b..a8cf40e2 100644 --- a/rest_framework_json_api/utils.py +++ b/rest_framework_json_api/utils.py @@ -29,32 +29,35 @@ if django.VERSION >= (1, 9): from django.db.models.fields.related_descriptors import ( - ManyToManyDescriptor, ReverseManyToOneDescriptor + ManyToManyDescriptor, ReverseManyToOneDescriptor # noqa: F401 ) ReverseManyRelatedObjectsDescriptor = object() else: - from django.db.models.fields.related import ManyRelatedObjectsDescriptor as ManyToManyDescriptor + from django.db.models.fields.related import ( # noqa: F401 + ManyRelatedObjectsDescriptor as ManyToManyDescriptor + ) from django.db.models.fields.related import ( ForeignRelatedObjectsDescriptor as ReverseManyToOneDescriptor ) - from django.db.models.fields.related import ReverseManyRelatedObjectsDescriptor + from django.db.models.fields.related import ReverseManyRelatedObjectsDescriptor # noqa: F401 # Generic relation descriptor from django.contrib.contenttypes. if 'django.contrib.contenttypes' not in settings.INSTALLED_APPS: # pragma: no cover # Target application does not use contenttypes. Importing would cause errors. ReverseGenericManyToOneDescriptor = object() elif django.VERSION >= (1, 9): - from django.contrib.contenttypes.fields import ReverseGenericManyToOneDescriptor + from django.contrib.contenttypes.fields import ReverseGenericManyToOneDescriptor # noqa: F401 else: - from django.contrib.contenttypes.fields import ( - ReverseGenericRelatedObjectsDescriptor as ReverseGenericManyToOneDescriptor + from django.contrib.contenttypes.fields import ( # noqa: F401 + ReverseGenericRelatedObjectsDescriptor as ReverseGenericManyToOneDescriptor # noqa: F401 ) -def get_resource_name(context): +def get_resource_name(context, expand_polymorphic_types=False): """ Return the name of a resource. """ + from rest_framework_json_api.serializers import PolymorphicModelSerializer view = context.get('view') # Sanity check to make sure we have a view. @@ -76,7 +79,10 @@ def get_resource_name(context): except AttributeError: try: serializer = view.get_serializer_class() - return get_resource_type_from_serializer(serializer) + if expand_polymorphic_types and issubclass(serializer, PolymorphicModelSerializer): + return serializer.get_polymorphic_types() + else: + return get_resource_type_from_serializer(serializer) except AttributeError: try: resource_name = get_resource_type_from_model(view.model) @@ -245,6 +251,11 @@ def get_related_resource_type(relation): relation_model = parent_model_relation.rel.model else: relation_model = parent_model_relation.field.related_model + elif hasattr(parent_model_relation, 'field'): + try: + relation_model = parent_model_relation.field.remote_field.model + except AttributeError: + relation_model = parent_model_relation.field.related.model else: return get_related_resource_type(parent_model_relation) diff --git a/setup.py b/setup.py index 75f05064..016d3df0 100755 --- a/setup.py +++ b/setup.py @@ -107,6 +107,8 @@ def get_package_data(package): 'pytest-factoryboy', 'pytest-django', 'pytest>=2.8,<3', + 'django-polymorphic', + 'packaging', ] + mock, zip_safe=False, ) diff --git a/tox.ini b/tox.ini index 7db9e4cc..3b32b700 100644 --- a/tox.ini +++ b/tox.ini @@ -1,18 +1,22 @@ [tox] envlist = - py{27,33,34,35}-django18-drf{31,32,33,34}, - py{27,34,35}-django19-drf{33,34}, - py{27,34,35}-django110-drf{34}, + py{27,33,34,35,36}-django18-drf{31,32,33,34}, + py{27,34,35,36}-django19-drf{33,34}, + py{27,34,35,36}-django110-drf34, + py{27,34,35,36}-django111-drf{34,35,36}, [testenv] deps = django18: Django>=1.8,<1.9 django19: Django>=1.9,<1.10 django110: Django>=1.10,<1.11 + django111: Django>=1.11,<1.12 drf31: djangorestframework>=3.1,<3.2 drf32: djangorestframework>=3.2,<3.3 drf33: djangorestframework>=3.3,<3.4 drf34: djangorestframework>=3.4,<3.5 + drf35: djangorestframework>=3.5,<3.6 + drf36: djangorestframework>=3.6,<3.7 setenv = PYTHONPATH = {toxinidir}