Skip to content

Commit

Permalink
Add compatibility to form prefixing
Browse files Browse the repository at this point in the history
  • Loading branch information
Ryan P Kilby committed Jul 13, 2018
1 parent 7277b1d commit f57073e
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 3 deletions.
19 changes: 16 additions & 3 deletions rest_framework_filters/filterset.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,19 @@ def related(filterset, filter_name):
return LOOKUP_SEP.join([filterset.relationship, filter_name])


def prefixed(filterset, filter_name):
"""
Return a filter name, using the filterset relationship and form prefix if present.
Note: This could invoke `Form.add_prefix`, but the result of `get_form_class()`
should be cached in order to prevent unnecessary duplication of the form class.
"""
filter_name = related(filterset, filter_name)
if not filterset.form_prefix:
return filter_name
return '%s-%s' % (filterset.form_prefix, filter_name)


class FilterSetMetaclass(filterset.FilterSetMetaclass):
def __new__(cls, name, bases, attrs):
new_class = super(FilterSetMetaclass, cls).__new__(cls, name, bases, attrs)
Expand Down Expand Up @@ -112,11 +125,11 @@ def get_request_filters(self):
exclude_name = '%s!' % filter_name

# Add plain lookup filters if match. ie, `username__icontains`
if related(self, filter_name) in self.data:
if prefixed(self, filter_name) in self.data:
requested_filters[filter_name] = f

# include exclusion keys
if related(self, exclude_name) in self.data:
if prefixed(self, exclude_name) in self.data:
# deepcopy the *base* filter to prevent copying of model & parent
f_copy = copy.deepcopy(self.base_filters[filter_name])
f_copy.parent = f.parent
Expand All @@ -131,7 +144,7 @@ def get_related_filtersets(self):
related_filtersets = OrderedDict()

for related_name in self.related_filters:
prefix = '%s%s' % (related(self, related_name), LOOKUP_SEP)
prefix = '%s%s' % (prefixed(self, related_name), LOOKUP_SEP)
if not any(value.startswith(prefix) for value in self.data):
continue

Expand Down
36 changes: 36 additions & 0 deletions tests/test_filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,42 @@ class PostFilter(FilterSet):
self.assertEqual(f.__name__, 'LocalTagFilter')


class FormPrefixTests(TestCase):
"""
Compatibility with form prefixing is a non-requirement, but a nice to have.
"""

@classmethod
def setUpTestData(cls):
user1 = User.objects.create(username="user1", email="user1@example.org")
user2 = User.objects.create(username="user2", email="user2@example.org")

Post.objects.create(author=user1, title="Test post 1")
Post.objects.create(author=user1, title="Test post 2")
Post.objects.create(author=user2, title="Test post 3")

def test_filter(self):
f1 = PostFilter({'title__endswith': '1'})
f2 = PostFilter({'prefix-title__endswith': '1'}, prefix='prefix')

self.assertQuerysetEqual(f1.qs, [1], lambda p: p.pk)
self.assertQuerysetEqual(f2.qs, [1], lambda p: p.pk)

def test_related_filter(self):
f1 = PostFilter({'author__username': 'user2'})
f2 = PostFilter({'prefix-author__username': 'user2'}, prefix='prefix')

self.assertQuerysetEqual(f1.qs, [3], lambda p: p.pk)
self.assertQuerysetEqual(f2.qs, [3], lambda p: p.pk)

def test_validation_errors(self):
f1 = PostFilter({'author__last_login__date': '2018'})
f2 = PostFilter({'prefix-author__last_login__date': '2018'}, prefix='prefix')

self.assertEqual(f1.errors, {'author__last_login__date': ['Enter a valid date.']})
self.assertEqual(f2.errors, {'author__last_login__date': ['Enter a valid date.']})


class MiscTests(TestCase):
def test_multiwidget_incompatibility(self):
Person.objects.create(name='A')
Expand Down
7 changes: 7 additions & 0 deletions tests/test_forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,13 @@ class Meta:
form = F().form
self.assertEqual(list(form.fields), ['title', 'author'])

def test_form_prefix(self):
f = PostFilter({'prefix-author__username': 'bob'}, prefix='prefix')
self.assertEqual(f.form.prefix, 'prefix')

f = f.related_filtersets['author']
self.assertEqual(f.form.prefix, 'prefix')

def test_validation_errors(self):
f = PostFilter({
'publish_date__year': 'foo',
Expand Down

0 comments on commit f57073e

Please sign in to comment.