diff --git a/README.rst b/README.rst index e88bbdf..fdb3a22 100644 --- a/README.rst +++ b/README.rst @@ -590,6 +590,35 @@ errors would be raised like so: { +Complex JSON Filtering with Boolean Logic +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The ``ComplexJSONFilterBackend`` backend allows a user to filter using a JSON definition instead of an encoded string. Pass an encoded representation of a json object that has a top-level `or` or `and` key, mapped to an array of clauses to the `json_filters` option. These clauses can either be other `or` or `and` clauses or a mapping of query params to their values. For example to query for all resources where (title does not contain "Why") AND (title starts with "Who" OR title starts with "What"): + +.. code-block:: python + + filters = { + "and": [ + { + "or": [ + { + "title__startswith": "Who" + }, + { + "title__startswith": "What" + }, + ] + }, + { + "title__icontains!": "Why" + }, + ] + } + querystring = "json_filters={filters}".format( + filters=quote(json.dumps(filters)) + ) + + Migrating to 1.0 ---------------- diff --git a/rest_framework_filters/backends.py b/rest_framework_filters/backends.py index a8476a7..03bf984 100644 --- a/rest_framework_filters/backends.py +++ b/rest_framework_filters/backends.py @@ -1,5 +1,7 @@ +import json from contextlib import contextmanager +from django.db.models import QuerySet from django.http import QueryDict from django_filters import compat from django_filters.rest_framework import backends @@ -8,6 +10,8 @@ from .complex_ops import combine_complex_queryset, decode_complex_ops from .filterset import FilterSet +COMPLEX_JSON_OPERATORS = {"and": QuerySet.__and__, "or": QuerySet.__or__} + class RestFrameworkFilterBackend(backends.DjangoFilterBackend): filterset_base = FilterSet @@ -96,3 +100,60 @@ def get_filtered_querysets(self, querystrings, request, queryset, view): if errors: raise ValidationError(errors) return querysets + + +class ComplexJSONFilterBackend(RestFrameworkFilterBackend): + complex_filter_param = "json_filters" + + def filter_queryset(self, request, queryset, view): + res = super().filter_queryset(request, queryset, view) + if self.complex_filter_param not in request.query_params: + return res + + encoded_querystring = request.query_params[self.complex_filter_param] + try: + complex_ops = json.loads(encoded_querystring) + return self.combine_filtered_querysets(complex_ops, request, res, view) + except ValidationError as exc: + raise ValidationError({self.complex_filter_param: exc.detail}) + except json.decoder.JSONDecodeError: + raise ValidationError({self.complex_filter_param: "unable to parse json."}) + + def combine_filtered_querysets(self, complex_filter, request, queryset, view): + """ + Function used recursively to filter the complex filter boolean logic + Args: + complex_filter: the json complex filter + request: request + queryset: starting queryset, unfiltered + view: the view + + Returns: + queryset + """ + operator = None + combined_queryset = None + for symbol, complex_operator in COMPLEX_JSON_OPERATORS.items(): + if operator is None and symbol in complex_filter: + operator = complex_operator + for sub_filter in complex_filter[symbol]: + filtered_queryset = self.combine_filtered_querysets(sub_filter, request, queryset, view) + if combined_queryset is None: + combined_queryset = filtered_queryset + else: + combined_queryset = complex_operator(combined_queryset, filtered_queryset) + if operator: + return combined_queryset + + return self.get_filtered_queryset( + "&".join(["{k}={v}".format(k=k, v=v) for k, v in complex_filter.items()]), request, queryset, view + ) + + def get_filtered_queryset(self, querystring, request, queryset, view): + original_GET = request._request.GET + request._request.GET = QueryDict(querystring) + try: + res = super().filter_queryset(request, queryset, view) + finally: + request._request.GET = original_GET + return res diff --git a/tests/test_backends.py b/tests/test_backends.py index 9083df6..ad42bcb 100644 --- a/tests/test_backends.py +++ b/tests/test_backends.py @@ -1,3 +1,4 @@ +import json from urllib.parse import quote, urlencode import django_filters @@ -480,3 +481,114 @@ def test_pagination_compatibility(self): [r['username'] for r in response.data['results']], ['user3'], ) + + +class ComplexJSONFilterBackendTests(APITestCase): + + @classmethod + def setUpTestData(cls): + models.User.objects.create(username="user1", email="user1@example.com") + models.User.objects.create(username="user2", email="user2@example.com") + models.User.objects.create(username="user3", email="user3@example.org") + models.User.objects.create(username="user4", email="user4@example.org") + + def test_valid(self): + readable = json.dumps({ + "or": [ + { + "username": "user1" + }, + { + "email__contains": "example.org" + } + ] + }) + response = self.client.get('/ffjsoncomplex-users/?json_filters=' + quote(readable), content_type='json') + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertListEqual( + [r['username'] for r in response.data], + ['user1', 'user3', 'user4'] + ) + + def test_invalid(self): + readable = json.dumps({ + "or": [ + { + "username": "user1" + }, + { + "email__contains": "example.org" + } + ] + })[0:10] + response = self.client.get('/ffjsoncomplex-users/?json_filters=' + quote(readable), content_type='json') + + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertDictEqual(response.json(), { + 'json_filters': "unable to parse json.", + }) + + def test_invalid_filterset_errors(self): + readable = json.dumps({ + "or": [ + { + "id": "foo" + }, + { + "id": "bar" + } + ] + }) + response = self.client.get('/ffjsoncomplex-users/?json_filters=' + quote(readable), content_type='json') + + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertDictEqual(response.json(), { + 'json_filters': { + 'id': ["Enter a number."], + }, + }) + + def test_pagination_compatibility(self): + """ + Ensure that complex-filtering does not interfere with additional query param processing. + """ + readable = json.dumps({ + "or": [ + { + "email__contains": "example.org" + } + ] + }) + + # sanity check w/o pagination + response = self.client.get('/ffjsoncomplex-users/?json_filters=' + quote(readable), content_type='json') + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertListEqual( + [r['username'] for r in response.data], + ['user3', 'user4'] + ) + + # sanity check w/o complex-filtering + response = self.client.get('/ffjsoncomplex-users/?page_size=1', content_type='json') + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertIn('results', response.data) + self.assertListEqual( + [r['username'] for r in response.data['results']], + ['user1'] + ) + + # pagination + complex-filtering + response = self.client.get( + '/ffjsoncomplex-users/?page_size=1&json_filters=' + quote(readable), + content_type='json' + ) + + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertIn('results', response.data) + self.assertListEqual( + [r['username'] for r in response.data['results']], + ['user3'] + ) diff --git a/tests/testapp/urls.py b/tests/testapp/urls.py index 3d3429c..24b1e0d 100644 --- a/tests/testapp/urls.py +++ b/tests/testapp/urls.py @@ -10,6 +10,9 @@ router.register('ffcomplex-users', views.ComplexFilterFieldsUserViewSet, basename='ffcomplex-users') +router.register(r'ffjsoncomplex-users', + views.ComplexJSONFilterFieldsUserViewSet, + basename='ffjsoncomplex-users') router.register('users', views.UserViewSet) router.register('notes', views.NoteViewSet) diff --git a/tests/testapp/views.py b/tests/testapp/views.py index 1d9e0e2..03dc668 100644 --- a/tests/testapp/views.py +++ b/tests/testapp/views.py @@ -52,6 +52,19 @@ class pagination_class(pagination.PageNumberPagination): page_size_query_param = 'page_size' +class ComplexJSONFilterFieldsUserViewSet(FilterFieldsUserViewSet): + queryset = User.objects.order_by('pk') + filter_backends = (backends.ComplexJSONFilterBackend, ) + filterset_fields = { + 'id': '__all__', + 'username': '__all__', + 'email': '__all__', + } + + class pagination_class(pagination.PageNumberPagination): + page_size_query_param = 'page_size' + + class UserViewSet(viewsets.ModelViewSet): queryset = User.objects.all() serializer_class = UserSerializer