Skip to content

Commit f3427b1

Browse files
Declare Prefetch as generic and do specialization in the plugin (#2786)
1 parent 6dbd81a commit f3427b1

File tree

5 files changed

+144
-35
lines changed

5 files changed

+144
-35
lines changed

django-stubs/db/models/query.pyi

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,12 @@ _Model = TypeVar("_Model", bound=Model, covariant=True)
1515
_Row = TypeVar("_Row", covariant=True, default=_Model) # ONLY use together with _Model
1616
_TupleT = TypeVar("_TupleT", bound=tuple[Any, ...], covariant=True)
1717

18+
# The type of the prefetched model
19+
_PrefetchedQuerySetT = TypeVar("_PrefetchedQuerySetT", bound=QuerySet[Model], covariant=True, default=QuerySet[Model])
20+
# The attribute name to use to store the prefetched list[_PrefetchedQuerySetT]
21+
# This will be specialized to a `LiteralString` in the plugin for further processing and validation
22+
_ToAttrT = TypeVar("_ToAttrT", bound=str, covariant=True, default=str)
23+
1824
MAX_GET_RESULTS: int
1925
REPR_OUTPUT_SIZE: int
2026

@@ -234,21 +240,23 @@ class RawQuerySet(Iterable[_Model], Sized):
234240
# Deprecated alias of QuerySet, for compatibility only.
235241
_QuerySet: TypeAlias = QuerySet # noqa: PYI047
236242

237-
class Prefetch:
243+
class Prefetch(Generic[_PrefetchedQuerySetT, _ToAttrT]):
238244
prefetch_through: str
239245
prefetch_to: str
240-
queryset: QuerySet | None
241-
to_attr: str | None
242-
def __init__(self, lookup: str, queryset: QuerySet | None = None, to_attr: str | None = None) -> None: ...
246+
queryset: _PrefetchedQuerySetT | None
247+
to_attr: _ToAttrT | None
248+
def __init__(
249+
self, lookup: str, queryset: _PrefetchedQuerySetT | None = None, to_attr: _ToAttrT | None = None
250+
) -> None: ...
243251
def __getstate__(self) -> dict[str, Any]: ...
244252
def add_prefix(self, prefix: str) -> None: ...
245253
def get_current_prefetch_to(self, level: int) -> str: ...
246254
def get_current_to_attr(self, level: int) -> tuple[str, str]: ...
247255
@deprecated(
248256
"get_current_queryset() is deprecated and will be removed in Django 6.0. Use get_current_querysets() instead."
249257
)
250-
def get_current_queryset(self, level: int) -> QuerySet | None: ...
251-
def get_current_querysets(self, level: int) -> list[QuerySet] | None: ...
258+
def get_current_queryset(self, level: int) -> _PrefetchedQuerySetT | None: ...
259+
def get_current_querysets(self, level: int) -> list[_PrefetchedQuerySetT] | None: ...
252260

253261
def prefetch_related_objects(model_instances: Iterable[_Model], *related_lookups: str | Prefetch) -> None: ...
254262
async def aprefetch_related_objects(model_instances: Iterable[_Model], *related_lookups: str | Prefetch) -> None: ...

ext/django_stubs_ext/patch.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from django.db.models.lookups import Lookup
2323
from django.db.models.manager import BaseManager
2424
from django.db.models.options import Options
25-
from django.db.models.query import BaseIterable, ModelIterable, QuerySet, RawQuerySet
25+
from django.db.models.query import BaseIterable, ModelIterable, Prefetch, QuerySet, RawQuerySet
2626
from django.forms.formsets import BaseFormSet
2727
from django.forms.models import BaseModelForm, BaseModelFormSet, ModelChoiceField, ModelFormOptions
2828
from django.utils.connection import BaseConnectionHandler, ConnectionProxy
@@ -97,6 +97,7 @@ def __repr__(self) -> str:
9797
MPGeneric(BaseIterable),
9898
MPGeneric(ForwardManyToOneDescriptor),
9999
MPGeneric(ReverseOneToOneDescriptor),
100+
MPGeneric(Prefetch),
100101
]
101102

102103

