Skip to content

Commit

Permalink
add support for GenericRelation field
Browse files Browse the repository at this point in the history
  • Loading branch information
roman-certn committed Jan 11, 2024
1 parent 14040ca commit a13c2d0
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 3 deletions.
35 changes: 35 additions & 0 deletions django_readers/qs.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from django.contrib.contenttypes.fields import ReverseGenericManyToOneDescriptor
from django.db.models import Prefetch, QuerySet
from django.db.models.constants import LOOKUP_SEP
from django.db.models.fields.related_descriptors import (
Expand Down Expand Up @@ -167,6 +168,32 @@ def prefetch_reverse_relationship(
)


def prefetch_reverse_generic_relationship(
name, related_field, related_queryset, prepare_related_queryset=noop, to_attr=None
):
"""
Efficiently prefetch a reverse generic relationship: one where the field on the "parent"
queryset is a `GenericRelation` field. We need to include this field in the query.
"""
return pipe(
include_fields(name),
prefetch_related(
Prefetch(
name,
pipe(
include_fields(
"pk",
related_field.content_type_field_name,
related_field.object_id_field_name,
),
prepare_related_queryset,
)(related_queryset),
to_attr,
)
),
)


def prefetch_many_to_many_relationship(
name, related_queryset, prepare_related_queryset=noop, to_attr=None
):
Expand Down Expand Up @@ -246,5 +273,13 @@ def prepare(queryset):
prepare_related_queryset,
to_attr,
)(queryset)
if type(related_descriptor) is ReverseGenericManyToOneDescriptor:
return prefetch_reverse_generic_relationship(
name,
related_descriptor.rel.field,
related_descriptor.field.related_model.objects.all(),
prepare_related_queryset,
to_attr,
)(queryset)

return prepare
14 changes: 14 additions & 0 deletions tests/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,17 @@
from django.contrib.contenttypes.fields import GenericRelation
from django.db import models


class LogEntry(models.Model):
content_type = models.ForeignKey(
to="contenttypes.ContentType",
on_delete=models.CASCADE,
related_name="+",
)
object_pk = models.CharField(max_length=255)
event = models.CharField(max_length=100)


class Group(models.Model):
name = models.CharField(max_length=100)

Expand All @@ -15,6 +26,9 @@ class Widget(models.Model):
value = models.PositiveIntegerField(default=0)
other = models.CharField(max_length=100, null=True)
owner = models.ForeignKey(Owner, null=True, on_delete=models.SET_NULL)
logs = GenericRelation(
LogEntry, content_type_field="content_type", object_id_field="object_pk"
)


class Thing(models.Model):
Expand Down
28 changes: 25 additions & 3 deletions tests/test_rest_framework.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from django.contrib.contenttypes.models import ContentType
from django.core.exceptions import ImproperlyConfigured
from django.test import TestCase
from django_readers import pairs, qs
Expand All @@ -10,7 +11,7 @@
from rest_framework import serializers
from rest_framework.generics import ListAPIView, RetrieveAPIView
from rest_framework.test import APIRequestFactory
from tests.models import Category, Group, Owner, Widget
from tests.models import Category, Group, LogEntry, Owner, Widget
from textwrap import dedent


Expand All @@ -28,6 +29,7 @@ class WidgetListView(SpecMixin, ListAPIView):
},
]
},
{"logs": ["event"]},
]


Expand All @@ -53,17 +55,32 @@ class CategoryDetailView(SpecMixin, RetrieveAPIView):

class RESTFrameworkTestCase(TestCase):
def test_list(self):
Widget.objects.create(
widget = Widget.objects.create(
name="test widget",
owner=Owner.objects.create(
name="test owner", group=Group.objects.create(name="test group")
),
)
LogEntry.objects.create(
content_type=ContentType.objects.get_for_model(widget),
object_pk=widget.id,
event="CREATED",
)
LogEntry.objects.create(
content_type=ContentType.objects.get_for_model(widget),
object_pk=widget.id,
event="UPDATED",
)
LogEntry.objects.create(
content_type=ContentType.objects.get_for_model(widget),
object_pk=widget.id,
event="DELETED",
)

request = APIRequestFactory().get("/")
view = WidgetListView.as_view()

with self.assertNumQueries(3):
with self.assertNumQueries(4):
response = view(request)

self.assertEqual(
Expand All @@ -77,6 +94,11 @@ def test_list(self):
"name": "test group",
},
},
"logs": [
{"event": "CREATED"},
{"event": "UPDATED"},
{"event": "DELETED"},
],
}
],
)
Expand Down

0 comments on commit a13c2d0

Please sign in to comment.