diff --git a/django_readers/specs.py b/django_readers/specs.py index 29e3615..b2fc340 100644 --- a/django_readers/specs.py +++ b/django_readers/specs.py @@ -1,5 +1,5 @@ from django_readers import pairs -from django_readers.utils import queries_disabled +from django_readers.utils import with_prepared_checker, with_queries_disabled def process_item(item): @@ -16,7 +16,9 @@ def process_item(item): def process(spec): - return queries_disabled(pairs.combine(*(process_item(item) for item in spec))) + return with_prepared_checker( + with_queries_disabled(pairs.combine(*(process_item(item) for item in spec))) + ) def relationship(name, relationship_spec, to_attr=None): diff --git a/django_readers/utils.py b/django_readers/utils.py index 6953d50..70a4607 100644 --- a/django_readers/utils.py +++ b/django_readers/utils.py @@ -44,12 +44,30 @@ def none_safe_get_attr(obj): return none_safe_get_attr -def queries_disabled(pair): +def with_queries_disabled(pair): prepare, project = pair decorator = zen_queries.queries_disabled() if zen_queries else lambda fn: fn return decorator(prepare), decorator(project) +def with_prepared_checker(pair): + prepare, project = pair + + is_prepared = False + + def wrapped_prepare(qs): + nonlocal is_prepared + is_prepared = True + return prepare(qs) + + def wrapped_project(qs): + if not is_prepared: + raise Exception("QuerySet must be prepared before projection") + return project(qs) + + return (wrapped_prepare, wrapped_project) + + class SpecVisitor: def visit(self, spec): return [self.visit_item(item) for item in spec] diff --git a/tests/test_rest_framework.py b/tests/test_rest_framework.py index 872ebf7..d2b1403 100644 --- a/tests/test_rest_framework.py +++ b/tests/test_rest_framework.py @@ -110,6 +110,29 @@ def test_detail(self): }, ) + def test_override_getqueryset_must_call_prepare(self): + Widget.objects.create() + + class WidgetListView(SpecMixin, ListAPIView): + spec = [ + "name", + ] + + def get_queryset(self): + queryset = Widget.objects.all() + # Should call self.prepare(queryset) here + return queryset + + request = APIRequestFactory().get("/") + view = WidgetListView.as_view() + + with self.assertRaises(Exception) as cm: + view(request) + + self.assertEqual( + str(cm.exception), "QuerySet must be prepared before projection" + ) + class SpecToSerializerClassTestCase(TestCase): def test_basic_spec(self):