1- from collections import OrderedDict
1+ import collections
22from typing import List , Optional , Sequence , Type
33
44from django .core .exceptions import FieldError
55from django .db .models .base import Model
66from django .db .models .fields .related import RelatedField
77from django .db .models .fields .reverse_related import ForeignObjectRel
88from mypy .nodes import Expression , NameExpr
9- from mypy .plugin import FunctionContext , MethodContext
9+ from mypy .plugin import MethodContext
1010from mypy .types import AnyType , Instance
1111from mypy .types import Type as MypyType
1212from mypy .types import TypeOfAny
1717from mypy_django_plugin .lib import chk_helpers , fullnames , helpers
1818
1919
20- def _extract_model_type_from_queryset (queryset_type : Instance ) -> Optional [Instance ]:
21- for base_type in [queryset_type , * queryset_type .type .bases ]:
22- if (len (base_type .args )
23- and isinstance (base_type .args [0 ], Instance )
24- and base_type .args [0 ].type .has_base (fullnames .MODEL_CLASS_FULLNAME )):
25- return base_type .args [0 ]
26- return None
27-
28-
29- def determine_proper_manager_type (ctx : FunctionContext ) -> MypyType :
30- default_return_type = ctx .default_return_type
31- assert isinstance (default_return_type , Instance )
32-
33- outer_model_info = chk_helpers .get_typechecker_api (ctx ).scope .active_class ()
34- if (outer_model_info is None
35- or not outer_model_info .has_base (fullnames .MODEL_CLASS_FULLNAME )):
36- return default_return_type
37-
38- return helpers .reparametrize_instance (default_return_type , [Instance (outer_model_info , [])])
39-
40-
4120def get_field_type_from_lookup (ctx : MethodContext , django_context : DjangoContext , model_cls : Type [Model ],
4221 * , method : str , lookup : str ) -> Optional [MypyType ]:
4322 try :
@@ -60,6 +39,16 @@ def get_field_type_from_lookup(ctx: MethodContext, django_context: DjangoContext
6039 return field_get_type
6140
6241
42+ def resolve_field_lookups (lookup_exprs : Sequence [Expression ], django_context : DjangoContext ) -> Optional [List [str ]]:
43+ field_lookups = []
44+ for field_lookup_expr in lookup_exprs :
45+ field_lookup = helpers .resolve_string_attribute_value (field_lookup_expr , django_context )
46+ if field_lookup is None :
47+ return None
48+ field_lookups .append (field_lookup )
49+ return field_lookups
50+
51+
6352def get_values_list_row_type (ctx : MethodContext , django_context : DjangoContext , model_cls : Type [Model ],
6453 flat : bool , named : bool ) -> MypyType :
6554 field_lookups = resolve_field_lookups (ctx .args [0 ], django_context )
@@ -75,7 +64,7 @@ def get_values_list_row_type(ctx: MethodContext, django_context: DjangoContext,
7564 assert lookup_type is not None
7665 return lookup_type
7766 elif named :
78- column_types : 'OrderedDict[str, MypyType]' = OrderedDict ()
67+ column_types = collections . OrderedDict ()
7968 for field in django_context .get_model_fields (model_cls ):
8069 column_type = django_context .get_field_get_type (typechecker_api , field ,
8170 method = 'values_list' )
@@ -91,7 +80,7 @@ def get_values_list_row_type(ctx: MethodContext, django_context: DjangoContext,
9180 typechecker_api .fail ("'flat' is not valid when 'values_list' is called with more than one field" , ctx .context )
9281 return AnyType (TypeOfAny .from_error )
9382
94- column_types = OrderedDict ()
83+ column_types = collections . OrderedDict ()
9584 for field_lookup in field_lookups :
9685 lookup_field_type = get_field_type_from_lookup (ctx , django_context , model_cls ,
9786 lookup = field_lookup , method = 'values_list' )
@@ -110,83 +99,83 @@ def get_values_list_row_type(ctx: MethodContext, django_context: DjangoContext,
11099 return row_type
111100
112101
113- def extract_proper_type_queryset_values_list (ctx : MethodContext , django_context : DjangoContext ) -> MypyType :
114- # called on the Instance, returns QuerySet of something
115- assert isinstance (ctx .type , Instance )
116- assert isinstance (ctx .default_return_type , Instance )
102+ class QuerySetMethodCallback (helpers .GetMethodCallback ):
103+ def current_model_type (self ) -> Optional [Instance ]:
104+ for base_type in [self .callee_type , * self .callee_type .type .bases ]:
105+ if (len (base_type .args )
106+ and isinstance (base_type .args [0 ], Instance )
107+ and base_type .args [0 ].type .has_base (fullnames .MODEL_CLASS_FULLNAME )):
108+ return base_type .args [0 ]
109+ return None
117110
118- model_type = _extract_model_type_from_queryset (ctx .type )
119- if model_type is None :
120- return AnyType (TypeOfAny .from_omitted_generics )
121111
122- model_cls = django_context . get_model_class_by_fullname ( model_type . type . fullname )
123- if model_cls is None :
124- return ctx .default_return_type
112+ class QuerySetValuesCallback ( QuerySetMethodCallback ):
113+ def get_method_return_type ( self ) -> MypyType :
114+ assert isinstance ( self .default_return_type , Instance )
125115
126- flat_expr = chk_helpers .get_call_argument_by_name (ctx , 'flat' )
127- if flat_expr is not None and isinstance (flat_expr , NameExpr ):
128- flat = helpers .parse_bool (flat_expr )
129- else :
130- flat = False
116+ model_type = self .current_model_type ()
117+ if model_type is None :
118+ return AnyType (TypeOfAny .from_omitted_generics )
131119
132- named_expr = chk_helpers .get_call_argument_by_name (ctx , 'named' )
133- if named_expr is not None and isinstance (named_expr , NameExpr ):
134- named = helpers .parse_bool (named_expr )
135- else :
136- named = False
120+ model_cls = self .django_context .get_model_class_by_fullname (model_type .type .fullname )
121+ if model_cls is None :
122+ return self .default_return_type
137123
138- if flat and named :
139- ctx . api . fail ( "'flat' and 'named' can't be used together" , ctx . context )
140- return helpers . reparametrize_instance ( ctx . default_return_type , [ model_type , AnyType (TypeOfAny .from_error )] )
124+ field_lookups = resolve_field_lookups ( self . ctx . args [ 0 ], self . django_context )
125+ if field_lookups is None :
126+ return AnyType (TypeOfAny .from_error )
141127
142- # account for possible None
143- flat = flat or False
144- named = named or False
128+ if len ( field_lookups ) == 0 :
129+ for field in self . django_context . get_model_fields ( model_cls ):
130+ field_lookups . append ( field . attname )
145131
146- row_type = get_values_list_row_type (ctx , django_context , model_cls ,
147- flat = flat , named = named )
148- return helpers .reparametrize_instance (ctx .default_return_type , [model_type , row_type ])
132+ column_types = collections .OrderedDict ()
133+ for field_lookup in field_lookups :
134+ field_lookup_type = get_field_type_from_lookup (self .ctx , self .django_context , model_cls ,
135+ lookup = field_lookup , method = 'values' )
136+ if field_lookup_type is None :
137+ return helpers .reparametrize_instance (self .default_return_type ,
138+ [model_type , AnyType (TypeOfAny .from_error )])
149139
140+ column_types [field_lookup ] = field_lookup_type
150141
151- def resolve_field_lookups (lookup_exprs : Sequence [Expression ], django_context : DjangoContext ) -> Optional [List [str ]]:
152- field_lookups = []
153- for field_lookup_expr in lookup_exprs :
154- field_lookup = helpers .resolve_string_attribute_value (field_lookup_expr , django_context )
155- if field_lookup is None :
156- return None
157- field_lookups .append (field_lookup )
158- return field_lookups
142+ row_type = chk_helpers .make_oneoff_typeddict (self .ctx .api , column_types , set (column_types .keys ()))
143+ return helpers .reparametrize_instance (self .default_return_type , [model_type , row_type ])
159144
160145
161- def extract_proper_type_queryset_values ( ctx : MethodContext , django_context : DjangoContext ) -> MypyType :
162- # called on QuerySet, return QuerySet of something
163- assert isinstance ( ctx . type , Instance )
164- assert isinstance (ctx .default_return_type , Instance )
146+ class QuerySetValuesListCallback ( QuerySetMethodCallback ) :
147+ def get_method_return_type ( self ) -> MypyType :
148+ # called on the Instance, returns QuerySet of something
149+ assert isinstance (self .default_return_type , Instance )
165150
166- model_type = _extract_model_type_from_queryset ( ctx . type )
167- if model_type is None :
168- return AnyType (TypeOfAny .from_omitted_generics )
151+ model_type = self . current_model_type ( )
152+ if model_type is None :
153+ return AnyType (TypeOfAny .from_omitted_generics )
169154
170- model_cls = django_context .get_model_class_by_fullname (model_type .type .fullname )
171- if model_cls is None :
172- return ctx .default_return_type
155+ model_cls = self . django_context .get_model_class_by_fullname (model_type .type .fullname )
156+ if model_cls is None :
157+ return self .default_return_type
173158
174- field_lookups = resolve_field_lookups (ctx .args [0 ], django_context )
175- if field_lookups is None :
176- return AnyType (TypeOfAny .from_error )
159+ flat_expr = chk_helpers .get_call_argument_by_name (self .ctx , 'flat' )
160+ if flat_expr is not None and isinstance (flat_expr , NameExpr ):
161+ flat = helpers .parse_bool (flat_expr )
162+ else :
163+ flat = False
177164
178- if len (field_lookups ) == 0 :
179- for field in django_context .get_model_fields (model_cls ):
180- field_lookups .append (field .attname )
165+ named_expr = chk_helpers .get_call_argument_by_name (self .ctx , 'named' )
166+ if named_expr is not None and isinstance (named_expr , NameExpr ):
167+ named = helpers .parse_bool (named_expr )
168+ else :
169+ named = False
181170
182- column_types : 'OrderedDict[str, MypyType]' = OrderedDict ()
183- for field_lookup in field_lookups :
184- field_lookup_type = get_field_type_from_lookup (ctx , django_context , model_cls ,
185- lookup = field_lookup , method = 'values' )
186- if field_lookup_type is None :
187- return helpers .reparametrize_instance (ctx .default_return_type , [model_type , AnyType (TypeOfAny .from_error )])
171+ if flat and named :
172+ self .ctx .api .fail ("'flat' and 'named' can't be used together" , self .ctx .context )
173+ return helpers .reparametrize_instance (self .default_return_type , [model_type , AnyType (TypeOfAny .from_error )])
188174
189- column_types [field_lookup ] = field_lookup_type
175+ # account for possible None
176+ flat = flat or False
177+ named = named or False
190178
191- row_type = chk_helpers .make_oneoff_typeddict (ctx .api , column_types , set (column_types .keys ()))
192- return helpers .reparametrize_instance (ctx .default_return_type , [model_type , row_type ])
179+ row_type = get_values_list_row_type (self .ctx , self .django_context , model_cls ,
180+ flat = flat , named = named )
181+ return helpers .reparametrize_instance (self .default_return_type , [model_type , row_type ])
0 commit comments