Skip to content

Commit

Permalink
Re-prefetch related objects after updating (#8043)
Browse files Browse the repository at this point in the history
* Re-prefetch related objects after updating

* Fix flake8 format

* Use _prefetch_related_lookups and refine test cases

* Add more test cases and refine prefetch checking
  • Loading branch information
yuekui authored Jan 11, 2023
1 parent bfce663 commit 2b34aa4
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 32 deletions.
9 changes: 7 additions & 2 deletions rest_framework/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
We don't bind behaviour to http method handlers yet,
which allows mixin classes to be composed in interesting ways.
"""
from django.db.models.query import prefetch_related_objects

from rest_framework import status
from rest_framework.response import Response
from rest_framework.settings import api_settings
Expand Down Expand Up @@ -67,10 +69,13 @@ def update(self, request, *args, **kwargs):
serializer.is_valid(raise_exception=True)
self.perform_update(serializer)

if getattr(instance, '_prefetched_objects_cache', None):
queryset = self.filter_queryset(self.get_queryset())
if queryset._prefetch_related_lookups:
# If 'prefetch_related' has been applied to a queryset, we need to
# forcibly invalidate the prefetch cache on the instance.
# forcibly invalidate the prefetch cache on the instance,
# and then re-prefetch related objects
instance._prefetched_objects_cache = {}
prefetch_related_objects([instance], *queryset._prefetch_related_lookups)

return Response(serializer.data)

Expand Down
94 changes: 64 additions & 30 deletions tests/test_prefetch_related.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from django.contrib.auth.models import Group, User
from django.db.models.query import Prefetch
from django.test import TestCase

from rest_framework import generics, serializers
Expand All @@ -8,51 +9,84 @@


class UserSerializer(serializers.ModelSerializer):
permissions = serializers.SerializerMethodField()

def get_permissions(self, obj):
ret = []
for g in obj.groups.all():
ret.extend([p.pk for p in g.permissions.all()])
return ret

class Meta:
model = User
fields = ('id', 'username', 'email', 'groups')
fields = ('id', 'username', 'email', 'groups', 'permissions')


class UserRetrieveUpdate(generics.RetrieveUpdateAPIView):
queryset = User.objects.exclude(username='exclude').prefetch_related(
Prefetch('groups', queryset=Group.objects.exclude(name='exclude')),
'groups__permissions',
)
serializer_class = UserSerializer


class UserUpdate(generics.UpdateAPIView):
queryset = User.objects.exclude(username='exclude').prefetch_related('groups')
class UserUpdateWithoutPrefetchRelated(generics.UpdateAPIView):
queryset = User.objects.exclude(username='exclude')
serializer_class = UserSerializer


class TestPrefetchRelatedUpdates(TestCase):
def setUp(self):
self.user = User.objects.create(username='tom', email='tom@example.com')
self.groups = [Group.objects.create(name='a'), Group.objects.create(name='b')]
self.groups = [Group.objects.create(name=f'group {i}') for i in range(10)]
self.user.groups.set(self.groups)
self.user.groups.add(Group.objects.create(name='exclude'))
self.expected = {
'id': self.user.pk,
'username': 'tom',
'groups': [group.pk for group in self.groups],
'email': 'tom@example.com',
'permissions': [],
}
self.view = UserRetrieveUpdate.as_view()

def test_prefetch_related_updates(self):
view = UserUpdate.as_view()
pk = self.user.pk
groups_pk = self.groups[0].pk
request = factory.put('/', {'username': 'new', 'groups': [groups_pk]}, format='json')
response = view(request, pk=pk)
assert User.objects.get(pk=pk).groups.count() == 1
expected = {
'id': pk,
'username': 'new',
'groups': [1],
'email': 'tom@example.com'
}
assert response.data == expected
self.groups.append(Group.objects.create(name='c'))
request = factory.put(
'/', {'username': 'new', 'groups': [group.pk for group in self.groups]}, format='json'
)
self.expected['username'] = 'new'
self.expected['groups'] = [group.pk for group in self.groups]
response = self.view(request, pk=self.user.pk)
assert User.objects.get(pk=self.user.pk).groups.count() == 12
assert response.data == self.expected
# Update and fetch should get same result
request = factory.get('/')
response = self.view(request, pk=self.user.pk)
assert response.data == self.expected

def test_prefetch_related_excluding_instance_from_original_queryset(self):
"""
Regression test for https://github.com/encode/django-rest-framework/issues/4661
"""
view = UserUpdate.as_view()
pk = self.user.pk
groups_pk = self.groups[0].pk
request = factory.put('/', {'username': 'exclude', 'groups': [groups_pk]}, format='json')
response = view(request, pk=pk)
assert User.objects.get(pk=pk).groups.count() == 1
expected = {
'id': pk,
'username': 'exclude',
'groups': [1],
'email': 'tom@example.com'
}
assert response.data == expected
request = factory.put(
'/', {'username': 'exclude', 'groups': [self.groups[0].pk]}, format='json'
)
response = self.view(request, pk=self.user.pk)
assert User.objects.get(pk=self.user.pk).groups.count() == 2
self.expected['username'] = 'exclude'
self.expected['groups'] = [self.groups[0].pk]
assert response.data == self.expected

def test_db_query_count(self):
request = factory.put(
'/', {'username': 'new'}, format='json'
)
with self.assertNumQueries(7):
self.view(request, pk=self.user.pk)

request = factory.put(
'/', {'username': 'new2'}, format='json'
)
with self.assertNumQueries(16):
UserUpdateWithoutPrefetchRelated.as_view()(request, pk=self.user.pk)

0 comments on commit 2b34aa4

Please sign in to comment.