diff --git a/mypy_django_plugin/django/context.py b/mypy_django_plugin/django/context.py index 36ac0f869..f9dd152b7 100644 --- a/mypy_django_plugin/django/context.py +++ b/mypy_django_plugin/django/context.py @@ -21,7 +21,7 @@ from mypy.types import Type as MypyType from mypy.types import TypeOfAny, UnionType -from mypy_django_plugin.lib import fullnames, helpers +from mypy_django_plugin.lib import chk_helpers, fullnames, helpers try: from django.contrib.postgres.fields import ArrayField @@ -356,11 +356,11 @@ def resolve_lookup_expected_type(self, ctx: MethodContext, model_cls: Type[Model return AnyType(TypeOfAny.explicit) if lookup_cls is None or isinstance(lookup_cls, Exact): - return self.get_field_lookup_exact_type(helpers.get_typechecker_api(ctx), field) + return self.get_field_lookup_exact_type(chk_helpers.get_typechecker_api(ctx), field) assert lookup_cls is not None - lookup_info = helpers.lookup_class_typeinfo(helpers.get_typechecker_api(ctx), lookup_cls) + lookup_info = helpers.lookup_class_typeinfo(chk_helpers.get_typechecker_api(ctx), lookup_cls) if lookup_info is None: return AnyType(TypeOfAny.explicit) @@ -370,7 +370,7 @@ def resolve_lookup_expected_type(self, ctx: MethodContext, model_cls: Type[Model # if it's Field, consider lookup_type a __get__ of current field if (isinstance(lookup_type, Instance) and lookup_type.type.fullname == fullnames.FIELD_FULLNAME): - field_info = helpers.lookup_class_typeinfo(helpers.get_typechecker_api(ctx), field.__class__) + field_info = helpers.lookup_class_typeinfo(chk_helpers.get_typechecker_api(ctx), field.__class__) if field_info is None: return AnyType(TypeOfAny.explicit) lookup_type = helpers.get_private_descriptor_type(field_info, '_pyi_private_get_type', diff --git a/mypy_django_plugin/lib/chk_helpers.py b/mypy_django_plugin/lib/chk_helpers.py new file mode 100644 index 000000000..3f657ff81 --- /dev/null +++ b/mypy_django_plugin/lib/chk_helpers.py @@ -0,0 +1,120 @@ +from typing import Dict, List, Optional, Set, Union + +from mypy import checker +from mypy.checker import TypeChecker +from mypy.nodes import ( + GDEF, MDEF, Expression, MypyFile, SymbolTableNode, TypeInfo, Var, +) +from mypy.plugin import ( + AttributeContext, CheckerPluginInterface, FunctionContext, MethodContext, +) +from mypy.types import AnyType, Instance, TupleType +from mypy.types import Type as MypyType +from mypy.types import TypedDictType, TypeOfAny + +from mypy_django_plugin.lib import helpers + + +def add_new_class_for_current_module(current_module: MypyFile, + name: str, + bases: List[Instance], + fields: Optional[Dict[str, MypyType]] = None + ) -> TypeInfo: + new_class_unique_name = checker.gen_unique_name(name, current_module.names) + new_typeinfo = helpers.new_typeinfo(new_class_unique_name, + bases=bases, + module_name=current_module.fullname) + # new_typeinfo = helpers.make_new_typeinfo_in_current_module(new_class_unique_name, + # bases=bases, + # current_module_fullname=current_module.fullname) + # add fields + if fields: + for field_name, field_type in fields.items(): + var = Var(field_name, type=field_type) + var.info = new_typeinfo + var._fullname = new_typeinfo.fullname + '.' + field_name + new_typeinfo.names[field_name] = SymbolTableNode(MDEF, var, plugin_generated=True) + + current_module.names[new_class_unique_name] = SymbolTableNode(GDEF, new_typeinfo, plugin_generated=True) + current_module.defs.append(new_typeinfo.defn) + return new_typeinfo + + +def make_oneoff_named_tuple(api: TypeChecker, name: str, fields: 'Dict[str, MypyType]') -> TupleType: + current_module = helpers.get_current_module(api) + namedtuple_info = add_new_class_for_current_module(current_module, name, + bases=[api.named_generic_type('typing.NamedTuple', [])], + fields=fields) + return TupleType(list(fields.values()), fallback=Instance(namedtuple_info, [])) + + +def make_tuple(api: 'TypeChecker', fields: List[MypyType]) -> TupleType: + # fallback for tuples is any builtins.tuple instance + fallback = api.named_generic_type('builtins.tuple', + [AnyType(TypeOfAny.special_form)]) + return TupleType(fields, fallback=fallback) + + +def make_oneoff_typeddict(api: CheckerPluginInterface, fields: 'Dict[str, MypyType]', + required_keys: Set[str]) -> TypedDictType: + object_type = api.named_generic_type('mypy_extensions._TypedDict', []) + typed_dict_type = TypedDictType(fields, # type: ignore + required_keys=required_keys, + fallback=object_type) + return typed_dict_type + + +def get_typechecker_api(ctx: Union[AttributeContext, MethodContext, FunctionContext]) -> TypeChecker: + if not isinstance(ctx.api, TypeChecker): + raise ValueError('Not a TypeChecker') + return ctx.api + + +def check_types_compatible(ctx: Union[FunctionContext, MethodContext], + *, expected_type: MypyType, actual_type: MypyType, error_message: str) -> None: + api = get_typechecker_api(ctx) + api.check_subtype(actual_type, expected_type, + ctx.context, error_message, + 'got', 'expected') + + +def get_call_argument_by_name(ctx: Union[FunctionContext, MethodContext], name: str) -> Optional[Expression]: + """ + Return the expression for the specific argument. + This helper should only be used with non-star arguments. + """ + if name not in ctx.callee_arg_names: + return None + idx = ctx.callee_arg_names.index(name) + args = ctx.args[idx] + if len(args) != 1: + # Either an error or no value passed. + return None + return args[0] + + +def get_call_argument_type_by_name(ctx: Union[FunctionContext, MethodContext], name: str) -> Optional[MypyType]: + """Return the type for the specific argument. + + This helper should only be used with non-star arguments. + """ + if name not in ctx.callee_arg_names: + return None + idx = ctx.callee_arg_names.index(name) + arg_types = ctx.arg_types[idx] + if len(arg_types) != 1: + # Either an error or no value passed. + return None + return arg_types[0] + + +def add_new_sym_for_info(info: TypeInfo, *, name: str, sym_type: MypyType) -> None: + # type=: type of the variable itself + var = Var(name=name, type=sym_type) + # var.info: type of the object variable is bound to + var.info = info + var._fullname = info.fullname + '.' + name + var.is_initialized_in_class = True + var.is_inferred = True + info.names[name] = SymbolTableNode(MDEF, var, + plugin_generated=True) diff --git a/mypy_django_plugin/lib/helpers.py b/mypy_django_plugin/lib/helpers.py index d6d7f9c67..e4ca1acfe 100644 --- a/mypy_django_plugin/lib/helpers.py +++ b/mypy_django_plugin/lib/helpers.py @@ -1,70 +1,78 @@ -from collections import OrderedDict from typing import ( - TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple, Union, + TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, Tuple, Union, ) from django.db.models.fields import Field from django.db.models.fields.related import RelatedField from django.db.models.fields.reverse_related import ForeignObjectRel -from mypy import checker from mypy.checker import TypeChecker from mypy.mro import calculate_mro from mypy.nodes import ( - GDEF, MDEF, Argument, Block, ClassDef, Expression, FuncDef, MemberExpr, MypyFile, NameExpr, PlaceholderNode, - StrExpr, SymbolNode, SymbolTable, SymbolTableNode, TypeInfo, Var, + Block, ClassDef, Expression, MemberExpr, MypyFile, NameExpr, StrExpr, SymbolTable, SymbolTableNode, + TypeInfo, Var, ) -from mypy.plugin import ( - AttributeContext, CheckerPluginInterface, ClassDefContext, DynamicClassDefContext, FunctionContext, MethodContext, -) -from mypy.plugins.common import add_method from mypy.semanal import SemanticAnalyzer -from mypy.types import AnyType, CallableType, Instance, NoneTyp, TupleType +from mypy.types import AnyType, Instance, NoneTyp from mypy.types import Type as MypyType -from mypy.types import TypedDictType, TypeOfAny, UnionType +from mypy.types import TypeOfAny, UnionType from mypy_django_plugin.lib import fullnames if TYPE_CHECKING: from mypy_django_plugin.django.context import DjangoContext +AnyPluginAPI = Union[TypeChecker, SemanticAnalyzer] + def get_django_metadata(model_info: TypeInfo) -> Dict[str, Any]: return model_info.metadata.setdefault('django', {}) -class IncompleteDefnException(Exception): - pass - - -def lookup_fully_qualified_sym(fullname: str, all_modules: Dict[str, MypyFile]) -> Optional[SymbolTableNode]: +def split_symbol_name(fullname: str, all_modules: Dict[str, MypyFile]) -> Optional[Tuple[str, str]]: if '.' not in fullname: return None - module, cls_name = fullname.rsplit('.', 1) - - module_file = all_modules.get(module) - if module_file is None: - return None - sym = module_file.names.get(cls_name) - if sym is None: - return None - return sym + module_name = None + parts = fullname.split('.') + for i in range(len(parts), 0, -1): + possible_module_name = '.'.join(parts[:i]) + if possible_module_name in all_modules: + module_name = possible_module_name + break -def lookup_fully_qualified_generic(name: str, all_modules: Dict[str, MypyFile]) -> Optional[SymbolNode]: - sym = lookup_fully_qualified_sym(name, all_modules) - if sym is None: + if module_name is None: return None - return sym.node + + cls_name = fullname.replace(module_name, '').lstrip('.') + return module_name, cls_name -def lookup_fully_qualified_typeinfo(api: Union[TypeChecker, SemanticAnalyzer], fullname: str) -> Optional[TypeInfo]: - node = lookup_fully_qualified_generic(fullname, api.modules) - if not isinstance(node, TypeInfo): +def lookup_fully_qualified_typeinfo(api: AnyPluginAPI, fullname: str) -> Optional[TypeInfo]: + split = split_symbol_name(fullname, api.modules) + if split is None: return None - return node + module_name, cls_name = split + + sym_table = api.modules[module_name].names # type: Dict[str, SymbolTableNode] + if '.' in cls_name: + parent_cls_name, _, cls_name = cls_name.rpartition('.') + # nested class + for parent_cls_name in parent_cls_name.split('.'): + sym = sym_table.get(parent_cls_name) + if (sym is None or sym.node is None + or not isinstance(sym.node, TypeInfo)): + return None + sym_table = sym.node.names + + sym = sym_table.get(cls_name) + if (sym is None + or sym.node is None + or not isinstance(sym.node, TypeInfo)): + return None + return sym.node -def lookup_class_typeinfo(api: TypeChecker, klass: type) -> Optional[TypeInfo]: +def lookup_class_typeinfo(api: AnyPluginAPI, klass: type) -> Optional[TypeInfo]: fullname = get_class_fullname(klass) field_info = lookup_fully_qualified_typeinfo(api, fullname) return field_info @@ -79,36 +87,6 @@ def get_class_fullname(klass: type) -> str: return klass.__module__ + '.' + klass.__qualname__ -def get_call_argument_by_name(ctx: Union[FunctionContext, MethodContext], name: str) -> Optional[Expression]: - """ - Return the expression for the specific argument. - This helper should only be used with non-star arguments. - """ - if name not in ctx.callee_arg_names: - return None - idx = ctx.callee_arg_names.index(name) - args = ctx.args[idx] - if len(args) != 1: - # Either an error or no value passed. - return None - return args[0] - - -def get_call_argument_type_by_name(ctx: Union[FunctionContext, MethodContext], name: str) -> Optional[MypyType]: - """Return the type for the specific argument. - - This helper should only be used with non-star arguments. - """ - if name not in ctx.callee_arg_names: - return None - idx = ctx.callee_arg_names.index(name) - arg_types = ctx.arg_types[idx] - if len(arg_types) != 1: - # Either an error or no value passed. - return None - return arg_types[0] - - def make_optional(typ: MypyType) -> MypyType: return UnionType.make_union([typ, NoneTyp()]) @@ -153,7 +131,7 @@ def get_private_descriptor_type(type_info: TypeInfo, private_field_name: str, is return AnyType(TypeOfAny.explicit) -def get_field_lookup_exact_type(api: TypeChecker, field: Field) -> MypyType: +def get_field_lookup_exact_type(api: AnyPluginAPI, field: Field) -> MypyType: if isinstance(field, (RelatedField, ForeignObjectRel)): lookup_type_class = field.related_model rel_model_info = lookup_class_typeinfo(api, lookup_type_class) @@ -168,44 +146,10 @@ def get_field_lookup_exact_type(api: TypeChecker, field: Field) -> MypyType: is_nullable=field.null) -def get_nested_meta_node_for_current_class(info: TypeInfo) -> Optional[TypeInfo]: - metaclass_sym = info.names.get('Meta') - if metaclass_sym is not None and isinstance(metaclass_sym.node, TypeInfo): - return metaclass_sym.node - return None - - -def add_new_class_for_module(module: MypyFile, - name: str, - bases: List[Instance], - fields: Optional[Dict[str, MypyType]] = None - ) -> TypeInfo: - new_class_unique_name = checker.gen_unique_name(name, module.names) - - # make new class expression - classdef = ClassDef(new_class_unique_name, Block([])) - classdef.fullname = module.fullname + '.' + new_class_unique_name +def get_current_module(api: AnyPluginAPI) -> MypyFile: + if isinstance(api, SemanticAnalyzer): + return api.cur_mod_node - # make new TypeInfo - new_typeinfo = TypeInfo(SymbolTable(), classdef, module.fullname) - new_typeinfo.bases = bases - calculate_mro(new_typeinfo) - new_typeinfo.calculate_metaclass_type() - - # add fields - if fields: - for field_name, field_type in fields.items(): - var = Var(field_name, type=field_type) - var.info = new_typeinfo - var._fullname = new_typeinfo.fullname + '.' + field_name - new_typeinfo.names[field_name] = SymbolTableNode(MDEF, var, plugin_generated=True) - - classdef.info = new_typeinfo - module.names[new_class_unique_name] = SymbolTableNode(GDEF, new_typeinfo, plugin_generated=True) - return new_typeinfo - - -def get_current_module(api: TypeChecker) -> MypyFile: current_module = None for item in reversed(api.scope.stack): if isinstance(item, MypyFile): @@ -215,21 +159,6 @@ def get_current_module(api: TypeChecker) -> MypyFile: return current_module -def make_oneoff_named_tuple(api: TypeChecker, name: str, fields: 'OrderedDict[str, MypyType]') -> TupleType: - current_module = get_current_module(api) - namedtuple_info = add_new_class_for_module(current_module, name, - bases=[api.named_generic_type('typing.NamedTuple', [])], - fields=fields) - return TupleType(list(fields.values()), fallback=Instance(namedtuple_info, [])) - - -def make_tuple(api: 'TypeChecker', fields: List[MypyType]) -> TupleType: - # fallback for tuples is any builtins.tuple instance - fallback = api.named_generic_type('builtins.tuple', - [AnyType(TypeOfAny.special_form)]) - return TupleType(fields, fallback=fallback) - - def convert_any_to_type(typ: MypyType, referred_to_type: MypyType) -> MypyType: if isinstance(typ, UnionType): converted_items = [] @@ -252,13 +181,6 @@ def convert_any_to_type(typ: MypyType, referred_to_type: MypyType) -> MypyType: return typ -def make_typeddict(api: CheckerPluginInterface, fields: 'OrderedDict[str, MypyType]', - required_keys: Set[str]) -> TypedDictType: - object_type = api.named_generic_type('mypy_extensions._TypedDict', []) - typed_dict_type = TypedDictType(fields, required_keys=required_keys, fallback=object_type) - return typed_dict_type - - def resolve_string_attribute_value(attr_expr: Expression, django_context: 'DjangoContext') -> Optional[str]: if isinstance(attr_expr, StrExpr): return attr_expr.value @@ -272,104 +194,25 @@ def resolve_string_attribute_value(attr_expr: Expression, django_context: 'Djang return None -def get_semanal_api(ctx: Union[ClassDefContext, DynamicClassDefContext]) -> SemanticAnalyzer: - if not isinstance(ctx.api, SemanticAnalyzer): - raise ValueError('Not a SemanticAnalyzer') - return ctx.api - - -def get_typechecker_api(ctx: Union[AttributeContext, MethodContext, FunctionContext]) -> TypeChecker: - if not isinstance(ctx.api, TypeChecker): - raise ValueError('Not a TypeChecker') - return ctx.api - - -def is_model_subclass_info(info: TypeInfo, django_context: 'DjangoContext') -> bool: +def is_subclass_of_model(info: TypeInfo, django_context: 'DjangoContext') -> bool: return (info.fullname in django_context.all_registered_model_class_fullnames or info.has_base(fullnames.MODEL_CLASS_FULLNAME)) -def check_types_compatible(ctx: Union[FunctionContext, MethodContext], - *, expected_type: MypyType, actual_type: MypyType, error_message: str) -> None: - api = get_typechecker_api(ctx) - api.check_subtype(actual_type, expected_type, - ctx.context, error_message, - 'got', 'expected') - - -def add_new_sym_for_info(info: TypeInfo, *, name: str, sym_type: MypyType) -> None: - # type=: type of the variable itself - var = Var(name=name, type=sym_type) - # var.info: type of the object variable is bound to - var.info = info - var._fullname = info.fullname + '.' + name - var.is_initialized_in_class = True - var.is_inferred = True - info.names[name] = SymbolTableNode(MDEF, var, - plugin_generated=True) - - -def build_unannotated_method_args(method_node: FuncDef) -> Tuple[List[Argument], MypyType]: - prepared_arguments = [] - for argument in method_node.arguments[1:]: - argument.type_annotation = AnyType(TypeOfAny.unannotated) - prepared_arguments.append(argument) - return_type = AnyType(TypeOfAny.unannotated) - return prepared_arguments, return_type - - -def copy_method_to_another_class(ctx: ClassDefContext, self_type: Instance, - new_method_name: str, method_node: FuncDef) -> None: - semanal_api = get_semanal_api(ctx) - if method_node.type is None: - if not semanal_api.final_iteration: - semanal_api.defer() - return - - arguments, return_type = build_unannotated_method_args(method_node) - add_method(ctx, - new_method_name, - args=arguments, - return_type=return_type, - self_type=self_type) - return - - method_type = method_node.type - if not isinstance(method_type, CallableType): - if not semanal_api.final_iteration: - semanal_api.defer() - return - - arguments = [] - bound_return_type = semanal_api.anal_type(method_type.ret_type, - allow_placeholder=True) - assert bound_return_type is not None - - if isinstance(bound_return_type, PlaceholderNode): - return - - for arg_name, arg_type, original_argument in zip(method_type.arg_names[1:], - method_type.arg_types[1:], - method_node.arguments[1:]): - bound_arg_type = semanal_api.anal_type(arg_type, allow_placeholder=True) - assert bound_arg_type is not None - - if isinstance(bound_arg_type, PlaceholderNode): - return - - var = Var(name=original_argument.variable.name, - type=arg_type) - var.line = original_argument.variable.line - var.column = original_argument.variable.column - argument = Argument(variable=var, - type_annotation=bound_arg_type, - initializer=original_argument.initializer, - kind=original_argument.kind) - argument.set_line(original_argument) - arguments.append(argument) - - add_method(ctx, - new_method_name, - args=arguments, - return_type=bound_return_type, - self_type=self_type) +def new_typeinfo(name: str, + *, + bases: List[Instance], + module_name: str) -> TypeInfo: + """ + Construct new TypeInfo instance. Cannot be used for nested classes. + """ + class_def = ClassDef(name, Block([])) + class_def.fullname = module_name + '.' + name + + info = TypeInfo(SymbolTable(), class_def, module_name) + info.bases = bases + calculate_mro(info) + info.metaclass_type = info.calculate_metaclass_type() + + class_def.info = info + return info diff --git a/mypy_django_plugin/lib/sem_helpers.py b/mypy_django_plugin/lib/sem_helpers.py new file mode 100644 index 000000000..139d604b9 --- /dev/null +++ b/mypy_django_plugin/lib/sem_helpers.py @@ -0,0 +1,120 @@ +from typing import List, NamedTuple, Optional, Tuple, Union, cast + +from mypy.nodes import Argument, FuncDef, TypeInfo, Var +from mypy.plugin import ClassDefContext, DynamicClassDefContext +from mypy.plugins.common import add_method +from mypy.semanal import SemanticAnalyzer +from mypy.types import AnyType, CallableType, Instance, PlaceholderType +from mypy.types import Type as MypyType +from mypy.types import TypeOfAny, get_proper_type + + +class IncompleteDefnException(Exception): + def __init__(self, error_message: str = '') -> None: + super().__init__(error_message) + + +class BoundNameNotFound(IncompleteDefnException): + def __init__(self, fullname: str) -> None: + super().__init__(f'No {fullname!r} found') + + +def get_semanal_api(ctx: Union[ClassDefContext, DynamicClassDefContext]) -> SemanticAnalyzer: + return cast(SemanticAnalyzer, ctx.api) + + +def get_nested_meta_node_for_current_class(info: TypeInfo) -> Optional[TypeInfo]: + metaclass_sym = info.names.get('Meta') + if metaclass_sym is not None and isinstance(metaclass_sym.node, TypeInfo): + return metaclass_sym.node + return None + + +def prepare_unannotated_method_signature(method_node: FuncDef) -> Tuple[List[Argument], MypyType]: + prepared_arguments = [] + for argument in method_node.arguments[1:]: + argument.type_annotation = AnyType(TypeOfAny.unannotated) + prepared_arguments.append(argument) + return_type = AnyType(TypeOfAny.unannotated) + return prepared_arguments, return_type + + +class SignatureTuple(NamedTuple): + arguments: List[Argument] + return_type: Optional[MypyType] + cannot_be_bound: bool + + +def analyze_callable_signature(api: SemanticAnalyzer, method_node: FuncDef) -> SignatureTuple: + method_type = method_node.type + assert isinstance(method_type, CallableType) + + arguments = [] + unbound = False + for arg_name, arg_type, original_argument in zip(method_type.arg_names[1:], + method_type.arg_types[1:], + method_node.arguments[1:]): + analyzed_arg_type = api.anal_type(get_proper_type(arg_type), allow_placeholder=True) + assert analyzed_arg_type is not None + if isinstance(analyzed_arg_type, PlaceholderType): + unbound = True + + var = Var(name=original_argument.variable.name, + type=analyzed_arg_type) + var.set_line(original_argument.variable) + + argument = Argument(variable=var, + type_annotation=analyzed_arg_type, + initializer=original_argument.initializer, + kind=original_argument.kind) + argument.set_line(original_argument) + arguments.append(argument) + + analyzed_ret_type = api.anal_type(get_proper_type(method_type.ret_type), allow_placeholder=True) + assert analyzed_ret_type is not None + if isinstance(analyzed_ret_type, PlaceholderType): + unbound = True + return SignatureTuple(arguments, analyzed_ret_type, unbound) + + +def copy_method_or_incomplete_defn_exception(ctx: ClassDefContext, + self_type: Instance, + new_method_name: str, + method_node: FuncDef) -> None: + semanal_api = get_semanal_api(ctx) + + if method_node.type is None: + if not semanal_api.final_iteration: + raise IncompleteDefnException(f'Unannotated method {method_node.fullname!r}') + + arguments, return_type = prepare_unannotated_method_signature(method_node) + add_method(ctx, + new_method_name, + args=arguments, + return_type=return_type, + self_type=self_type) + return + + assert isinstance(method_node.type, CallableType) + + # copy global SymbolTableNode objects from original class to the current node, if not present + original_module = semanal_api.modules[method_node.info.module_name] + for name, sym in original_module.names.items(): + if (not sym.plugin_generated + and name not in semanal_api.cur_mod_node.names): + semanal_api.add_imported_symbol(name, sym, context=semanal_api.cur_mod_node) + + arguments, analyzed_return_type, unbound = analyze_callable_signature(semanal_api, method_node) + assert len(arguments) + 1 == len(method_node.arguments) + if unbound: + raise IncompleteDefnException(f'Signature of method {method_node.fullname!r} is not ready') + + assert analyzed_return_type is not None + + if new_method_name in ctx.cls.info.names: + del ctx.cls.info.names[new_method_name] + add_method(ctx, + new_method_name, + args=arguments, + return_type=analyzed_return_type, + self_type=self_type) diff --git a/mypy_django_plugin/main.py b/mypy_django_plugin/main.py index 30ac0e0d3..93017d909 100644 --- a/mypy_django_plugin/main.py +++ b/mypy_django_plugin/main.py @@ -18,7 +18,8 @@ fields, forms, init_create, meta, querysets, request, settings, ) from mypy_django_plugin.transformers.managers import ( - create_new_manager_class_from_from_queryset_method, + create_manager_class_from_as_manager_method, create_new_manager_class_from_from_queryset_method, + instantiate_anonymous_queryset_from_as_manager, ) from mypy_django_plugin.transformers.models import process_model_class @@ -123,6 +124,10 @@ def _new_dependency(self, module: str) -> Tuple[int, str, int]: return 10, module, -1 def get_additional_deps(self, file: MypyFile) -> List[Tuple[int, str, int]]: + # load QuerySet and Manager together (for as_manager) + if file.fullname == 'django.db.models.query': + return [self._new_dependency('django.db.models.manager')] + # for settings if file.fullname == 'django.conf' and self.django_context.django_settings_module: return [self._new_dependency(self.django_context.django_settings_module)] @@ -180,7 +185,7 @@ def get_function_hook(self, fullname: str if info.has_base(fullnames.FIELD_FULLNAME): return partial(fields.transform_into_proper_return_type, django_context=self.django_context) - if helpers.is_model_subclass_info(info, self.django_context): + if helpers.is_subclass_of_model(info, self.django_context): return partial(init_create.redefine_and_typecheck_model_init, django_context=self.django_context) return None @@ -212,6 +217,11 @@ def get_method_hook(self, fullname: str if info and info.has_base(fullnames.OPTIONS_CLASS_FULLNAME): return partial(meta.return_proper_field_type_from_get_field, django_context=self.django_context) + if method_name == 'as_manager': + info = self._get_typeinfo_or_none(class_fullname) + if info and info.has_base(fullnames.QUERYSET_CLASS_FULLNAME): + return instantiate_anonymous_queryset_from_as_manager + manager_classes = self._get_current_manager_bases() if class_fullname in manager_classes and method_name == 'create': return partial(init_create.redefine_and_typecheck_model_create, django_context=self.django_context) @@ -252,6 +262,12 @@ def get_dynamic_class_hook(self, fullname: str info = self._get_typeinfo_or_none(class_name) if info and info.has_base(fullnames.BASE_MANAGER_CLASS_FULLNAME): return create_new_manager_class_from_from_queryset_method + if fullname.endswith('as_manager'): + class_name, _, _ = fullname.rpartition('.') + info = self._get_typeinfo_or_none(class_name) + if info and info.has_base(fullnames.QUERYSET_CLASS_FULLNAME): + return create_manager_class_from_as_manager_method + return None diff --git a/mypy_django_plugin/transformers/fields.py b/mypy_django_plugin/transformers/fields.py index b88fdbf59..d2082b54e 100644 --- a/mypy_django_plugin/transformers/fields.py +++ b/mypy_django_plugin/transformers/fields.py @@ -9,13 +9,13 @@ from mypy.types import TypeOfAny from mypy_django_plugin.django.context import DjangoContext -from mypy_django_plugin.lib import fullnames, helpers +from mypy_django_plugin.lib import chk_helpers, fullnames, helpers def _get_current_field_from_assignment(ctx: FunctionContext, django_context: DjangoContext) -> Optional[Field]: - outer_model_info = helpers.get_typechecker_api(ctx).scope.active_class() + outer_model_info = chk_helpers.get_typechecker_api(ctx).scope.active_class() if (outer_model_info is None - or not helpers.is_model_subclass_info(outer_model_info, django_context)): + or not helpers.is_subclass_of_model(outer_model_info, django_context)): return None field_name = None @@ -66,21 +66,21 @@ def fill_descriptor_types_for_related_field(ctx: FunctionContext, django_context # __get__/__set__ of ForeignKey of derived model for model_cls in django_context.all_registered_model_classes: if issubclass(model_cls, current_model_cls) and not model_cls._meta.abstract: - derived_model_info = helpers.lookup_class_typeinfo(helpers.get_typechecker_api(ctx), model_cls) + derived_model_info = helpers.lookup_class_typeinfo(chk_helpers.get_typechecker_api(ctx), model_cls) if derived_model_info is not None: fk_ref_type = Instance(derived_model_info, []) derived_fk_type = reparametrize_related_field_type(default_related_field_type, set_type=fk_ref_type, get_type=fk_ref_type) - helpers.add_new_sym_for_info(derived_model_info, - name=current_field.name, - sym_type=derived_fk_type) + chk_helpers.add_new_sym_for_info(derived_model_info, + name=current_field.name, + sym_type=derived_fk_type) related_model = related_model_cls related_model_to_set = related_model_cls if related_model_to_set._meta.proxy_for_model is not None: related_model_to_set = related_model_to_set._meta.proxy_for_model - typechecker_api = helpers.get_typechecker_api(ctx) + typechecker_api = chk_helpers.get_typechecker_api(ctx) related_model_info = helpers.lookup_class_typeinfo(typechecker_api, related_model) if related_model_info is None: @@ -114,7 +114,7 @@ def set_descriptor_types_for_field(ctx: FunctionContext) -> Instance: default_return_type = cast(Instance, ctx.default_return_type) is_nullable = False - null_expr = helpers.get_call_argument_by_name(ctx, 'null') + null_expr = chk_helpers.get_call_argument_by_name(ctx, 'null') if null_expr is not None: is_nullable = helpers.parse_bool(null_expr) or False @@ -122,10 +122,10 @@ def set_descriptor_types_for_field(ctx: FunctionContext) -> Instance: return helpers.reparametrize_instance(default_return_type, [set_type, get_type]) -def determine_type_of_array_field(ctx: FunctionContext, django_context: DjangoContext) -> MypyType: +def determine_type_of_array_field(ctx: FunctionContext) -> MypyType: default_return_type = set_descriptor_types_for_field(ctx) - base_field_arg_type = helpers.get_call_argument_type_by_name(ctx, 'base_field') + base_field_arg_type = chk_helpers.get_call_argument_type_by_name(ctx, 'base_field') if not base_field_arg_type or not isinstance(base_field_arg_type, Instance): return default_return_type @@ -141,9 +141,9 @@ def transform_into_proper_return_type(ctx: FunctionContext, django_context: Djan default_return_type = ctx.default_return_type assert isinstance(default_return_type, Instance) - outer_model_info = helpers.get_typechecker_api(ctx).scope.active_class() + outer_model_info = chk_helpers.get_typechecker_api(ctx).scope.active_class() if (outer_model_info is None - or not helpers.is_model_subclass_info(outer_model_info, django_context)): + or not helpers.is_subclass_of_model(outer_model_info, django_context)): return ctx.default_return_type assert isinstance(outer_model_info, TypeInfo) @@ -152,6 +152,6 @@ def transform_into_proper_return_type(ctx: FunctionContext, django_context: Djan return fill_descriptor_types_for_related_field(ctx, django_context) if default_return_type.type.has_base(fullnames.ARRAY_FIELD_FULLNAME): - return determine_type_of_array_field(ctx, django_context) + return determine_type_of_array_field(ctx) return set_descriptor_types_for_field(ctx) diff --git a/mypy_django_plugin/transformers/forms.py b/mypy_django_plugin/transformers/forms.py index 7bd0e1116..6f3741ff5 100644 --- a/mypy_django_plugin/transformers/forms.py +++ b/mypy_django_plugin/transformers/forms.py @@ -5,11 +5,11 @@ from mypy.types import Type as MypyType from mypy.types import TypeType -from mypy_django_plugin.lib import helpers +from mypy_django_plugin.lib import chk_helpers, sem_helpers def make_meta_nested_class_inherit_from_any(ctx: ClassDefContext) -> None: - meta_node = helpers.get_nested_meta_node_for_current_class(ctx.cls.info) + meta_node = sem_helpers.get_nested_meta_node_for_current_class(ctx.cls.info) if meta_node is None: if not ctx.api.final_iteration: ctx.api.defer() @@ -28,7 +28,7 @@ def extract_proper_type_for_get_form(ctx: MethodContext) -> MypyType: object_type = ctx.type assert isinstance(object_type, Instance) - form_class_type = helpers.get_call_argument_type_by_name(ctx, 'form_class') + form_class_type = chk_helpers.get_call_argument_type_by_name(ctx, 'form_class') if form_class_type is None or isinstance(form_class_type, NoneTyp): form_class_type = get_specified_form_class(object_type) diff --git a/mypy_django_plugin/transformers/init_create.py b/mypy_django_plugin/transformers/init_create.py index fe0b19ee2..549f826ba 100644 --- a/mypy_django_plugin/transformers/init_create.py +++ b/mypy_django_plugin/transformers/init_create.py @@ -6,7 +6,7 @@ from mypy.types import Type as MypyType from mypy_django_plugin.django.context import DjangoContext -from mypy_django_plugin.lib import helpers +from mypy_django_plugin.lib import chk_helpers def get_actual_types(ctx: Union[MethodContext, FunctionContext], @@ -32,7 +32,7 @@ def get_actual_types(ctx: Union[MethodContext, FunctionContext], def typecheck_model_method(ctx: Union[FunctionContext, MethodContext], django_context: DjangoContext, model_cls: Type[Model], method: str) -> MypyType: - typechecker_api = helpers.get_typechecker_api(ctx) + typechecker_api = chk_helpers.get_typechecker_api(ctx) expected_types = django_context.get_expected_types(typechecker_api, model_cls, method=method) expected_keys = [key for key in expected_types.keys() if key != 'pk'] @@ -42,11 +42,11 @@ def typecheck_model_method(ctx: Union[FunctionContext, MethodContext], django_co model_cls.__name__), ctx.context) continue - helpers.check_types_compatible(ctx, - expected_type=expected_types[actual_name], - actual_type=actual_type, - error_message='Incompatible type for "{}" of "{}"'.format(actual_name, - model_cls.__name__)) + error_message = 'Incompatible type for "{}" of "{}"'.format(actual_name, model_cls.__name__) + chk_helpers.check_types_compatible(ctx, + expected_type=expected_types[actual_name], + actual_type=actual_type, + error_message=error_message) return ctx.default_return_type diff --git a/mypy_django_plugin/transformers/managers.py b/mypy_django_plugin/transformers/managers.py index 88201b439..78d223088 100644 --- a/mypy_django_plugin/transformers/managers.py +++ b/mypy_django_plugin/transformers/managers.py @@ -1,77 +1,311 @@ +from typing import Any, Dict, Iterator, Optional, Tuple + from mypy.nodes import ( - GDEF, FuncDef, MemberExpr, NameExpr, RefExpr, StrExpr, SymbolTableNode, TypeInfo, + GDEF, CallExpr, Context, Decorator, FuncDef, MemberExpr, NameExpr, OverloadedFuncDef, PlaceholderNode, RefExpr, + StrExpr, SymbolTable, SymbolTableNode, TypeInfo, ) -from mypy.plugin import ClassDefContext, DynamicClassDefContext -from mypy.types import AnyType, Instance, TypeOfAny +from mypy.plugin import ClassDefContext, DynamicClassDefContext, MethodContext +from mypy.semanal import SemanticAnalyzer, is_same_symbol, is_valid_replacement +from mypy.types import AnyType, CallableType, Instance +from mypy.types import Type as MypyType +from mypy.types import TypeOfAny -from mypy_django_plugin.lib import fullnames, helpers +from mypy_django_plugin.lib import chk_helpers, fullnames, helpers, sem_helpers -def create_new_manager_class_from_from_queryset_method(ctx: DynamicClassDefContext) -> None: - semanal_api = helpers.get_semanal_api(ctx) +def iter_all_custom_queryset_methods(derived_queryset_info: TypeInfo) -> Iterator[Tuple[str, FuncDef]]: + for base_queryset_info in derived_queryset_info.mro: + if base_queryset_info.fullname == fullnames.QUERYSET_CLASS_FULLNAME: + break + for name, sym in base_queryset_info.names.items(): + if isinstance(sym.node, FuncDef): + yield name, sym.node + +def generate_from_queryset_name(base_manager_info: TypeInfo, queryset_info: TypeInfo) -> str: + return base_manager_info.name + 'From' + queryset_info.name + + +def resolve_callee_info_or_exception(ctx: DynamicClassDefContext) -> TypeInfo: callee = ctx.call.callee assert isinstance(callee, MemberExpr) assert isinstance(callee.expr, RefExpr) - base_manager_info = callee.expr.node - if base_manager_info is None: + callee_info = callee.expr.node + if (callee_info is None + or isinstance(callee_info, PlaceholderNode)): + raise sem_helpers.IncompleteDefnException(f'Definition of base manager {callee.fullname!r} ' + f'is incomplete.') + + assert isinstance(callee_info, TypeInfo) + return callee_info + + +def resolve_passed_queryset_info_or_exception(ctx: DynamicClassDefContext) -> TypeInfo: + api = sem_helpers.get_semanal_api(ctx) + + passed_queryset_name_expr = ctx.call.args[0] + assert isinstance(passed_queryset_name_expr, NameExpr) + + sym = api.lookup_qualified(passed_queryset_name_expr.name, ctx=ctx.call) + if (sym is None + or sym.node is None + or isinstance(sym.node, PlaceholderNode)): + bound_name = passed_queryset_name_expr.fullname or passed_queryset_name_expr.name + raise sem_helpers.BoundNameNotFound(bound_name) + + assert isinstance(sym.node, TypeInfo) + return sym.node + + +def resolve_django_manager_info_or_exception(ctx: DynamicClassDefContext) -> TypeInfo: + api = sem_helpers.get_semanal_api(ctx) + info = helpers.lookup_fully_qualified_typeinfo(api, fullnames.MANAGER_CLASS_FULLNAME) + if info is None: + raise sem_helpers.BoundNameNotFound(fullnames.MANAGER_CLASS_FULLNAME) + + return info + + +def new_manager_typeinfo(ctx: DynamicClassDefContext, callee_manager_info: TypeInfo) -> TypeInfo: + callee_manager_type = Instance(callee_manager_info, [AnyType(TypeOfAny.unannotated)]) + api = sem_helpers.get_semanal_api(ctx) + + new_manager_class_name = ctx.name + new_manager_info = helpers.new_typeinfo(new_manager_class_name, + bases=[callee_manager_type], module_name=api.cur_mod_id) + new_manager_info.set_line(ctx.call) + return new_manager_info + + +def get_generated_manager_fullname(call: CallExpr, base_manager_info: TypeInfo, queryset_info: TypeInfo) -> str: + if len(call.args) > 1: + # only for from_queryset() + expr = call.args[1] + assert isinstance(expr, StrExpr) + custom_manager_generated_name = expr.value + else: + custom_manager_generated_name = base_manager_info.name + 'From' + queryset_info.name + + custom_manager_generated_fullname = 'django.db.models.manager' + '.' + custom_manager_generated_name + return custom_manager_generated_fullname + + +def get_generated_managers_metadata(django_manager_info: TypeInfo) -> Dict[str, Any]: + return django_manager_info.metadata.setdefault('from_queryset_managers', {}) + + +def record_new_manager_info_fullname_into_metadata(ctx: DynamicClassDefContext, + new_manager_fullname: str, + callee_manager_info: TypeInfo, + queryset_info: TypeInfo, + django_manager_info: TypeInfo) -> None: + custom_manager_generated_fullname = get_generated_manager_fullname(ctx.call, + base_manager_info=callee_manager_info, + queryset_info=queryset_info) + metadata = get_generated_managers_metadata(django_manager_info) + metadata[custom_manager_generated_fullname] = new_manager_fullname + + +def create_new_manager_class_from_from_queryset_method(ctx: DynamicClassDefContext) -> None: + semanal_api = sem_helpers.get_semanal_api(ctx) + try: + callee_manager_info = resolve_callee_info_or_exception(ctx) + queryset_info = resolve_passed_queryset_info_or_exception(ctx) + django_manager_info = resolve_django_manager_info_or_exception(ctx) + except sem_helpers.IncompleteDefnException: if not semanal_api.final_iteration: semanal_api.defer() - return - - assert isinstance(base_manager_info, TypeInfo) - new_manager_info = semanal_api.basic_new_typeinfo(ctx.name, - basetype_or_fallback=Instance(base_manager_info, - [AnyType(TypeOfAny.unannotated)])) - new_manager_info.line = ctx.call.line - new_manager_info.defn.line = ctx.call.line - new_manager_info.metaclass_type = new_manager_info.calculate_metaclass_type() - - current_module = semanal_api.cur_mod_node - current_module.names[ctx.name] = SymbolTableNode(GDEF, new_manager_info, - plugin_generated=True) - passed_queryset = ctx.call.args[0] - assert isinstance(passed_queryset, NameExpr) - - derived_queryset_fullname = passed_queryset.fullname - assert derived_queryset_fullname is not None - - sym = semanal_api.lookup_fully_qualified_or_none(derived_queryset_fullname) - assert sym is not None - if sym.node is None: + return + else: + raise + + new_manager_info = new_manager_typeinfo(ctx, callee_manager_info) + record_new_manager_info_fullname_into_metadata(ctx, + new_manager_info.fullname, + callee_manager_info, + queryset_info, + django_manager_info) + + class_def_context = ClassDefContext(cls=new_manager_info.defn, + reason=ctx.call, api=semanal_api) + self_type = Instance(new_manager_info, [AnyType(TypeOfAny.explicit)]) + + try: + for name, method_node in iter_all_custom_queryset_methods(queryset_info): + sem_helpers.copy_method_or_incomplete_defn_exception(class_def_context, + self_type, + new_method_name=name, + method_node=method_node) + except sem_helpers.IncompleteDefnException: if not semanal_api.final_iteration: semanal_api.defer() + return else: - # inherit from Any to prevent false-positives, if queryset class cannot be resolved - new_manager_info.fallback_to_any = True - return + raise - derived_queryset_info = sym.node - assert isinstance(derived_queryset_info, TypeInfo) + new_manager_sym = SymbolTableNode(GDEF, new_manager_info, plugin_generated=True) - if len(ctx.call.args) > 1: - expr = ctx.call.args[1] - assert isinstance(expr, StrExpr) - custom_manager_generated_name = expr.value - else: - custom_manager_generated_name = base_manager_info.name + 'From' + derived_queryset_info.name + # context=None - forcibly replace old node + added = semanal_api.add_symbol_table_node(ctx.name, new_manager_sym, context=None) + if added: + # replace all references to the old manager Var everywhere + for _, module in semanal_api.modules.items(): + if module.fullname != semanal_api.cur_mod_id: + for sym_name, sym in module.names.items(): + if sym.fullname == new_manager_info.fullname: + module.names[sym_name] = new_manager_sym.copy() + + # we need another iteration to process methods + if (not added + and not semanal_api.final_iteration): + semanal_api.defer() + + +def add_symbol_table_node(api: SemanticAnalyzer, + name: str, + symbol: SymbolTableNode, + context: Optional[Context] = None, + symbol_table: Optional[SymbolTable] = None, + can_defer: bool = True, + escape_comprehensions: bool = False) -> bool: + """Add symbol table node to the currently active symbol table. + + Return True if we actually added the symbol, or False if we refused + to do so (because something is not ready or it was a no-op). + + Generate an error if there is an invalid redefinition. + + If context is None, unconditionally add node, since we can't report + an error. Note that this is used by plugins to forcibly replace nodes! + + TODO: Prevent plugins from replacing nodes, as it could cause problems? + + Args: + name: short name of symbol + symbol: Node to add + can_defer: if True, defer current target if adding a placeholder + context: error context (see above about None value) + """ + names = symbol_table or api.current_symbol_table(escape_comprehensions=escape_comprehensions) + existing = names.get(name) + if isinstance(symbol.node, PlaceholderNode) and can_defer: + api.defer(context) + if (existing is not None + and context is not None + and not is_valid_replacement(existing, symbol)): + # There is an existing node, so this may be a redefinition. + # If the new node points to the same node as the old one, + # or if both old and new nodes are placeholders, we don't + # need to do anything. + old = existing.node + new = symbol.node + if isinstance(new, PlaceholderNode): + # We don't know whether this is okay. Let's wait until the next iteration. + return False + if not is_same_symbol(old, new): + if isinstance(new, (FuncDef, Decorator, OverloadedFuncDef, TypeInfo)): + api.add_redefinition(names, name, symbol) + if not (isinstance(new, (FuncDef, Decorator)) + and api.set_original_def(old, new)): + api.name_already_defined(name, context, existing) + elif name not in api.missing_names and '*' not in api.missing_names: + names[name] = symbol + api.progress = True + return True + return False + + +def create_manager_class_from_as_manager_method(ctx: DynamicClassDefContext) -> None: + semanal_api = sem_helpers.get_semanal_api(ctx) + try: + queryset_info = resolve_callee_info_or_exception(ctx) + django_manager_info = resolve_django_manager_info_or_exception(ctx) + except sem_helpers.IncompleteDefnException: + if not semanal_api.final_iteration: + semanal_api.defer() + return + else: + raise + + generic_param: MypyType = AnyType(TypeOfAny.explicit) + generic_param_name = 'Any' + if (semanal_api.scope.classes + and semanal_api.scope.classes[-1].has_base(fullnames.MODEL_CLASS_FULLNAME)): + info = semanal_api.scope.classes[-1] # type: TypeInfo + generic_param = Instance(info, []) + generic_param_name = info.name - custom_manager_generated_fullname = '.'.join(['django.db.models.manager', custom_manager_generated_name]) - if 'from_queryset_managers' not in base_manager_info.metadata: - base_manager_info.metadata['from_queryset_managers'] = {} - base_manager_info.metadata['from_queryset_managers'][custom_manager_generated_fullname] = new_manager_info.fullname + new_manager_class_name = queryset_info.name + '_AsManager_' + generic_param_name + new_manager_info = helpers.new_typeinfo(new_manager_class_name, + bases=[Instance(django_manager_info, [generic_param])], + module_name=semanal_api.cur_mod_id) + new_manager_info.set_line(ctx.call) + + record_new_manager_info_fullname_into_metadata(ctx, + new_manager_info.fullname, + django_manager_info, + queryset_info, + django_manager_info) class_def_context = ClassDefContext(cls=new_manager_info.defn, reason=ctx.call, api=semanal_api) - self_type = Instance(new_manager_info, []) - # we need to copy all methods in MRO before django.db.models.query.QuerySet - for class_mro_info in derived_queryset_info.mro: - if class_mro_info.fullname == fullnames.QUERYSET_CLASS_FULLNAME: - break - for name, sym in class_mro_info.names.items(): - if isinstance(sym.node, FuncDef): - helpers.copy_method_to_another_class(class_def_context, - self_type, - new_method_name=name, - method_node=sym.node) + self_type = Instance(new_manager_info, [AnyType(TypeOfAny.explicit)]) + + try: + for name, method_node in iter_all_custom_queryset_methods(queryset_info): + sem_helpers.copy_method_or_incomplete_defn_exception(class_def_context, + self_type, + new_method_name=name, + method_node=method_node) + except sem_helpers.IncompleteDefnException: + if not semanal_api.final_iteration: + semanal_api.defer() + return + else: + raise + + new_manager_sym = SymbolTableNode(GDEF, new_manager_info, plugin_generated=True) + + # context=None - forcibly replace old node + added = add_symbol_table_node(semanal_api, new_manager_class_name, new_manager_sym, + context=None, + symbol_table=semanal_api.globals) + if added: + # replace all references to the old manager Var everywhere + for _, module in semanal_api.modules.items(): + if module.fullname != semanal_api.cur_mod_id: + for sym_name, sym in module.names.items(): + if sym.fullname == new_manager_info.fullname: + module.names[sym_name] = new_manager_sym.copy() + + # we need another iteration to process methods + if (not added + and not semanal_api.final_iteration): + semanal_api.defer() + + +def instantiate_anonymous_queryset_from_as_manager(ctx: MethodContext) -> MypyType: + api = chk_helpers.get_typechecker_api(ctx) + django_manager_info = helpers.lookup_fully_qualified_typeinfo(api, fullnames.MANAGER_CLASS_FULLNAME) + assert django_manager_info is not None + + assert isinstance(ctx.type, CallableType) + assert isinstance(ctx.type.ret_type, Instance) + queryset_info = ctx.type.ret_type.type + + gen_name = django_manager_info.name + 'From' + queryset_info.name + gen_fullname = 'django.db.models.manager' + '.' + gen_name + + metadata = get_generated_managers_metadata(django_manager_info) + if gen_fullname not in metadata: + raise ValueError(f'{gen_fullname!r} is not present in generated managers list') + + module_name, _, class_name = metadata[gen_fullname].rpartition('.') + current_module = helpers.get_current_module(api) + assert module_name == current_module.fullname + + generated_manager_info = current_module.names[class_name].node + assert isinstance(generated_manager_info, TypeInfo) + + return Instance(generated_manager_info, []) diff --git a/mypy_django_plugin/transformers/meta.py b/mypy_django_plugin/transformers/meta.py index 64e6e12fe..0af47176b 100644 --- a/mypy_django_plugin/transformers/meta.py +++ b/mypy_django_plugin/transformers/meta.py @@ -5,12 +5,12 @@ from mypy.types import TypeOfAny from mypy_django_plugin.django.context import DjangoContext -from mypy_django_plugin.lib import helpers +from mypy_django_plugin.lib import chk_helpers, helpers def _get_field_instance(ctx: MethodContext, field_fullname: str) -> MypyType: - field_info = helpers.lookup_fully_qualified_typeinfo(helpers.get_typechecker_api(ctx), - field_fullname) + api = chk_helpers.get_typechecker_api(ctx) + field_info = helpers.lookup_fully_qualified_typeinfo(api, field_fullname) if field_info is None: return AnyType(TypeOfAny.unannotated) return Instance(field_info, [AnyType(TypeOfAny.explicit), AnyType(TypeOfAny.explicit)]) @@ -32,7 +32,7 @@ def return_proper_field_type_from_get_field(ctx: MethodContext, django_context: if model_cls is None: return ctx.default_return_type - field_name_expr = helpers.get_call_argument_by_name(ctx, 'field_name') + field_name_expr = chk_helpers.get_call_argument_by_name(ctx, 'field_name') if field_name_expr is None: return ctx.default_return_type diff --git a/mypy_django_plugin/transformers/models.py b/mypy_django_plugin/transformers/models.py index f0c436ca0..6e22aae97 100644 --- a/mypy_django_plugin/transformers/models.py +++ b/mypy_django_plugin/transformers/models.py @@ -1,21 +1,23 @@ -from typing import Dict, List, Optional, Type, cast +from typing import List, Optional, Type, cast from django.db.models.base import Model from django.db.models.fields import DateField, DateTimeField -from django.db.models.fields.related import ForeignKey +from django.db.models.fields.related import ForeignKey, OneToOneField from django.db.models.fields.reverse_related import ( ManyToManyRel, ManyToOneRel, OneToOneRel, ) -from mypy.nodes import ARG_STAR2, Argument, Context, FuncDef, TypeInfo, Var +from mypy.nodes import ( + ARG_STAR2, GDEF, MDEF, Argument, Context, FuncDef, SymbolTableNode, TypeInfo, Var, +) from mypy.plugin import ClassDefContext from mypy.plugins import common -from mypy.semanal import SemanticAnalyzer +from mypy.semanal import SemanticAnalyzer, dummy_context from mypy.types import AnyType, Instance from mypy.types import Type as MypyType from mypy.types import TypeOfAny from mypy_django_plugin.django.context import DjangoContext -from mypy_django_plugin.lib import fullnames, helpers +from mypy_django_plugin.lib import fullnames, helpers, sem_helpers from mypy_django_plugin.transformers import fields from mypy_django_plugin.transformers.fields import get_field_descriptor_types @@ -35,7 +37,7 @@ def lookup_typeinfo(self, fullname: str) -> Optional[TypeInfo]: def lookup_typeinfo_or_incomplete_defn_error(self, fullname: str) -> TypeInfo: info = self.lookup_typeinfo(fullname) if info is None: - raise helpers.IncompleteDefnException(f'No {fullname!r} found') + raise sem_helpers.IncompleteDefnException(f'No {fullname!r} found') return info def lookup_class_typeinfo_or_incomplete_defn_error(self, klass: type) -> TypeInfo: @@ -43,26 +45,52 @@ def lookup_class_typeinfo_or_incomplete_defn_error(self, klass: type) -> TypeInf field_info = self.lookup_typeinfo_or_incomplete_defn_error(fullname) return field_info - def create_new_var(self, name: str, typ: MypyType) -> Var: - # type=: type of the variable itself - var = Var(name=name, type=typ) - # var.info: type of the object variable is bound to + def model_class_has_attribute_defined(self, name: str, traverse_mro: bool = True) -> bool: + if not traverse_mro: + sym = self.model_classdef.info.names.get(name) + else: + sym = self.model_classdef.info.get(name) + return sym is not None + + def resolve_manager_fullname(self, manager_fullname: str) -> str: + base_manager_info = self.lookup_typeinfo(fullnames.MANAGER_CLASS_FULLNAME) + if (base_manager_info is None + or 'from_queryset_managers' not in base_manager_info.metadata): + return manager_fullname + + metadata = base_manager_info.metadata['from_queryset_managers'] + return metadata.get(manager_fullname, manager_fullname) + + def add_new_node_to_model_class(self, name: str, typ: MypyType, + force_replace_existing: bool = False) -> None: + if not force_replace_existing and name in self.model_classdef.info.names: + raise ValueError(f'Member {name!r} already defined at model {self.model_classdef.info.fullname!r}.') + + var = Var(name, type=typ) + # TypeInfo of the object variable is bound to var.info = self.model_classdef.info - var._fullname = self.model_classdef.info.fullname + '.' + name + var._fullname = self.api.qualified_name(name) var.is_initialized_in_class = True - var.is_inferred = True - return var - - def add_new_node_to_model_class(self, name: str, typ: MypyType) -> None: - helpers.add_new_sym_for_info(self.model_classdef.info, - name=name, - sym_type=typ) - def add_new_class_for_current_module(self, name: str, bases: List[Instance]) -> TypeInfo: - current_module = self.api.modules[self.model_classdef.info.module_name] - new_class_info = helpers.add_new_class_for_module(current_module, - name=name, bases=bases) - return new_class_info + sym = SymbolTableNode(MDEF, var, plugin_generated=True) + context: Optional[Context] = dummy_context() + if force_replace_existing: + context = None + self.api.add_symbol_table_node(name, sym, context=context) + + def add_new_class_for_current_module(self, name: str, bases: List[Instance], + force_replace_existing: bool = False) -> TypeInfo: + current_module = self.api.cur_mod_node + if not force_replace_existing and name in current_module.names: + raise ValueError(f'Class {name!r} already defined for module {current_module.fullname!r}') + + new_typeinfo = helpers.new_typeinfo(name, + bases=bases, + module_name=current_module.fullname) + if name in current_module.names: + del current_module.names[name] + current_module.names[name] = SymbolTableNode(GDEF, new_typeinfo, plugin_generated=True) + return new_typeinfo def run(self) -> None: model_cls = self.django_context.get_model_class_by_fullname(self.model_classdef.fullname) @@ -88,58 +116,90 @@ class Meta(Any): """ def run(self) -> None: - meta_node = helpers.get_nested_meta_node_for_current_class(self.model_classdef.info) + meta_node = sem_helpers.get_nested_meta_node_for_current_class(self.model_classdef.info) if meta_node is None: return None meta_node.fallback_to_any = True class AddDefaultPrimaryKey(ModelClassInitializer): + """ + Adds default primary key to models which does not define their own. + ``` + class User(models.Model): + name = models.TextField() + ``` + """ + def run_with_model_cls(self, model_cls: Type[Model]) -> None: auto_field = model_cls._meta.auto_field - if auto_field and not self.model_classdef.info.has_readable_member(auto_field.attname): - # autogenerated field - auto_field_fullname = helpers.get_class_fullname(auto_field.__class__) - auto_field_info = self.lookup_typeinfo_or_incomplete_defn_error(auto_field_fullname) + if auto_field is None: + return - set_type, get_type = fields.get_field_descriptor_types(auto_field_info, is_nullable=False) - self.add_new_node_to_model_class(auto_field.attname, Instance(auto_field_info, - [set_type, get_type])) + primary_key_attrname = auto_field.attname + if self.model_class_has_attribute_defined(primary_key_attrname): + return + + auto_field_class_fullname = helpers.get_class_fullname(auto_field.__class__) + auto_field_info = self.lookup_typeinfo_or_incomplete_defn_error(auto_field_class_fullname) + + set_type, get_type = fields.get_field_descriptor_types(auto_field_info, is_nullable=False) + self.add_new_node_to_model_class(primary_key_attrname, Instance(auto_field_info, + [set_type, get_type])) class AddRelatedModelsId(ModelClassInitializer): + """ + Adds `FIELDNAME_id` attributes to models. + ``` + class User(models.Model): + pass + class Blog(models.Model): + user = models.ForeignKey(User) + ``` + + `user_id` will be added to `Blog`. + """ + def run_with_model_cls(self, model_cls: Type[Model]) -> None: for field in model_cls._meta.get_fields(): - if isinstance(field, ForeignKey): - related_model_cls = self.django_context.get_field_related_model_cls(field) - if related_model_cls is None: - error_context: Context = self.ctx.cls - field_sym = self.ctx.cls.info.get(field.name) - if field_sym is not None and field_sym.node is not None: - error_context = field_sym.node - self.api.fail(f'Cannot find model {field.related_model!r} ' - f'referenced in field {field.name!r} ', - ctx=error_context) - self.add_new_node_to_model_class(field.attname, - AnyType(TypeOfAny.explicit)) - continue + if not isinstance(field, (OneToOneField, ForeignKey)): + continue + related_id_attr_name = field.attname + if self.model_class_has_attribute_defined(related_id_attr_name): + continue + # if self.get_model_class_attr(related_id_attr_name) is not None: + # continue - if related_model_cls._meta.abstract: - continue + related_model_cls = self.django_context.get_field_related_model_cls(field) + if related_model_cls is None: + error_context: Context = self.ctx.cls + field_sym = self.ctx.cls.info.get(field.name) + if field_sym is not None and field_sym.node is not None: + error_context = field_sym.node + self.api.fail(f'Cannot find model {field.related_model!r} ' + f'referenced in field {field.name!r} ', + ctx=error_context) + self.add_new_node_to_model_class(related_id_attr_name, + AnyType(TypeOfAny.explicit)) + continue - rel_primary_key_field = self.django_context.get_primary_key_field(related_model_cls) - try: - field_info = self.lookup_class_typeinfo_or_incomplete_defn_error(rel_primary_key_field.__class__) - except helpers.IncompleteDefnException as exc: - if not self.api.final_iteration: - raise exc - else: - continue + if related_model_cls._meta.abstract: + continue - is_nullable = self.django_context.get_field_nullability(field, None) - set_type, get_type = get_field_descriptor_types(field_info, is_nullable) - self.add_new_node_to_model_class(field.attname, - Instance(field_info, [set_type, get_type])) + rel_primary_key_field = self.django_context.get_primary_key_field(related_model_cls) + try: + field_info = self.lookup_class_typeinfo_or_incomplete_defn_error(rel_primary_key_field.__class__) + except sem_helpers.IncompleteDefnException as exc: + if not self.api.final_iteration: + raise exc + else: + continue + + is_nullable = self.django_context.get_field_nullability(field, None) + set_type, get_type = get_field_descriptor_types(field_info, is_nullable) + self.add_new_node_to_model_class(related_id_attr_name, + Instance(field_info, [set_type, get_type])) class AddManagers(ModelClassInitializer): @@ -152,25 +212,15 @@ def has_any_parametrized_manager_as_base(self, info: TypeInfo) -> bool: def is_any_parametrized_manager(self, typ: Instance) -> bool: return typ.type.fullname in fullnames.MANAGER_CLASSES and isinstance(typ.args[0], AnyType) - def get_generated_manager_mappings(self, base_manager_fullname: str) -> Dict[str, str]: - base_manager_info = self.lookup_typeinfo(base_manager_fullname) - if (base_manager_info is None - or 'from_queryset_managers' not in base_manager_info.metadata): - return {} - return base_manager_info.metadata['from_queryset_managers'] - def create_new_model_parametrized_manager(self, name: str, base_manager_info: TypeInfo) -> Instance: bases = [] for original_base in base_manager_info.bases: if self.is_any_parametrized_manager(original_base): - if original_base.type is None: - raise helpers.IncompleteDefnException() - original_base = helpers.reparametrize_instance(original_base, [Instance(self.model_classdef.info, [])]) bases.append(original_base) - new_manager_info = self.add_new_class_for_current_module(name, bases) + new_manager_info = self.add_new_class_for_current_module(name, bases, force_replace_existing=True) # copy fields to a new manager new_cls_def_context = ClassDefContext(cls=new_manager_info.defn, reason=self.ctx.reason, @@ -178,12 +228,15 @@ def create_new_model_parametrized_manager(self, name: str, base_manager_info: Ty custom_manager_type = Instance(new_manager_info, [Instance(self.model_classdef.info, [])]) for name, sym in base_manager_info.names.items(): + if name in new_manager_info.names: + raise ValueError(f'Name {name!r} already exists on newly-created {new_manager_info.fullname!r} class.') + # replace self type with new class, if copying method if isinstance(sym.node, FuncDef): - helpers.copy_method_to_another_class(new_cls_def_context, - self_type=custom_manager_type, - new_method_name=name, - method_node=sym.node) + sem_helpers.copy_method_or_incomplete_defn_exception(new_cls_def_context, + self_type=custom_manager_type, + new_method_name=name, + method_node=sym.node) continue new_sym = sym.copy() @@ -192,32 +245,36 @@ def create_new_model_parametrized_manager(self, name: str, base_manager_info: Ty new_var.info = new_manager_info new_var._fullname = new_manager_info.fullname + '.' + name new_sym.node = new_var + new_manager_info.names[name] = new_sym return custom_manager_type def run_with_model_cls(self, model_cls: Type[Model]) -> None: for manager_name, manager in model_cls._meta.managers_map.items(): - manager_class_name = manager.__class__.__name__ - manager_fullname = helpers.get_class_fullname(manager.__class__) - try: - manager_info = self.lookup_typeinfo_or_incomplete_defn_error(manager_fullname) - except helpers.IncompleteDefnException as exc: - if not self.api.final_iteration: - raise exc - else: - base_manager_fullname = helpers.get_class_fullname(manager.__class__.__bases__[0]) - generated_managers = self.get_generated_manager_mappings(base_manager_fullname) - if manager_fullname not in generated_managers: - # not a generated manager, continue with the loop - continue - real_manager_fullname = generated_managers[manager_fullname] - manager_info = self.lookup_typeinfo(real_manager_fullname) # type: ignore - if manager_info is None: - continue - manager_class_name = real_manager_fullname.rsplit('.', maxsplit=1)[1] + if self.model_class_has_attribute_defined(manager_name, traverse_mro=False): + sym = self.model_classdef.info.names.get(manager_name) + assert sym is not None + + if (sym.type is not None + and isinstance(sym.type, Instance) + and sym.type.type.has_base(fullnames.BASE_MANAGER_CLASS_FULLNAME) + and not self.has_any_parametrized_manager_as_base(sym.type.type)): + # already defined and parametrized properly + continue + + if getattr(manager, '_built_with_as_manager', False): + # as_manager is not supported yet + if not self.model_class_has_attribute_defined(manager_name, traverse_mro=True): + self.add_new_node_to_model_class(manager_name, AnyType(TypeOfAny.explicit)) + continue + + manager_fullname = self.resolve_manager_fullname(helpers.get_class_fullname(manager.__class__)) + manager_info = self.lookup_typeinfo_or_incomplete_defn_error(manager_fullname) + manager_class_name = manager_fullname.rsplit('.', maxsplit=1)[1] if manager_name not in self.model_classdef.info.names: + # manager not yet defined, just add models.Manager[ModelName] manager_type = Instance(manager_info, [Instance(self.model_classdef.info, [])]) self.add_new_node_to_model_class(manager_name, manager_type) else: @@ -226,56 +283,67 @@ def run_with_model_cls(self, model_cls: Type[Model]) -> None: continue custom_model_manager_name = manager.model.__name__ + '_' + manager_class_name - try: - custom_manager_type = self.create_new_model_parametrized_manager(custom_model_manager_name, - base_manager_info=manager_info) - except helpers.IncompleteDefnException: - continue + custom_manager_type = self.create_new_model_parametrized_manager(custom_model_manager_name, + base_manager_info=manager_info) - self.add_new_node_to_model_class(manager_name, custom_manager_type) + self.add_new_node_to_model_class(manager_name, custom_manager_type, + force_replace_existing=True) class AddDefaultManagerAttribute(ModelClassInitializer): def run_with_model_cls(self, model_cls: Type[Model]) -> None: - # add _default_manager - if '_default_manager' not in self.model_classdef.info.names: - default_manager_fullname = helpers.get_class_fullname(model_cls._meta.default_manager.__class__) - default_manager_info = self.lookup_typeinfo_or_incomplete_defn_error(default_manager_fullname) - default_manager = Instance(default_manager_info, [Instance(self.model_classdef.info, [])]) - self.add_new_node_to_model_class('_default_manager', default_manager) + if self.model_class_has_attribute_defined('_default_manager', traverse_mro=False): + return + if model_cls._meta.default_manager is None: + return + if getattr(model_cls._meta.default_manager, '_built_with_as_manager', False): + self.add_new_node_to_model_class('_default_manager', + AnyType(TypeOfAny.explicit)) + return + + default_manager_fullname = helpers.get_class_fullname(model_cls._meta.default_manager.__class__) + resolved_default_manager_fullname = self.resolve_manager_fullname(default_manager_fullname) + + default_manager_info = self.lookup_typeinfo_or_incomplete_defn_error(resolved_default_manager_fullname) + default_manager = Instance(default_manager_info, [Instance(self.model_classdef.info, [])]) + self.add_new_node_to_model_class('_default_manager', default_manager) class AddRelatedManagers(ModelClassInitializer): def run_with_model_cls(self, model_cls: Type[Model]) -> None: # add related managers for relation in self.django_context.get_model_relations(model_cls): - attname = relation.get_accessor_name() - if attname is None: + related_manager_attr_name = relation.get_accessor_name() + if related_manager_attr_name is None: # no reverse accessor continue + if self.model_class_has_attribute_defined(related_manager_attr_name, traverse_mro=False): + continue + related_model_cls = self.django_context.get_field_related_model_cls(relation) if related_model_cls is None: continue try: related_model_info = self.lookup_class_typeinfo_or_incomplete_defn_error(related_model_cls) - except helpers.IncompleteDefnException as exc: + except sem_helpers.IncompleteDefnException as exc: if not self.api.final_iteration: raise exc else: continue if isinstance(relation, OneToOneRel): - self.add_new_node_to_model_class(attname, Instance(related_model_info, [])) + self.add_new_node_to_model_class(related_manager_attr_name, Instance(related_model_info, [])) continue if isinstance(relation, (ManyToOneRel, ManyToManyRel)): try: - related_manager_info = self.lookup_typeinfo_or_incomplete_defn_error(fullnames.RELATED_MANAGER_CLASS) # noqa: E501 + related_manager_info = self.lookup_typeinfo_or_incomplete_defn_error( + fullnames.RELATED_MANAGER_CLASS) # noqa: E501 if 'objects' not in related_model_info.names: - raise helpers.IncompleteDefnException() - except helpers.IncompleteDefnException as exc: + raise sem_helpers.IncompleteDefnException() + except sem_helpers.IncompleteDefnException as exc: if not self.api.final_iteration: raise exc else: @@ -288,14 +356,15 @@ def run_with_model_cls(self, model_cls: Type[Model]) -> None: if (default_manager_type is None or not isinstance(default_manager_type, Instance) or default_manager_type.type.fullname == fullnames.MANAGER_CLASS_FULLNAME): - self.add_new_node_to_model_class(attname, parametrized_related_manager_type) + self.add_new_node_to_model_class(related_manager_attr_name, parametrized_related_manager_type) continue name = related_model_cls.__name__ + '_' + 'RelatedManager' bases = [parametrized_related_manager_type, default_manager_type] - new_related_manager_info = self.add_new_class_for_current_module(name, bases) - - self.add_new_node_to_model_class(attname, Instance(new_related_manager_info, [])) + new_related_manager_info = self.add_new_class_for_current_module(name, bases, + force_replace_existing=True) + self.add_new_node_to_model_class(related_manager_attr_name, + Instance(new_related_manager_info, [])) class AddExtraFieldMethods(ModelClassInitializer): @@ -355,6 +424,8 @@ def process_model_class(ctx: ClassDefContext, for initializer_cls in initializers: try: initializer_cls(ctx, django_context).run() - except helpers.IncompleteDefnException: + except sem_helpers.IncompleteDefnException as exc: if not ctx.api.final_iteration: ctx.api.defer() + continue + raise exc diff --git a/mypy_django_plugin/transformers/orm_lookups.py b/mypy_django_plugin/transformers/orm_lookups.py index 0aa516be0..1dacdc874 100644 --- a/mypy_django_plugin/transformers/orm_lookups.py +++ b/mypy_django_plugin/transformers/orm_lookups.py @@ -4,7 +4,7 @@ from mypy.types import TypeOfAny from mypy_django_plugin.django.context import DjangoContext -from mypy_django_plugin.lib import fullnames, helpers +from mypy_django_plugin.lib import chk_helpers, fullnames, helpers def typecheck_queryset_filter(ctx: MethodContext, django_context: DjangoContext) -> MypyType: @@ -35,10 +35,10 @@ def typecheck_queryset_filter(ctx: MethodContext, django_context: DjangoContext) fullnames.QUERYSET_CLASS_FULLNAME))): return ctx.default_return_type - helpers.check_types_compatible(ctx, - expected_type=lookup_type, - actual_type=provided_type, - error_message=f'Incompatible type for lookup {lookup_kwarg!r}:') + chk_helpers.check_types_compatible(ctx, + expected_type=lookup_type, + actual_type=provided_type, + error_message=f'Incompatible type for lookup {lookup_kwarg!r}:') return ctx.default_return_type diff --git a/mypy_django_plugin/transformers/querysets.py b/mypy_django_plugin/transformers/querysets.py index c157bb4a0..1476e105e 100644 --- a/mypy_django_plugin/transformers/querysets.py +++ b/mypy_django_plugin/transformers/querysets.py @@ -14,7 +14,7 @@ from mypy_django_plugin.django.context import ( DjangoContext, LookupsAreUnsupported, ) -from mypy_django_plugin.lib import fullnames, helpers +from mypy_django_plugin.lib import chk_helpers, fullnames, helpers def _extract_model_type_from_queryset(queryset_type: Instance) -> Optional[Instance]: @@ -30,7 +30,7 @@ def determine_proper_manager_type(ctx: FunctionContext) -> MypyType: default_return_type = ctx.default_return_type assert isinstance(default_return_type, Instance) - outer_model_info = helpers.get_typechecker_api(ctx).scope.active_class() + outer_model_info = chk_helpers.get_typechecker_api(ctx).scope.active_class() if (outer_model_info is None or not outer_model_info.has_base(fullnames.MODEL_CLASS_FULLNAME)): return default_return_type @@ -55,7 +55,7 @@ def get_field_type_from_lookup(ctx: MethodContext, django_context: DjangoContext return AnyType(TypeOfAny.from_error) lookup_field = django_context.get_primary_key_field(related_model_cls) - field_get_type = django_context.get_field_get_type(helpers.get_typechecker_api(ctx), + field_get_type = django_context.get_field_get_type(chk_helpers.get_typechecker_api(ctx), lookup_field, method=method) return field_get_type @@ -66,7 +66,7 @@ def get_values_list_row_type(ctx: MethodContext, django_context: DjangoContext, if field_lookups is None: return AnyType(TypeOfAny.from_error) - typechecker_api = helpers.get_typechecker_api(ctx) + typechecker_api = chk_helpers.get_typechecker_api(ctx) if len(field_lookups) == 0: if flat: primary_key_field = django_context.get_primary_key_field(model_cls) @@ -80,7 +80,7 @@ def get_values_list_row_type(ctx: MethodContext, django_context: DjangoContext, column_type = django_context.get_field_get_type(typechecker_api, field, method='values_list') column_types[field.attname] = column_type - return helpers.make_oneoff_named_tuple(typechecker_api, 'Row', column_types) + return chk_helpers.make_oneoff_named_tuple(typechecker_api, 'Row', column_types) else: # flat=False, named=False, all fields field_lookups = [] @@ -103,9 +103,9 @@ def get_values_list_row_type(ctx: MethodContext, django_context: DjangoContext, assert len(column_types) == 1 row_type = next(iter(column_types.values())) elif named: - row_type = helpers.make_oneoff_named_tuple(typechecker_api, 'Row', column_types) + row_type = chk_helpers.make_oneoff_named_tuple(typechecker_api, 'Row', column_types) else: - row_type = helpers.make_tuple(typechecker_api, list(column_types.values())) + row_type = chk_helpers.make_tuple(typechecker_api, list(column_types.values())) return row_type @@ -123,13 +123,13 @@ def extract_proper_type_queryset_values_list(ctx: MethodContext, django_context: if model_cls is None: return ctx.default_return_type - flat_expr = helpers.get_call_argument_by_name(ctx, 'flat') + flat_expr = chk_helpers.get_call_argument_by_name(ctx, 'flat') if flat_expr is not None and isinstance(flat_expr, NameExpr): flat = helpers.parse_bool(flat_expr) else: flat = False - named_expr = helpers.get_call_argument_by_name(ctx, 'named') + named_expr = chk_helpers.get_call_argument_by_name(ctx, 'named') if named_expr is not None and isinstance(named_expr, NameExpr): named = helpers.parse_bool(named_expr) else: @@ -188,5 +188,5 @@ def extract_proper_type_queryset_values(ctx: MethodContext, django_context: Djan column_types[field_lookup] = field_lookup_type - row_type = helpers.make_typeddict(ctx.api, column_types, set(column_types.keys())) + row_type = chk_helpers.make_oneoff_typeddict(ctx.api, column_types, set(column_types.keys())) return helpers.reparametrize_instance(ctx.default_return_type, [model_type, row_type]) diff --git a/mypy_django_plugin/transformers/request.py b/mypy_django_plugin/transformers/request.py index be584ab42..a126f91b0 100644 --- a/mypy_django_plugin/transformers/request.py +++ b/mypy_django_plugin/transformers/request.py @@ -3,13 +3,13 @@ from mypy.types import Type as MypyType from mypy_django_plugin.django.context import DjangoContext -from mypy_django_plugin.lib import helpers +from mypy_django_plugin.lib import chk_helpers, helpers def set_auth_user_model_as_type_for_request_user(ctx: AttributeContext, django_context: DjangoContext) -> MypyType: auth_user_model = django_context.settings.AUTH_USER_MODEL model_cls = django_context.apps_registry.get_model(auth_user_model) - model_info = helpers.lookup_class_typeinfo(helpers.get_typechecker_api(ctx), model_cls) + model_info = helpers.lookup_class_typeinfo(chk_helpers.get_typechecker_api(ctx), model_cls) if model_info is None: return ctx.default_attr_type diff --git a/mypy_django_plugin/transformers/settings.py b/mypy_django_plugin/transformers/settings.py index ba6490b4e..1b5972485 100644 --- a/mypy_django_plugin/transformers/settings.py +++ b/mypy_django_plugin/transformers/settings.py @@ -5,7 +5,7 @@ from mypy.types import TypeOfAny, TypeType from mypy_django_plugin.django.context import DjangoContext -from mypy_django_plugin.lib import helpers +from mypy_django_plugin.lib import chk_helpers, helpers def get_user_model_hook(ctx: FunctionContext, django_context: DjangoContext) -> MypyType: @@ -13,7 +13,7 @@ def get_user_model_hook(ctx: FunctionContext, django_context: DjangoContext) -> model_cls = django_context.apps_registry.get_model(auth_user_model) model_cls_fullname = helpers.get_class_fullname(model_cls) - model_info = helpers.lookup_fully_qualified_typeinfo(helpers.get_typechecker_api(ctx), + model_info = helpers.lookup_fully_qualified_typeinfo(chk_helpers.get_typechecker_api(ctx), model_cls_fullname) if model_info is None: return AnyType(TypeOfAny.unannotated) @@ -28,7 +28,7 @@ def get_type_of_settings_attribute(ctx: AttributeContext, django_context: Django ctx.api.fail(f"'Settings' object has no attribute {setting_name!r}", ctx.context) return ctx.default_attr_type - typechecker_api = helpers.get_typechecker_api(ctx) + typechecker_api = chk_helpers.get_typechecker_api(ctx) # first look for the setting in the project settings file, then global settings settings_module = typechecker_api.modules.get(django_context.django_settings_module) diff --git a/scripts/enabled_test_modules.py b/scripts/enabled_test_modules.py index b1b5345d5..9861e6bcb 100644 --- a/scripts/enabled_test_modules.py +++ b/scripts/enabled_test_modules.py @@ -314,6 +314,9 @@ 'model_enums': [ "'bool' is not a valid base class", ], + 'multiple_database': [ + 'Unexpected attribute "extra_arg" for model "Book"', + ], 'null_queries': [ "Cannot resolve keyword 'foo' into field" ], diff --git a/test-data/typecheck/managers/querysets/test_as_manager.yml b/test-data/typecheck/managers/querysets/test_as_manager.yml new file mode 100644 index 000000000..fe2c20dc8 --- /dev/null +++ b/test-data/typecheck/managers/querysets/test_as_manager.yml @@ -0,0 +1,95 @@ +- case: anonymous_queryset_from_as_manager_inside_model + main: | + from myapp.models import MyModel + + reveal_type(MyModel.objects) # N: Revealed type is 'myapp.models.MyQuerySet_AsManager_MyModel' + reveal_type(MyModel.objects.get()) # N: Revealed type is 'myapp.models.MyModel*' + reveal_type(MyModel.objects.queryset_method) # N: Revealed type is 'def () -> builtins.int' + reveal_type(MyModel.objects.queryset_method()) # N: Revealed type is 'builtins.int' + installed_apps: + - myapp + files: + - path: myapp/__init__.py + - path: myapp/models.py + content: | + from django.db import models + class MyQuerySet(models.QuerySet): + def queryset_method(self) -> int: + pass + class MyModel(models.Model): + objects = MyQuerySet.as_manager() + + +- case: two_invocations_parametrized_with_different_models + main: | + from myapp.models import User, Blog + reveal_type(User.objects) # N: Revealed type is 'myapp.models.MyQuerySet_AsManager_User' + reveal_type(User.objects.get()) # N: Revealed type is 'myapp.models.User*' + reveal_type(User.objects.queryset_method) # N: Revealed type is 'def () -> builtins.int' + reveal_type(User.objects.queryset_method()) # N: Revealed type is 'builtins.int' + + reveal_type(Blog.objects) # N: Revealed type is 'myapp.models.MyQuerySet_AsManager_Blog' + reveal_type(Blog.objects.get()) # N: Revealed type is 'myapp.models.Blog*' + reveal_type(Blog.objects.queryset_method) # N: Revealed type is 'def () -> builtins.int' + reveal_type(Blog.objects.queryset_method()) # N: Revealed type is 'builtins.int' + installed_apps: + - myapp + files: + - path: myapp/__init__.py + - path: myapp/models.py + content: | + from django.db import models + class MyQuerySet(models.QuerySet): + def queryset_method(self) -> int: + pass + class User(models.Model): + objects = MyQuerySet.as_manager() + class Blog(models.Model): + objects = MyQuerySet.as_manager() + + +- case: as_manager_outside_model_parametrized_with_any + main: | + from myapp.models import NotModel, outside_objects + reveal_type(NotModel.objects) # N: Revealed type is 'myapp.models.MyQuerySet_AsManager_Any' + reveal_type(NotModel.objects.get()) # N: Revealed type is 'Any' + reveal_type(outside_objects) # N: Revealed type is 'myapp.models.MyQuerySet_AsManager_Any' + reveal_type(outside_objects.get()) # N: Revealed type is 'Any' + installed_apps: + - myapp + files: + - path: myapp/__init__.py + - path: myapp/models.py + content: | + from django.db import models + class MyQuerySet(models.QuerySet): + def queryset_method(self) -> int: + pass + outside_objects = MyQuerySet.as_manager() + class NotModel: + objects = MyQuerySet.as_manager() + +- case: test_as_manager_without_name_to_bind_in_different_files + main: | + from myapp.models import MyQuerySet + reveal_type(MyQuerySet.as_manager()) # N: Revealed type is 'Any' + reveal_type(MyQuerySet.as_manager().get()) # N: Revealed type is 'Any' + reveal_type(MyQuerySet.as_manager().mymethod()) # N: Revealed type is 'Any' + + from myapp import helpers + installed_apps: + - myapp + files: + - path: myapp/__init__.py + - path: myapp/models.py + content: | + from django.db import models + class MyQuerySet(models.QuerySet): + def mymethod(self) -> int: + pass + class MyModel(models.Model): + objects = MyQuerySet.as_manager() + - path: myapp/helpers.py + content: | + from myapp.models import MyQuerySet + MyQuerySet.as_manager() \ No newline at end of file diff --git a/test-data/typecheck/managers/querysets/test_from_queryset.yml b/test-data/typecheck/managers/querysets/test_from_queryset.yml index e9f2ad4ff..96bdf9a3a 100644 --- a/test-data/typecheck/managers/querysets/test_from_queryset.yml +++ b/test-data/typecheck/managers/querysets/test_from_queryset.yml @@ -3,6 +3,7 @@ from myapp.models import MyModel reveal_type(MyModel().objects) # N: Revealed type is 'myapp.models.MyModel_NewManager[myapp.models.MyModel]' reveal_type(MyModel().objects.get()) # N: Revealed type is 'myapp.models.MyModel*' + reveal_type(MyModel().objects.queryset_method) # N: Revealed type is 'def () -> builtins.str' reveal_type(MyModel().objects.queryset_method()) # N: Revealed type is 'builtins.str' installed_apps: - myapp @@ -178,4 +179,57 @@ from django.db import models class BaseQuerySet(models.QuerySet): def base_queryset_method(self, param: Union[int, str]) -> NoReturn: - raise ValueError \ No newline at end of file + raise ValueError + + +- case: from_queryset_with_inherited_manager_and_fk_to_auth_contrib + disable_cache: true + main: | + from myapp.base_queryset import BaseQuerySet + reveal_type(BaseQuerySet().base_queryset_method) # N: Revealed type is 'def (param: builtins.dict[builtins.str, Union[builtins.int, builtins.str]]) -> Union[builtins.int, builtins.str]' + + from django.contrib.auth.models import Permission + reveal_type(Permission().another_models) # N: Revealed type is 'django.db.models.manager.RelatedManager[myapp.models.AnotherModelInProjectWithContribAuthM2M]' + + from myapp.managers import NewManager + reveal_type(NewManager()) # N: Revealed type is 'myapp.managers.NewManager' + reveal_type(NewManager().base_queryset_method) # N: Revealed type is 'def (param: builtins.dict[builtins.str, Union[builtins.int, builtins.str]]) -> Union[builtins.int, builtins.str]' + + from myapp.models import MyModel + reveal_type(MyModel().objects) # N: Revealed type is 'myapp.models.MyModel_NewManager[myapp.models.MyModel]' + reveal_type(MyModel().objects.get()) # N: Revealed type is 'myapp.models.MyModel*' + reveal_type(MyModel().objects.base_queryset_method) # N: Revealed type is 'def (param: builtins.dict[builtins.str, Union[builtins.int, builtins.str]]) -> Union[builtins.int, builtins.str]' + installed_apps: + - myapp + - django.contrib.auth + files: + - path: myapp/__init__.py + - path: myapp/models.py + content: | + from django.db import models + from myapp.managers import NewManager + from django.contrib.auth.models import Permission + + class MyModel(models.Model): + objects = NewManager() + + class AnotherModelInProjectWithContribAuthM2M(models.Model): + permissions = models.ForeignKey( + Permission, + on_delete=models.PROTECT, + related_name='another_models' + ) + - path: myapp/managers.py + content: | + from django.db import models + from myapp.base_queryset import BaseQuerySet + class ModelQuerySet(BaseQuerySet): + pass + NewManager = models.Manager.from_queryset(ModelQuerySet) + - path: myapp/base_queryset.py + content: | + from typing import Union, Dict + from django.db import models + class BaseQuerySet(models.QuerySet): + def base_queryset_method(self, param: Dict[str, Union[int, str]]) -> Union[int, str]: + return param["hello"] \ No newline at end of file diff --git a/test-data/typecheck/managers/test_managers.yml b/test-data/typecheck/managers/test_managers.yml index 82206d645..7df83d4ca 100644 --- a/test-data/typecheck/managers/test_managers.yml +++ b/test-data/typecheck/managers/test_managers.yml @@ -307,15 +307,15 @@ - case: custom_manager_returns_proper_model_types main: | from myapp.models import User - reveal_type(User.objects) # N: Revealed type is 'myapp.models.User_MyManager2[myapp.models.User]' - reveal_type(User.objects.select_related()) # N: Revealed type is 'myapp.models.User_MyManager2[myapp.models.User]' + reveal_type(User.objects) # N: Revealed type is 'myapp.models.User_MyManager[myapp.models.User]' + reveal_type(User.objects.select_related()) # N: Revealed type is 'myapp.models.User_MyManager[myapp.models.User]' reveal_type(User.objects.get()) # N: Revealed type is 'myapp.models.User*' reveal_type(User.objects.get_instance()) # N: Revealed type is 'builtins.int' reveal_type(User.objects.get_instance_untyped('hello')) # N: Revealed type is 'Any' from myapp.models import ChildUser - reveal_type(ChildUser.objects) # N: Revealed type is 'myapp.models.ChildUser_MyManager2[myapp.models.ChildUser]' - reveal_type(ChildUser.objects.select_related()) # N: Revealed type is 'myapp.models.ChildUser_MyManager2[myapp.models.ChildUser]' + reveal_type(ChildUser.objects) # N: Revealed type is 'myapp.models.ChildUser_MyManager[myapp.models.ChildUser]' + reveal_type(ChildUser.objects.select_related()) # N: Revealed type is 'myapp.models.ChildUser_MyManager[myapp.models.ChildUser]' reveal_type(ChildUser.objects.get()) # N: Revealed type is 'myapp.models.ChildUser*' reveal_type(ChildUser.objects.get_instance()) # N: Revealed type is 'builtins.int' reveal_type(ChildUser.objects.get_instance_untyped('hello')) # N: Revealed type is 'Any' @@ -335,3 +335,23 @@ objects = MyManager() class ChildUser(models.Model): objects = MyManager() + + +- case: manager_defined_in_the_nested_class + main: | + from myapp.models import MyModel + reveal_type(MyModel.objects) # N: Revealed type is 'myapp.models.MyModel_MyManager[myapp.models.MyModel]' + reveal_type(MyModel.objects.get()) # N: Revealed type is 'myapp.models.MyModel*' + reveal_type(MyModel.objects.mymethod()) # N: Revealed type is 'builtins.int' + installed_apps: + - myapp + files: + - path: myapp/__init__.py + - path: myapp/models.py + content: | + from django.db import models + class MyModel(models.Model): + class MyManager(models.Manager): + def mymethod(self) -> int: + pass + objects = MyManager() \ No newline at end of file