Skip to content

Commit

Permalink
Remove 'override_filters' contextmanager
Browse files Browse the repository at this point in the history
  • Loading branch information
Ryan P Kilby committed Jul 15, 2018
1 parent aa8a01c commit f7a4884
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 74 deletions.
66 changes: 23 additions & 43 deletions rest_framework_filters/filterset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import copy
from collections import OrderedDict
from contextlib import contextmanager

from django.db.models import Subquery
from django.db.models.constants import LOOKUP_SEP
Expand Down Expand Up @@ -86,10 +85,6 @@ class SubsetDisabledMixin:
def get_filter_subset(cls, params, rel=None):
return cls.base_filters

@contextmanager
def override_filters(self):
yield


class FilterSet(rest_framework.FilterSet, metaclass=FilterSetMetaclass):

Expand All @@ -100,7 +95,7 @@ def __init__(self, data=None, queryset=None, *, relationship=None, **kwargs):

self.relationship = relationship
self.related_filtersets = self.get_related_filtersets()
self.request_filters = self.get_request_filters()
self.filters = self.get_request_filters()

@classmethod
def get_fields(cls):
Expand Down Expand Up @@ -198,13 +193,10 @@ def get_request_filters(self):
# build the compiled set of all filters
requested_filters = OrderedDict()
for filter_name, f in self.filters.items():
exclude_name = '%s!' % filter_name

# Add plain lookup filters if match. ie, `username__icontains`
if related(self, filter_name) in self.data:
requested_filters[filter_name] = f
requested_filters[filter_name] = f

# include exclusion keys
# exclusion params
exclude_name = '%s!' % filter_name
if related(self, exclude_name) in self.data:
# deepcopy the *base* filter to prevent copying of model & parent
f_copy = copy.deepcopy(self.base_filters[filter_name])
Expand Down Expand Up @@ -237,21 +229,10 @@ def get_related_filtersets(self):

return related_filtersets

@contextmanager
def override_filters(self):
if not self.is_bound:
yield
else:
orig_filters = self.filters
self.filters = self.request_filters
yield
self.filters = orig_filters

def filter_queryset(self, queryset):
with self.override_filters():
queryset = super(FilterSet, self).filter_queryset(queryset)
queryset = self.filter_related_filtersets(queryset)
return queryset
queryset = super(FilterSet, self).filter_queryset(queryset)
queryset = self.filter_related_filtersets(queryset)
return queryset

def filter_related_filtersets(self, queryset):
"""
Expand All @@ -273,20 +254,19 @@ def filter_related_filtersets(self, queryset):
return queryset

def get_form_class(self):
with self.override_filters():
class Form(super(FilterSet, self).get_form_class()):
def add_prefix(form, field_name):
field_name = related(self, field_name)
return super(Form, form).add_prefix(field_name)

def clean(form):
cleaned_data = super(Form, form).clean()

# when prefixing the errors, use the related filter name,
# which is relative to the parent filterset, not the root.
for related_filterset in self.related_filtersets.values():
for key, error in related_filterset.form.errors.items():
self.form.errors[related(related_filterset, key)] = error

return cleaned_data
return Form
class Form(super(FilterSet, self).get_form_class()):
def add_prefix(form, field_name):
field_name = related(self, field_name)
return super(Form, form).add_prefix(field_name)

def clean(form):
cleaned_data = super(Form, form).clean()

# when prefixing the errors, use the related filter name,
# which is relative to the parent filterset, not the root.
for related_filterset in self.related_filtersets.values():
for key, error in related_filterset.form.errors.items():
self.form.errors[related(related_filterset, key)] = error

return cleaned_data
return Form
4 changes: 2 additions & 2 deletions tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,9 @@ class Meta:
class SimpleViewSet(views.FilterFieldsUserViewSet):
filterset_class = SimpleFilterSet

self.assertEqual(list(SimpleFilterSet({'username!': ''}).form.fields), ['username!'])
self.assertEqual(list(SimpleFilterSet({'username!': ''}).form.fields), ['username', 'username!'])
self.render(SimpleViewSet)
self.assertEqual(list(SimpleFilterSet({'username!': ''}).form.fields), ['username!'])
self.assertEqual(list(SimpleFilterSet({'username!': ''}).form.fields), ['username', 'username!'])

def test_patch_for_rendering(self):
view = views.FilterClassUserViewSet(action_map={})
Expand Down
34 changes: 5 additions & 29 deletions tests/test_filterset.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,28 +411,6 @@ def test_subset_disabled_form(self):
self.assertEqual(list(F({'author': ''}).form.fields), ['author'])


class OverrideFiltersTests(TestCase):

def test_bound(self):
f = PostFilter({})

with f.override_filters():
self.assertEqual(len(f.filters), 0)

def test_not_bound(self):
f = PostFilter(None)

with f.override_filters():
self.assertEqual(len(f.filters), 0)

def test_subset_disabled(self):
f = PostFilter.disable_subset()(None)

with f.override_filters():
# The number of filters varies by Django version
self.assertGreater(len(f.filters), 30)


class FilterExclusionTests(TestCase):

@classmethod
Expand All @@ -456,9 +434,8 @@ def test_exclude_property(self):
}

filterset = TagFilter(GET, queryset=Tag.objects.all())
requested_filters = filterset.request_filters

self.assertTrue(requested_filters['name__contains!'].exclude)
self.assertTrue(filterset.filters['name__contains!'].exclude)

def test_filter_and_exclude(self):
"""
Expand All @@ -470,20 +447,19 @@ def test_filter_and_exclude(self):
}

filterset = TagFilter(GET, queryset=Tag.objects.all())
requested_filters = filterset.request_filters

self.assertFalse(requested_filters['name__contains'].exclude)
self.assertTrue(requested_filters['name__contains!'].exclude)
self.assertFalse(filterset.filters['name__contains'].exclude)
self.assertTrue(filterset.filters['name__contains!'].exclude)

def test_related_exclude(self):
GET = {
'tags__name__contains!': 'Tag',
}

filterset = PostFilter(GET, queryset=Post.objects.all())
requested_filters = filterset.related_filtersets['tags'].request_filters
filterset = filterset.related_filtersets['tags']

self.assertTrue(requested_filters['name__contains!'].exclude)
self.assertTrue(filterset.filters['name__contains!'].exclude)

def test_exclusion_results(self):
GET = {
Expand Down

0 comments on commit f7a4884

Please sign in to comment.