Skip to content

Commit ffc48c0

Browse files
committed
Reparametrize managers without explicit type parameters
This extracts the reparametrization logic from typeddjango#1030 in addition to removing the codepath that copied methods from querysets to managers. That code path seems to not be needed with this change.
1 parent db14454 commit ffc48c0

File tree

9 files changed

+102
-161
lines changed

9 files changed

+102
-161
lines changed

django-stubs/contrib/sessions/base_session.pyi

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
from datetime import datetime
2-
from typing import Any, Dict, Optional, Type
2+
from typing import Any, Dict, Optional, Type, TypeVar
33

44
from django.contrib.sessions.backends.base import SessionBase
55
from django.db import models
66

7-
class BaseSessionManager(models.Manager):
7+
_T = TypeVar("_T", bound="AbstractBaseSession")
8+
9+
class BaseSessionManager(models.Manager[_T]):
810
def encode(self, session_dict: Dict[str, int]) -> str: ...
9-
def save(self, session_key: str, session_dict: Dict[str, int], expire_date: datetime) -> AbstractBaseSession: ...
11+
def save(self, session_key: str, session_dict: Dict[str, int], expire_date: datetime) -> _T: ...
1012

1113
class AbstractBaseSession(models.Model):
1214
expire_date: datetime
Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1+
from typing import TypeVar
2+
13
from django.contrib.sessions.base_session import AbstractBaseSession, BaseSessionManager
24

3-
class SessionManager(BaseSessionManager): ...
5+
_T = TypeVar("_T", bound="Session")
6+
7+
class SessionManager(BaseSessionManager[_T]): ...
48
class Session(AbstractBaseSession): ...
Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1-
from typing import Optional
1+
from typing import Optional, TypeVar
22

3+
from django.contrib.sites.models import Site
34
from django.db import models
45

5-
class CurrentSiteManager(models.Manager):
6+
_T = TypeVar("_T", bound=Site)
7+
8+
class CurrentSiteManager(models.Manager[_T]):
69
def __init__(self, field_name: Optional[str] = ...) -> None: ...

mypy_django_plugin/lib/helpers.py

Lines changed: 3 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from collections import OrderedDict
2-
from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple, Union
2+
from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, Set, Union
33

44
from django.db.models.fields import Field
55
from django.db.models.fields.related import RelatedField
@@ -10,11 +10,9 @@
1010
from mypy.nodes import (
1111
GDEF,
1212
MDEF,
13-
Argument,
1413
Block,
1514
ClassDef,
1615
Expression,
17-
FuncDef,
1816
MemberExpr,
1917
MypyFile,
2018
NameExpr,
@@ -34,11 +32,10 @@
3432
MethodContext,
3533
SemanticAnalyzerPluginInterface,
3634
)
37-
from mypy.plugins.common import add_method_to_class
3835
from mypy.semanal import SemanticAnalyzer
39-
from mypy.types import AnyType, CallableType, Instance, NoneTyp, TupleType
36+
from mypy.types import AnyType, Instance, NoneTyp, TupleType
4037
from mypy.types import Type as MypyType
41-
from mypy.types import TypedDictType, TypeOfAny, UnboundType, UnionType
38+
from mypy.types import TypedDictType, TypeOfAny, UnionType
4239