mypy_django_plugin/main.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,9 @@ def get_function_hook(self, fullname: str) -> Callable[[FunctionContext], MypyTy
143143
if info.has_base(fullnames.BASE_MANAGER_CLASS_FULLNAME):
144144
return querysets.determine_proper_manager_type
145145

146+
if info.has_base(fullnames.PREFETCH_CLASS_FULLNAME):
147+
return partial(querysets.specialize_prefetch_type, django_context=self.django_context)
148+
146149
return None
147150

148151
@cached_property

mypy_django_plugin/transformers/querysets.py

Lines changed: 72 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from mypy.checker import TypeChecker
88
from mypy.nodes import ARG_NAMED, ARG_NAMED_OPT, CallExpr, Expression
99
from mypy.plugin import FunctionContext, MethodContext
10-
from mypy.types import AnyType, Instance, TupleType, TypedDictType, TypeOfAny, get_proper_type
10+
from mypy.types import AnyType, Instance, LiteralType, TupleType, TypedDictType, TypeOfAny, get_proper_type
1111
from mypy.types import Type as MypyType
1212

1313
from mypy_django_plugin.django.context import DjangoContext, LookupsAreUnsupported
@@ -318,20 +318,37 @@ def extract_proper_type_queryset_values(ctx: MethodContext, django_context: Djan
318318
return default_return_type.copy_modified(args=[django_model.typ, row_type])
319319

320320

321-
def _infer_prefetch_element_model_type(queryset_expr: Expression | None, api: TypeChecker) -> Instance | None:
321+
def _infer_prefetch_queryset_type(queryset_expr: Expression, api: TypeChecker) -> Instance | None:
322322
"""Infer the model Instance from `Prefetch(queryset=...)`"""
323-
if queryset_expr is None:
324-
# TODO: Infer the model type from the lookup in `Prefetch(lookup=..., to_attr=...)`
325-
return None
326323
try:
327324
qs_type = get_proper_type(api.expr_checker.accept(queryset_expr))
328325
except Exception:
329326
return None
330327
if isinstance(qs_type, Instance):
331-
return helpers.extract_model_type_from_queryset(qs_type, api)
328+
return qs_type
332329
return None
333330

334331

332+
def specialize_prefetch_type(ctx: FunctionContext, django_context: DjangoContext) -> MypyType:
333+
"""Function hook for `Prefetch(...)` to specialize its `to_attr` generic parameters."""
334+
default = get_proper_type(ctx.default_return_type)
335+
if not isinstance(default, Instance):
336+
return ctx.default_return_type
337+
338+
api = helpers.get_typechecker_api(ctx)
339+
340+
# Guaranteed to exist because the TypeVar has a default
341+
to_attr_type = default.args[1]
342+
# We specialize the `to_attr` str type to a Literal[str] so that it can be used later
343+
# to annotate the correct attribute name on the model instances and do further validations
344+
# See `extract_prefetch_related_annotations` below.
345+
if to_attr_expr := helpers.get_call_argument_by_name(ctx, "to_attr"):
346+
if to_attr_value := helpers.resolve_string_attribute_value(to_attr_expr, django_context):
347+
to_attr_type = LiteralType(value=to_attr_value, fallback=api.named_generic_type("builtins.str", []))
348+
349+
return default.copy_modified(args=[default.args[0], to_attr_type])
350+
351+
335352
def extract_prefetch_related_annotations(ctx: MethodContext, django_context: DjangoContext) -> MypyType:
336353
"""
337354
Extract annotated attributes via `prefetch_related(Prefetch(..., to_attr=...))`
@@ -353,18 +370,57 @@ def extract_prefetch_related_annotations(ctx: MethodContext, django_context: Dja
353370

354371
for expr, typ in zip(ctx.args[0], ctx.arg_types[0], strict=False):
355372
typ = get_proper_type(typ)
356-
if not (
357-
isinstance(typ, Instance)
358-
and typ.type.fullname == fullnames.PREFETCH_CLASS_FULLNAME
359-
and isinstance(expr, CallExpr)
360-
and (to_attr_expr := helpers.get_class_init_argument_by_name(expr, "to_attr"))
361-
and (to_attr_value := helpers.resolve_string_attribute_value(to_attr_expr, django_context))
362-
):
373+
if not (isinstance(typ, Instance) and typ.type.has_base(fullnames.PREFETCH_CLASS_FULLNAME)):
363374
continue
364375

365-
# Determine model type from the `queryset` attr
366-
queryset_expr = helpers.get_class_init_argument_by_name(expr, "queryset")
367-
elem_model = _infer_prefetch_element_model_type(queryset_expr, api)
376+
# 1) Try to get to_attr from specialized type arg
377+
to_attr_value: str | None = None
378+
if (
379+
len(typ.args) >= 2
380+
and isinstance((to_attr_t := get_proper_type(typ.args[1])), LiteralType)
381+
and isinstance(to_attr_t.value, str)
382+
):
383+
to_attr_value = to_attr_t.value
384+
385+
# Fallback: parse inline call expression
386+
if to_attr_value is None:
387+
if not (
388+
isinstance(expr, CallExpr)
389+
and (to_attr_expr := helpers.get_class_init_argument_by_name(expr, "to_attr"))
390+
and (to_attr_value := helpers.resolve_string_attribute_value(to_attr_expr, django_context))
391+
):
392+
continue
393+
394+
# 2) Determine element model type from specialized type arg
395+
elem_model: Instance | None = None
396+
if len(typ.args) >= 1 and isinstance((queryset_type := get_proper_type(typ.args[0])), Instance):
397+
elem_model = helpers.extract_model_type_from_queryset(queryset_type, api)
398+
399+
# Fallback: parse inline call expression
400+
if elem_model is None and isinstance(expr, CallExpr):
401+
queryset_expr = helpers.get_class_init_argument_by_name(expr, "queryset")
402+
if queryset_expr is not None and isinstance(
403+
(inferred_queryset_type := _infer_prefetch_queryset_type(queryset_expr, api)), Instance
404+
):
405+
elem_model = helpers.extract_model_type_from_queryset(inferred_queryset_type, api)
406+
else:
407+
# Resolve model type using the "lookup" required first argument and
408+
# the model associated with the current queryset.
409+
lookup_expr = helpers.get_class_init_argument_by_name(expr, "lookup")
410+
if lookup_expr is None:
411+
continue
412+
413+
if lookup_value := helpers.resolve_string_attribute_value(lookup_expr, django_context):
414+
if django_model := helpers.get_model_info_from_qs_ctx(ctx, django_context):
415+
try:
416+
observed_model_cls = django_context.resolve_lookup_into_field(
417+
django_model.cls, lookup_value
418+
)[1]
419+
if model_info := helpers.lookup_class_typeinfo(api, observed_model_cls):
420+
elem_model = Instance(model_info, [])
421+
except (FieldError, LookupsAreUnsupported):
422+
pass
423+
368424
value_type = api.named_generic_type(
369425
"builtins.list",
370426
[elem_model if elem_model is not None else AnyType(TypeOfAny.special_form)],

tests/typecheck/managers/querysets/test_prefetch_related.yml

Lines changed: 53 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,28 @@
11
- case: prefetch_related_to_attr
22
main: |
33
from myapp.models import Article, Tag
4-
from django.db.models import Prefetch
4+
from django.db.models import Prefetch, F, QuerySet
55
from django.db import models
6+
from typing_extensions import Literal, TypedDict
7+
from django_stubs_ext import WithAnnotations
68
7-
# Noop (to_attr not provided)
9+
### Noop (to_attr not provided)
810
reveal_type(Article.objects.prefetch_related(Prefetch("tags")).all()) # N: Revealed type is "django.db.models.query.QuerySet[myapp.models.Article, myapp.models.Article]"
911
reveal_type(Article.objects.prefetch_related(Prefetch("tags", Tag.objects.all())).all()) # N: Revealed type is "django.db.models.query.QuerySet[myapp.models.Article, myapp.models.Article]"
12+
reveal_type(Prefetch("tags")) # N: Revealed type is "django.db.models.query.Prefetch[django.db.models.query.QuerySet[django.db.models.base.Model, django.db.models.base.Model], builtins.str]"
13+
reveal_type(Prefetch("tags", Tag.objects.all())) # N: Revealed type is "django.db.models.query.Prefetch[django.db.models.query.QuerySet[myapp.models.Tag, myapp.models.Tag], builtins.str]"
14+
15+
# Prefetch created in a function with no to_attr
16+
def get_prefetch_no_to_attr() -> Prefetch[QuerySet[Tag]]:
17+
return Prefetch("tags", Tag.objects.all())
18+
19+
reveal_type(Article.objects.prefetch_related(get_prefetch_no_to_attr()).all()) # N: Revealed type is "django.db.models.query.QuerySet[myapp.models.Article, myapp.models.Article]"
20+
21+
# Prefetch created in a function with no to_attr and no queryset
22+
def get_prefetch_no_to_attr_no_qs() -> Prefetch:
23+
return Prefetch("tags")
24+
25+
reveal_type(Article.objects.prefetch_related(get_prefetch_no_to_attr_no_qs()).all()) # N: Revealed type is "django.db.models.query.QuerySet[myapp.models.Article, myapp.models.Article]"
1026
1127
# On the QuerySet
1228
article_qs = Article.objects.all().prefetch_related(Prefetch("tags", Tag.objects.all(), to_attr="every_tags"))
@@ -59,29 +75,54 @@
5975
)
6076
reveal_type(mixed_plain) # N: Revealed type is "django.db.models.query.QuerySet[myapp.models.Article@AnnotatedWith[TypedDict({'arts3': builtins.list[myapp.models.Article]})], myapp.models.Article@AnnotatedWith[TypedDict({'arts3': builtins.list[myapp.models.Article]})]]"
6177
62-
63-
## Not Supported
64-
65-
# Prefetch with `to_attr` arg but without the `queryset` arg
66-
# TODO: We should be able to resolve a more accurate type using existing lookup `resolve_lookup_expected_type` machinery
67-
reveal_type(Article.objects.prefetch_related(models.Prefetch("tags", to_attr="just_tags")).get().just_tags) # N: Revealed type is "builtins.list[Any]"
68-
6978
# Intermediary variable -- function scope
7079
def foo() -> None:
7180
tag_prefetch = Prefetch("tags", Tag.objects.all(), to_attr="every_tags")
72-
reveal_type(Article.objects.prefetch_related(tag_prefetch).all()) # N: Revealed type is "django.db.models.query.QuerySet[myapp.models.Article, myapp.models.Article]"
81+
reveal_type(Article.objects.prefetch_related(tag_prefetch).all()) # N: Revealed type is "django.db.models.query.QuerySet[myapp.models.Article@AnnotatedWith[TypedDict({'every_tags': builtins.list[myapp.models.Tag]})], myapp.models.Article@AnnotatedWith[TypedDict({'every_tags': builtins.list[myapp.models.Tag]})]]"
7382
7483
# Intermediary variable -- module scope
7584
tag_prefetch = Prefetch("tags", Tag.objects.all(), to_attr="every_tags")
76-
reveal_type(Article.objects.prefetch_related(tag_prefetch).all()) # N: Revealed type is "django.db.models.query.QuerySet[myapp.models.Article, myapp.models.Article]"
85+
reveal_type(Article.objects.prefetch_related(tag_prefetch).all()) # N: Revealed type is "django.db.models.query.QuerySet[myapp.models.Article@AnnotatedWith[TypedDict({'every_tags': builtins.list[myapp.models.Tag]})], myapp.models.Article@AnnotatedWith[TypedDict({'every_tags': builtins.list[myapp.models.Tag]})]]"
86+
87+
# Prefetch created in a function
88+
def get_invalid_prefetch() -> Prefetch[QuerySet[Tag], Literal["every_tags"]]:
89+
return Prefetch("tags", Tag.objects.all(), to_attr="foo") # E: Incompatible return value type (got "Prefetch[QuerySet[Tag, Tag], Literal['foo']]", expected "Prefetch[QuerySet[Tag, Tag], Literal['every_tags']]") [return-value] # E: Argument "to_attr" to "Prefetch" has incompatible type "Literal['foo']"; expected "Literal['every_tags'] | None" [arg-type]
90+
91+
def get_prefetch() -> Prefetch[QuerySet[Tag], Literal["every_tags"]]:
92+
return Prefetch("tags", Tag.objects.all(), to_attr="every_tags")
93+
94+
reveal_type(Article.objects.prefetch_related(get_prefetch()).all()) # N: Revealed type is "django.db.models.query.QuerySet[myapp.models.Article@AnnotatedWith[TypedDict({'every_tags': builtins.list[myapp.models.Tag]})], myapp.models.Article@AnnotatedWith[TypedDict({'every_tags': builtins.list[myapp.models.Tag]})]]"
95+
96+
# Prefetch created in a function with to_attr as an intermediary variable
97+
def get_prefetch_with_var() -> Prefetch[QuerySet[Tag], Literal["every_tags"]]:
98+
to_attr = "every_tags" # TODO: RM error next line
99+
return Prefetch("tags", Tag.objects.all(), to_attr) # E: Argument 3 to "Prefetch" has incompatible type "str"; expected "Literal['every_tags'] | None" [arg-type]
100+
101+
reveal_type(Article.objects.prefetch_related(get_prefetch_with_var()).all()) # N: Revealed type is "django.db.models.query.QuerySet[myapp.models.Article@AnnotatedWith[TypedDict({'every_tags': builtins.list[myapp.models.Tag]})], myapp.models.Article@AnnotatedWith[TypedDict({'every_tags': builtins.list[myapp.models.Tag]})]]"
77102
78103
# Mixed inline `Prefetch` and variable `Prefetch` in one call
79104
mixed_qs = Article.objects.prefetch_related(
80105
tag_prefetch,
81106
Prefetch("article_set", Article.objects.all(), to_attr="arts2"),
82107
)
83-
reveal_type(mixed_qs) # N: Revealed type is "django.db.models.query.QuerySet[myapp.models.Article@AnnotatedWith[TypedDict({'arts2': builtins.list[myapp.models.Article]})], myapp.models.Article@AnnotatedWith[TypedDict({'arts2': builtins.list[myapp.models.Article]})]]"
108+
reveal_type(mixed_qs) # N: Revealed type is "django.db.models.query.QuerySet[myapp.models.Article@AnnotatedWith[TypedDict({'every_tags': builtins.list[myapp.models.Tag], 'arts2': builtins.list[myapp.models.Article]})], myapp.models.Article@AnnotatedWith[TypedDict({'every_tags': builtins.list[myapp.models.Tag], 'arts2': builtins.list[myapp.models.Article]})]]"
109+
110+
# Prefetch with `to_attr` arg but without the `queryset` arg
111+
reveal_type(Article.objects.prefetch_related(models.Prefetch("tags", to_attr="just_tags")).get().just_tags) # N: Revealed type is "builtins.list[myapp.models.Tag]"
112+
113+
# Prefetch with annotated `queryset`
114+
reveal_type(Article.objects.prefetch_related( # N: Revealed type is "builtins.list[myapp.models.Tag@AnnotatedWith[TypedDict({'foo': Any})]]"
115+
models.Prefetch("tags", Tag.objects.annotate(foo=F("id")).all(), to_attr="just_tags")).get().just_tags
116+
)
117+
118+
class MyDict(TypedDict):
119+
foo: str
120+
121+
def get_prefetch_with_annotations() -> Prefetch[QuerySet[WithAnnotations[Tag, MyDict]], Literal["every_tags"]]:
122+
return Prefetch("tags", Tag.objects.annotate(foo=F("id")), "every_tags")
84123
124+
reveal_type(Article.objects.prefetch_related(get_prefetch_with_annotations()).all()) # N: Revealed type is "django.db.models.query.QuerySet[myapp.models.Article@AnnotatedWith[TypedDict({'every_tags': builtins.list[myapp.models.Tag@AnnotatedWith[TypedDict('main.MyDict', {'foo': builtins.str})]]})], myapp.models.Article@AnnotatedWith[TypedDict({'every_tags': builtins.list[myapp.models.Tag@AnnotatedWith[TypedDict('main.MyDict', {'foo': builtins.str})]]})]]"
125+
reveal_type(Article.objects.prefetch_related(get_prefetch_with_annotations()).get().every_tags[0].foo) # N: Revealed type is "builtins.str"
85126
installed_apps:
86127
- myapp
87128
files:

0 commit comments

Comments
 (0)