From d9e11964b94c24cd536dff7d0ce75974a01c7b49 Mon Sep 17 00:00:00 2001 From: Tom Date: Fri, 27 Nov 2015 19:28:53 +0000 Subject: [PATCH] Issue #155 - support serializers with custom primary keys --- example/tests/test_serializers.py | 1 - example/views.py | 1 - rest_framework_json_api/renderers.py | 4 +-- rest_framework_json_api/serializers.py | 7 +++-- rest_framework_json_api/utils.py | 37 +++++++++++++++++++++----- 5 files changed, 38 insertions(+), 12 deletions(-) diff --git a/example/tests/test_serializers.py b/example/tests/test_serializers.py index 6712ec7e..de8031bf 100644 --- a/example/tests/test_serializers.py +++ b/example/tests/test_serializers.py @@ -70,4 +70,3 @@ def test_deserialize_many(self): self.assertTrue(serializer.is_valid(), msg=serializer.errors) print(serializer.data) - diff --git a/example/views.py b/example/views.py index 59ca1a05..6a3fb505 100644 --- a/example/views.py +++ b/example/views.py @@ -41,4 +41,3 @@ class CommentRelationshipView(RelationshipView): class AuthorRelationshipView(RelationshipView): queryset = Author.objects.all() self_link_view_name = 'author-relationships' - diff --git a/rest_framework_json_api/renderers.py b/rest_framework_json_api/renderers.py index 16f6fc13..4ebb026e 100644 --- a/rest_framework_json_api/renderers.py +++ b/rest_framework_json_api/renderers.py @@ -105,8 +105,8 @@ def render(self, data, accepted_media_type=None, renderer_context=None): for position in range(len(serializer_data)): resource = serializer_data[position] # Get current resource resource_instance = resource_serializer.instance[position] # Get current instance - json_api_data.append( - utils.build_json_resource_obj(fields, resource, resource_instance, resource_name)) + resource_obj = utils.build_json_resource_obj(fields, resource, resource_instance, resource_name) + json_api_data.append(resource_obj) included = utils.extract_included(fields, resource, resource_instance, included_resources) if included: json_api_included.extend(included) diff --git a/rest_framework_json_api/serializers.py b/rest_framework_json_api/serializers.py index 8fd78292..5e08e4f4 100644 --- a/rest_framework_json_api/serializers.py +++ b/rest_framework_json_api/serializers.py @@ -15,9 +15,11 @@ class ResourceIdentifierObjectSerializer(BaseSerializer): } model_class = None + pk_field = 'pk' def __init__(self, *args, **kwargs): self.model_class = kwargs.pop('model_class', self.model_class) + self.pk_field = kwargs.pop('pk_field', self.pk_field) if 'instance' not in kwargs and not self.model_class: raise RuntimeError('ResourceIdentifierObjectsSerializer must be initialized with a model class.') super(ResourceIdentifierObjectSerializer, self).__init__(*args, **kwargs) @@ -25,15 +27,16 @@ def __init__(self, *args, **kwargs): def to_representation(self, instance): return { 'type': format_relation_name(get_resource_type_from_instance(instance)), - 'id': str(instance.pk) + 'id': str(getattr(instance, self.pk_field)) } def to_internal_value(self, data): if data['type'] != format_relation_name(self.model_class.__name__): self.fail('incorrect_model_type', model_type=self.model_class, received_type=data['type']) pk = data['id'] + lookup = {self.pk_field: pk} try: - return self.model_class.objects.get(pk=pk) + return self.model_class.objects.get(**lookup) except ObjectDoesNotExist: self.fail('does_not_exist', pk_value=pk) except (TypeError, ValueError): diff --git a/rest_framework_json_api/utils.py b/rest_framework_json_api/utils.py index 84d2511d..35dd45e6 100644 --- a/rest_framework_json_api/utils.py +++ b/rest_framework_json_api/utils.py @@ -149,9 +149,17 @@ def format_relation_name(value, format_type=None): def build_json_resource_obj(fields, resource, resource_instance, resource_name): + if resource_instance is None: + pk = None + else: + # Check if the primary key exists in the resource by getting the primary keys attribute name. + pk_attr = resource_instance._meta.pk.name + pk = resource[pk_attr] if pk_attr in resource else resource_instance.pk + pk = encoding.force_text(pk) + resource_data = [ ('type', resource_name), - ('id', encoding.force_text(resource_instance.pk) if resource_instance else None), + ('id', pk), ('attributes', extract_attributes(fields, resource)), ] relationships = extract_relationships(fields, resource, resource_instance) @@ -286,6 +294,15 @@ def extract_relationships(fields, resource, resource_instance): else: continue + # Take a model and return it's primary key, calling the primary key fields 'get_attribute' + # function or the models .pk property. + def get_pk(obj): + pk_attr = obj._meta.pk.name + if hasattr(field, 'fields') and pk_attr in field.fields: + return field.fields[pk_attr].get_attribute(obj) + else: + return obj.pk + relation_type = get_related_resource_type(field) if isinstance(field, HyperlinkedIdentityField): @@ -298,7 +315,7 @@ def extract_relationships(fields, resource, resource_instance): for related_object in relation_queryset: relation_data.append( - OrderedDict([('type', relation_type), ('id', encoding.force_text(related_object.pk))]) + OrderedDict([('type', relation_type), ('id', encoding.force_text(get_pk(related_object)))]) ) data.update({field_name: { @@ -326,7 +343,7 @@ def extract_relationships(fields, resource, resource_instance): continue if isinstance(field, (PrimaryKeyRelatedField, HyperlinkedRelatedField)): - relation_id = relation_instance_or_manager.pk if resource.get(field_name) else None + relation_id = get_pk(relation_instance_or_manager) if resource.get(field_name) else None relation_data = { 'data': ( @@ -369,7 +386,7 @@ def extract_relationships(fields, resource, resource_instance): related_object_type = get_instance_or_manager_resource_type(related_object) relation_data.append(OrderedDict([ ('type', related_object_type), - ('id', encoding.force_text(related_object.pk)) + ('id', encoding.force_text(get_pk(related_object))) ])) data.update({ field_name: { @@ -389,10 +406,18 @@ def extract_relationships(fields, resource, resource_instance): if isinstance(serializer_data, list): for position in range(len(serializer_data)): nested_resource_instance = resource_instance_queryset[position] + nested_resource_data = serializer_data[position] nested_resource_instance_type = get_resource_type_from_instance(nested_resource_instance) + + instance_pk_name = nested_resource_instance._meta.pk.name + if instance_pk_name in nested_resource_data: + pk = nested_resource_data[instance_pk_name] + else: + pk = nested_resource_instance.pk + relation_data.append(OrderedDict([ ('type', nested_resource_instance_type), - ('id', encoding.force_text(nested_resource_instance.pk)) + ('id', encoding.force_text(pk)) ])) data.update({field_name: {'data': relation_data}}) @@ -407,7 +432,7 @@ def extract_relationships(fields, resource, resource_instance): 'data': ( OrderedDict([ ('type', relation_type), - ('id', encoding.force_text(relation_instance_or_manager.pk)) + ('id', get_pk(relation_instance_or_manager)) ]) if resource.get(field_name) else None) } })