4340
from mypy_django_plugin.lib import fullnames
4441
from mypy_django_plugin.lib.fullnames import WITH_ANNOTATIONS_FULLNAME
@@ -361,86 +358,6 @@ def add_new_sym_for_info(info: TypeInfo, *, name: str, sym_type: MypyType, no_se
361358
info.names[name] = SymbolTableNode(MDEF, var, plugin_generated=True, no_serialize=no_serialize)
362359

363360

364-
def build_unannotated_method_args(method_node: FuncDef) -> Tuple[List[Argument], MypyType]:
365-
prepared_arguments = []
366-
try:
367-
arguments = method_node.arguments[1:]
368-
except AttributeError:
369-
arguments = []
370-
for argument in arguments:
371-
argument.type_annotation = AnyType(TypeOfAny.unannotated)
372-
prepared_arguments.append(argument)
373-
return_type = AnyType(TypeOfAny.unannotated)
374-
return prepared_arguments, return_type
375-
376-
377-
def bind_or_analyze_type(t: MypyType, api: SemanticAnalyzer, module_name: Optional[str] = None) -> Optional[MypyType]:
378-
"""Analyze a type. If an unbound type, try to look it up in the given module name.
379-
380-
That should hopefully give a bound type."""
381-
if isinstance(t, UnboundType) and module_name is not None:
382-
node = api.lookup_fully_qualified_or_none(module_name + "." + t.name)
383-
if node is not None and node.type is not None:
384-
return node.type
385-
386-
return api.anal_type(t)
387-
388-
389-
def copy_method_to_another_class(
390-
api: SemanticAnalyzer,
391-
cls: ClassDef,
392-
self_type: Instance,
393-
new_method_name: str,
394-
method_node: FuncDef,
395-
return_type: Optional[MypyType] = None,
396-
original_module_name: Optional[str] = None,
397-
) -> bool:
398-
if method_node.type is None:
399-
arguments, return_type = build_unannotated_method_args(method_node)
400-
add_method_to_class(api, cls, new_method_name, args=arguments, return_type=return_type, self_type=self_type)
401-
return True
402-
403-
method_type = method_node.type
404-
if not isinstance(method_type, CallableType):
405-
if not api.final_iteration:
406-
api.defer()
407-
return False
408-
409-
if return_type is None:
410-
return_type = bind_or_analyze_type(method_type.ret_type, api, original_module_name)
411-
if return_type is None:
412-
return False
413-
414-
# We build the arguments from the method signature (`CallableType`), because if we were to
415-
# use the arguments from the method node (`FuncDef.arguments`) we're not compatible with
416-
# a method loaded from cache. As mypy doesn't serialize `FuncDef.arguments` when caching
417-
arguments = []
418-
# Note that the first argument is excluded, as that's `self`
419-
for pos, (arg_type, arg_kind, arg_name) in enumerate(
420-
zip(method_type.arg_types[1:], method_type.arg_kinds[1:], method_type.arg_names[1:]),
421-
start=1,
422-
):
423-
bound_arg_type = bind_or_analyze_type(arg_type, api, original_module_name)
424-
if bound_arg_type is None:
425-
return False
426-
if arg_name is None and hasattr(method_node, "arguments"):
427-
arg_name = method_node.arguments[pos].variable.name
428-
arguments.append(
429-
Argument(
430-
# Positional only arguments can have name as `None`, if we can't find a name, we just invent one..
431-
variable=Var(name=arg_name if arg_name is not None else str(pos), type=arg_type),
432-
type_annotation=bound_arg_type,
433-
initializer=None,
434-
kind=arg_kind,
435-
pos_only=arg_name is None,
436-
)
437-
)
438-
439-
add_method_to_class(api, cls, new_method_name, args=arguments, return_type=return_type, self_type=self_type)
440-
441-
return True
442-
443-
444361
def add_new_manager_base(api: SemanticAnalyzerPluginInterface, fullname: str) -> None:
445362
sym = api.lookup_fully_qualified_or_none(fullnames.MANAGER_CLASS_FULLNAME)
446363
if sym is not None and isinstance(sym.node, TypeInfo):

mypy_django_plugin/main.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from mypy_django_plugin.transformers.managers import (
2727
create_new_manager_class_from_as_manager_method,
2828
create_new_manager_class_from_from_queryset_method,
29+
reparametrize_any_manager_hook,
2930
resolve_manager_method,
3031
)
3132
from mypy_django_plugin.transformers.models import (
@@ -240,6 +241,15 @@ def get_method_hook(self, fullname: str) -> Optional[Callable[[MethodContext], M
240241

241242
return None
242243

244+
def get_customize_class_mro_hook(self, fullname: str) -> Optional[Callable[[ClassDefContext], None]]:
245+
sym = self.lookup_fully_qualified(fullname)
246+
if (
247+
sym is not None
248+
and isinstance(sym.node, TypeInfo)
249+
and sym.node.has_base(fullnames.BASE_MANAGER_CLASS_FULLNAME)
250+
):
251+
return reparametrize_any_manager_hook
252+
243253
def get_base_class_hook(self, fullname: str) -> Optional[Callable[[ClassDefContext], None]]:
244254
# Base class is a Model class definition
245255
if (

mypy_django_plugin/transformers/managers.py

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
TypeInfo,
1616
Var,
1717
)
18-
from mypy.plugin import AttributeContext, DynamicClassDefContext
18+
from mypy.plugin import AttributeContext, ClassDefContext, DynamicClassDefContext
1919
from mypy.semanal import SemanticAnalyzer
2020
from mypy.semanal_shared import has_placeholder
2121
from mypy.types import AnyType, CallableType, Instance, ProperType
@@ -466,3 +466,63 @@ def create_new_manager_class_from_as_manager_method(ctx: DynamicClassDefContext)
466466
# Note that the generated manager type is always inserted at module level
467467
SymbolTableNode(GDEF, new_manager_info, plugin_generated=True),
468468
)
469+
470+
471+
def reparametrize_any_manager_hook(ctx: ClassDefContext) -> None:
472+
"""
473+
Add explicit generics to manager classes that are defined without generic.
474+
475+
Eg.
476+
477+
class MyManager(models.Manager): ...
478+
479+
is interpreted as::
480+
_T = TypeVar('_T', covariant=True)
481+
class MyManager(models.Manager[_T]): ...
482+
"""
483+
484+
manager = ctx.api.lookup_fully_qualified_or_none(ctx.cls.fullname)
485+
if manager is None or manager.node is None:
486+
return
487+
assert isinstance(manager.node, TypeInfo)
488+
489+
if manager.node.type_vars:
490+
# We've already been here
491+
return
492+
493+
parent_manager = next(
494+
(base for base in manager.node.bases if base.type.has_base(fullnames.BASE_MANAGER_CLASS_FULLNAME)),
495+
None,
496+
)
497+
if parent_manager is None:
498+
return
499+
500+
preserve_typevars = (
501+
# If args are missing, but tvars present, don't ignore: we have to reparametrize
502+
not parent_manager.type.type_vars
503+
or parent_manager.args
504+
and (
505+
not isinstance(parent_manager.args[0], AnyType) or parent_manager.args[0].type_of_any == TypeOfAny.explicit
506+
)
507+
)
508+
if preserve_typevars:
509+
return
510+
511+
base_manager = ctx.api.lookup_fully_qualified_or_none(fullnames.BASE_MANAGER_CLASS_FULLNAME)
512+
if base_manager is None:
513+
if not ctx.api.final_iteration:
514+
ctx.api.defer()
515+
return
516+
assert isinstance(base_manager.node, TypeInfo)
517+
518+
tvars = tuple(base_manager.node.defn.type_vars)
519+
# For some reason, we have to defer now, otherwise `defer` is called in other place
520+
# on final iteration (`SemanticAnalyzer.analyze_func_def`).
521+
if any(has_placeholder(tvar) for tvar in tvars):
522+
assert not ctx.api.final_iteration, "Too late to reparametrize"
523+
ctx.api.defer()
524+
525+
parent_manager.args = tvars
526+
manager.node.type_vars = []
527+
manager.node.defn.type_vars = list(tvars)
528+
manager.node.add_type_vars()

mypy_django_plugin/transformers/models.py

Lines changed: 4 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -282,45 +282,6 @@ def has_any_parametrized_manager_as_base(self, info: TypeInfo) -> bool:
282282
def is_any_parametrized_manager(self, typ: Instance) -> bool:
283283
return typ.type.fullname in fullnames.MANAGER_CLASSES and isinstance(typ.args[0], AnyType)
284284

285-
def create_new_model_parametrized_manager(self, name: str, base_manager_info: TypeInfo) -> Instance:
286-
bases = []
287-
for original_base in base_manager_info.bases:
288-
if self.is_any_parametrized_manager(original_base):
289-
original_base = helpers.reparametrize_instance(original_base, [Instance(self.model_classdef.info, [])])
290-
bases.append(original_base)
291-
292-
# TODO: This adds the manager to the module, even if we end up
293-
# deferring. That can be avoided by not adding it to the module first,
294-
# but rather waiting until we know we won't defer
295-
new_manager_info = self.add_new_class_for_current_module(name, bases)
296-
# copy fields to a new manager
297-
custom_manager_type = Instance(new_manager_info, [Instance(self.model_classdef.info, [])])
298-
299-
for name, sym in base_manager_info.names.items():
300-
# replace self type with new class, if copying method
301-
if isinstance(sym.node, FuncDef):
302-
copied_method = helpers.copy_method_to_another_class(
303-
api=self.api,
304-
cls=new_manager_info.defn,
305-
self_type=custom_manager_type,
306-
new_method_name=name,
307-
method_node=sym.node,
308-
original_module_name=base_manager_info.module_name,
309-
)
310-
if not copied_method and not self.api.final_iteration:
311-
raise helpers.IncompleteDefnException()
312-
continue
313-
314-
new_sym = sym.copy()
315-
if isinstance(new_sym.node, Var):
316-
new_var = Var(name, type=sym.type)
317-
new_var.info = new_manager_info
318-
new_var._fullname = new_manager_info.fullname + "." + name
319-
new_sym.node = new_var
320-
new_manager_info.names[name] = new_sym
321-
322-
return custom_manager_type
323-
324285
def lookup_manager(self, fullname: str, manager: "Manager[Any]") -> Optional[TypeInfo]:
325286
manager_info = self.lookup_typeinfo(fullname)
326287
if manager_info is None:
@@ -354,33 +315,17 @@ def run_with_model_cls(self, model_cls: Type[Model]) -> None:
354315
# Manager is already typed -> do nothing unless it's a dynamically generated manager
355316
self.reparametrize_dynamically_created_manager(manager_name, manager_info)
356317
continue
357-
elif manager_info is None:
318+
319+
if manager_info is None:
358320
# We couldn't find a manager type, see if we should create one
359321
manager_info = self.create_manager_from_from_queryset(manager_name)
360322

361323
if manager_info is None:
362324
incomplete_manager_defs.add(manager_name)
363325
continue
364326

365-
if manager_name not in self.model_classdef.info.names or self.is_manager_dynamically_generated(
366-
manager_info
367-
):
368-
manager_type = Instance(manager_info, [Instance(self.model_classdef.info, [])])
369-
self.add_new_node_to_model_class(manager_name, manager_type)
370-
elif self.has_any_parametrized_manager_as_base(manager_info):
371-
# Ending up here could for instance be due to having a custom _Manager_
372-
# that is not built from a custom QuerySet. Another example is a
373-
# related manager.
374-
manager_class_name = manager.__class__.__name__
375-
custom_model_manager_name = manager.model.__name__ + "_" + manager_class_name
376-
try:
377-
manager_type = self.create_new_model_parametrized_manager(
378-
custom_model_manager_name, base_manager_info=manager_info
379-
)
380-
except helpers.IncompleteDefnException:
381-
continue
382-
383-
self.add_new_node_to_model_class(manager_name, manager_type)
327+
manager_type = Instance(manager_info, [Instance(self.model_classdef.info, [])])
328+
self.add_new_node_to_model_class(manager_name, manager_type)
384329

385330
if incomplete_manager_defs:
386331
if not self.api.final_iteration:

tests/typecheck/managers/querysets/test_from_queryset.yml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -162,9 +162,9 @@
162162
- case: from_queryset_returns_intersection_of_manager_and_queryset
163163
main: |
164164
from myapp.models import MyModel, NewManager
165-
reveal_type(NewManager()) # N: Revealed type is "myapp.models.ModelBaseManagerFromModelQuerySet"
165+
reveal_type(NewManager()) # N: Revealed type is "myapp.models.ModelBaseManagerFromModelQuerySet[<nothing>]"
166166
reveal_type(MyModel.objects) # N: Revealed type is "myapp.models.ModelBaseManagerFromModelQuerySet[myapp.models.MyModel]"
167-
reveal_type(MyModel.objects.get()) # N: Revealed type is "Any"
167+
reveal_type(MyModel.objects.get()) # N: Revealed type is "myapp.models.MyModel"
168168
reveal_type(MyModel.objects.manager_only_method()) # N: Revealed type is "builtins.int"
169169
reveal_type(MyModel.objects.manager_and_queryset_method()) # N: Revealed type is "builtins.str"
170170
installed_apps:
@@ -188,12 +188,12 @@
188188
- case: from_queryset_with_class_name_provided
189189
main: |
190190
from myapp.models import MyModel, NewManager, OtherModel, OtherManager
191-
reveal_type(NewManager()) # N: Revealed type is "myapp.models.NewManager"
191+
reveal_type(NewManager()) # N: Revealed type is "myapp.models.NewManager[<nothing>]"
192192
reveal_type(MyModel.objects) # N: Revealed type is "myapp.models.NewManager[myapp.models.MyModel]"
193-
reveal_type(MyModel.objects.get()) # N: Revealed type is "Any"
193+
reveal_type(MyModel.objects.get()) # N: Revealed type is "myapp.models.MyModel"
194194
reveal_type(MyModel.objects.manager_only_method()) # N: Revealed type is "builtins.int"
195195
reveal_type(MyModel.objects.manager_and_queryset_method()) # N: Revealed type is "builtins.str"
196-
reveal_type(OtherManager()) # N: Revealed type is "myapp.models.X"
196+
reveal_type(OtherManager()) # N: Revealed type is "myapp.models.X[<nothing>]"
197197
reveal_type(OtherModel.objects) # N: Revealed type is "myapp.models.X[myapp.models.OtherModel]"
198198
reveal_type(OtherModel.objects.manager_only_method()) # N: Revealed type is "builtins.int"
199199
reveal_type(OtherModel.objects.manager_and_queryset_method()) # N: Revealed type is "builtins.str"

0 commit comments

Comments
 (0)