Skip to content

Commit 8c89bd7

Browse files
committed
move querysets to new API
1 parent 640afe3 commit 8c89bd7

File tree

3 files changed

+83
-92
lines changed

3 files changed

+83
-92
lines changed

mypy_django_plugin/main.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import configparser
2-
from functools import partial
32
from typing import Callable, Dict, List, Optional, Tuple
43

54
from django.db.models.fields.related import RelatedField
@@ -13,7 +12,6 @@
1312

1413
from mypy_django_plugin.django.context import DjangoContext
1514
from mypy_django_plugin.lib import fullnames, helpers
16-
from mypy_django_plugin.transformers import querysets
1715
from mypy_django_plugin.transformers2.fields import FieldContructorCallback
1816
from mypy_django_plugin.transformers2.forms import (
1917
FormCallback, GetFormCallback, GetFormClassCallback,
@@ -26,6 +24,9 @@
2624
from mypy_django_plugin.transformers2.orm_lookups import (
2725
QuerySetFilterTypecheckCallback,
2826
)
27+
from mypy_django_plugin.transformers2.querysets import (
28+
QuerySetValuesCallback, QuerySetValuesListCallback,
29+
)
2930
from mypy_django_plugin.transformers2.related_managers import (
3031
GetRelatedManagerCallback,
3132
)
@@ -176,7 +177,6 @@ def get_function_hook(self, fullname: str
176177

177178
if self.django_context.is_model_subclass(info):
178179
return ModelInitCallback(self)
179-
# return partial(init_create.redefine_and_typecheck_model_init, django_context=self.django_context)
180180
return None
181181

182182
def get_method_hook(self, fullname: str
@@ -195,12 +195,14 @@ def get_method_hook(self, fullname: str
195195
if method_name == 'values':
196196
info = self._get_typeinfo_or_none(class_fullname)
197197
if info and info.has_base(fullnames.QUERYSET_CLASS_FULLNAME):
198-
return partial(querysets.extract_proper_type_queryset_values, django_context=self.django_context)
198+
return QuerySetValuesCallback(self)
199+
# return partial(querysets.extract_proper_type_queryset_values, django_context=self.django_context)
199200

200201
if method_name == 'values_list':
201202
info = self._get_typeinfo_or_none(class_fullname)
202203
if info and info.has_base(fullnames.QUERYSET_CLASS_FULLNAME):
203-
return partial(querysets.extract_proper_type_queryset_values_list, django_context=self.django_context)
204+
return QuerySetValuesListCallback(self)
205+
# return partial(querysets.extract_proper_type_queryset_values_list, django_context=self.django_context)
204206

205207
if method_name == 'get_field':
206208
info = self._get_typeinfo_or_none(class_fullname)

mypy_django_plugin/transformers/__init__.py

Whitespace-only changes.
Lines changed: 76 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
from collections import OrderedDict
1+
import collections
22
from typing import List, Optional, Sequence, Type
33

44
from django.core.exceptions import FieldError
55
from django.db.models.base import Model
66
from django.db.models.fields.related import RelatedField
77
from django.db.models.fields.reverse_related import ForeignObjectRel
88
from mypy.nodes import Expression, NameExpr
9-
from mypy.plugin import FunctionContext, MethodContext
9+
from mypy.plugin import MethodContext
1010
from mypy.types import AnyType, Instance
1111
from mypy.types import Type as MypyType
1212
from mypy.types import TypeOfAny
@@ -17,27 +17,6 @@
1717
from mypy_django_plugin.lib import chk_helpers, fullnames, helpers
1818

1919

20-
def _extract_model_type_from_queryset(queryset_type: Instance) -> Optional[Instance]:
21-
for base_type in [queryset_type, *queryset_type.type.bases]:
22-
if (len(base_type.args)
23-
and isinstance(base_type.args[0], Instance)
24-
and base_type.args[0].type.has_base(fullnames.MODEL_CLASS_FULLNAME)):
25-
return base_type.args[0]
26-
return None
27-
28-
29-
def determine_proper_manager_type(ctx: FunctionContext) -> MypyType:
30-
default_return_type = ctx.default_return_type
31-
assert isinstance(default_return_type, Instance)
32-
33-
outer_model_info = chk_helpers.get_typechecker_api(ctx).scope.active_class()
34-
if (outer_model_info is None
35-
or not outer_model_info.has_base(fullnames.MODEL_CLASS_FULLNAME)):
36-
return default_return_type
37-
38-
return helpers.reparametrize_instance(default_return_type, [Instance(outer_model_info, [])])
39-
40-
4120
def get_field_type_from_lookup(ctx: MethodContext, django_context: DjangoContext, model_cls: Type[Model],
4221
*, method: str, lookup: str) -> Optional[MypyType]:
4322
try:
@@ -60,6 +39,16 @@ def get_field_type_from_lookup(ctx: MethodContext, django_context: DjangoContext
6039
return field_get_type
6140

6241

42+
def resolve_field_lookups(lookup_exprs: Sequence[Expression], django_context: DjangoContext) -> Optional[List[str]]:
43+
field_lookups = []
44+
for field_lookup_expr in lookup_exprs:
45+
field_lookup = helpers.resolve_string_attribute_value(field_lookup_expr, django_context)
46+
if field_lookup is None:
47+
return None
48+
field_lookups.append(field_lookup)
49+
return field_lookups
50+
51+
6352
def get_values_list_row_type(ctx: MethodContext, django_context: DjangoContext, model_cls: Type[Model],
6453
flat: bool, named: bool) -> MypyType:
6554
field_lookups = resolve_field_lookups(ctx.args[0], django_context)
@@ -75,7 +64,7 @@ def get_values_list_row_type(ctx: MethodContext, django_context: DjangoContext,
7564
assert lookup_type is not None
7665
return lookup_type
7766
elif named:
78-
column_types: 'OrderedDict[str, MypyType]' = OrderedDict()
67+
column_types = collections.OrderedDict()
7968
for field in django_context.get_model_fields(model_cls):
8069
column_type = django_context.get_field_get_type(typechecker_api, field,
8170
method='values_list')
@@ -91,7 +80,7 @@ def get_values_list_row_type(ctx: MethodContext, django_context: DjangoContext,
9180
typechecker_api.fail("'flat' is not valid when 'values_list' is called with more than one field", ctx.context)
9281
return AnyType(TypeOfAny.from_error)
9382

94-
column_types = OrderedDict()
83+
column_types = collections.OrderedDict()
9584
for field_lookup in field_lookups:
9685
lookup_field_type = get_field_type_from_lookup(ctx, django_context, model_cls,
9786
lookup=field_lookup, method='values_list')
@@ -110,83 +99,83 @@ def get_values_list_row_type(ctx: MethodContext, django_context: DjangoContext,
11099
return row_type
111100

112101

113-
def extract_proper_type_queryset_values_list(ctx: MethodContext, django_context: DjangoContext) -> MypyType:
114-
# called on the Instance, returns QuerySet of something
115-
assert isinstance(ctx.type, Instance)
116-
assert isinstance(ctx.default_return_type, Instance)
102+
class QuerySetMethodCallback(helpers.GetMethodCallback):
103+
def current_model_type(self) -> Optional[Instance]:
104+
for base_type in [self.callee_type, *self.callee_type.type.bases]:
105+
if (len(base_type.args)
106+
and isinstance(base_type.args[0], Instance)
107+
and base_type.args[0].type.has_base(fullnames.MODEL_CLASS_FULLNAME)):
108+
return base_type.args[0]
109+
return None
117110

118-
model_type = _extract_model_type_from_queryset(ctx.type)
119-
if model_type is None:
120-
return AnyType(TypeOfAny.from_omitted_generics)
121111

122-
model_cls = django_context.get_model_class_by_fullname(model_type.type.fullname)
123-
if model_cls is None:
124-
return ctx.default_return_type
112+
class QuerySetValuesCallback(QuerySetMethodCallback):
113+
def get_method_return_type(self) -> MypyType:
114+
assert isinstance(self.default_return_type, Instance)
125115

126-
flat_expr = chk_helpers.get_call_argument_by_name(ctx, 'flat')
127-
if flat_expr is not None and isinstance(flat_expr, NameExpr):
128-
flat = helpers.parse_bool(flat_expr)
129-
else:
130-
flat = False
116+
model_type = self.current_model_type()
117+
if model_type is None:
118+
return AnyType(TypeOfAny.from_omitted_generics)
131119

132-
named_expr = chk_helpers.get_call_argument_by_name(ctx, 'named')
133-
if named_expr is not None and isinstance(named_expr, NameExpr):
134-
named = helpers.parse_bool(named_expr)
135-
else:
136-
named = False
120+
model_cls = self.django_context.get_model_class_by_fullname(model_type.type.fullname)
121+
if model_cls is None:
122+
return self.default_return_type
137123

138-
if flat and named:
139-
ctx.api.fail("'flat' and 'named' can't be used together", ctx.context)
140-
return helpers.reparametrize_instance(ctx.default_return_type, [model_type, AnyType(TypeOfAny.from_error)])
124+
field_lookups = resolve_field_lookups(self.ctx.args[0], self.django_context)
125+
if field_lookups is None:
126+
return AnyType(TypeOfAny.from_error)
141127

142-
# account for possible None
143-
flat = flat or False
144-
named = named or False
128+
if len(field_lookups) == 0:
129+
for field in self.django_context.get_model_fields(model_cls):
130+
field_lookups.append(field.attname)
145131

146-
row_type = get_values_list_row_type(ctx, django_context, model_cls,
147-
flat=flat, named=named)
148-
return helpers.reparametrize_instance(ctx.default_return_type, [model_type, row_type])
132+
column_types = collections.OrderedDict()
133+
for field_lookup in field_lookups:
134+
field_lookup_type = get_field_type_from_lookup(self.ctx, self.django_context, model_cls,
135+
lookup=field_lookup, method='values')
136+
if field_lookup_type is None:
137+
return helpers.reparametrize_instance(self.default_return_type,
138+
[model_type, AnyType(TypeOfAny.from_error)])
149139

140+
column_types[field_lookup] = field_lookup_type
150141

151-
def resolve_field_lookups(lookup_exprs: Sequence[Expression], django_context: DjangoContext) -> Optional[List[str]]:
152-
field_lookups = []
153-
for field_lookup_expr in lookup_exprs:
154-
field_lookup = helpers.resolve_string_attribute_value(field_lookup_expr, django_context)
155-
if field_lookup is None:
156-
return None
157-
field_lookups.append(field_lookup)
158-
return field_lookups
142+
row_type = chk_helpers.make_oneoff_typeddict(self.ctx.api, column_types, set(column_types.keys()))
143+
return helpers.reparametrize_instance(self.default_return_type, [model_type, row_type])
159144

160145

161-
def extract_proper_type_queryset_values(ctx: MethodContext, django_context: DjangoContext) -> MypyType:
162-
# called on QuerySet, return QuerySet of something
163-
assert isinstance(ctx.type, Instance)
164-
assert isinstance(ctx.default_return_type, Instance)
146+
class QuerySetValuesListCallback(QuerySetMethodCallback):
147+
def get_method_return_type(self) -> MypyType:
148+
# called on the Instance, returns QuerySet of something
149+
assert isinstance(self.default_return_type, Instance)
165150

166-
model_type = _extract_model_type_from_queryset(ctx.type)
167-
if model_type is None:
168-
return AnyType(TypeOfAny.from_omitted_generics)
151+
model_type = self.current_model_type()
152+
if model_type is None:
153+
return AnyType(TypeOfAny.from_omitted_generics)
169154

170-
model_cls = django_context.get_model_class_by_fullname(model_type.type.fullname)
171-
if model_cls is None:
172-
return ctx.default_return_type
155+
model_cls = self.django_context.get_model_class_by_fullname(model_type.type.fullname)
156+
if model_cls is None:
157+
return self.default_return_type
173158

174-
field_lookups = resolve_field_lookups(ctx.args[0], django_context)
175-
if field_lookups is None:
176-
return AnyType(TypeOfAny.from_error)
159+
flat_expr = chk_helpers.get_call_argument_by_name(self.ctx, 'flat')
160+
if flat_expr is not None and isinstance(flat_expr, NameExpr):
161+
flat = helpers.parse_bool(flat_expr)
162+
else:
163+
flat = False
177164

178-
if len(field_lookups) == 0:
179-
for field in django_context.get_model_fields(model_cls):
180-
field_lookups.append(field.attname)
165+
named_expr = chk_helpers.get_call_argument_by_name(self.ctx, 'named')
166+
if named_expr is not None and isinstance(named_expr, NameExpr):
167+
named = helpers.parse_bool(named_expr)
168+
else:
169+
named = False
181170

182-
column_types: 'OrderedDict[str, MypyType]' = OrderedDict()
183-
for field_lookup in field_lookups:
184-
field_lookup_type = get_field_type_from_lookup(ctx, django_context, model_cls,
185-
lookup=field_lookup, method='values')
186-
if field_lookup_type is None:
187-
return helpers.reparametrize_instance(ctx.default_return_type, [model_type, AnyType(TypeOfAny.from_error)])
171+
if flat and named:
172+
self.ctx.api.fail("'flat' and 'named' can't be used together", self.ctx.context)
173+
return helpers.reparametrize_instance(self.default_return_type, [model_type, AnyType(TypeOfAny.from_error)])
188174

189-
column_types[field_lookup] = field_lookup_type
175+
# account for possible None
176+
flat = flat or False
177+
named = named or False
190178

191-
row_type = chk_helpers.make_oneoff_typeddict(ctx.api, column_types, set(column_types.keys()))
192-
return helpers.reparametrize_instance(ctx.default_return_type, [model_type, row_type])
179+
row_type = get_values_list_row_type(self.ctx, self.django_context, model_cls,
180+
flat=flat, named=named)
181+
return helpers.reparametrize_instance(self.default_return_type, [model_type, row_type])

0 commit comments

Comments
 (0)