diff --git a/mypy/plugin.py b/mypy/plugin.py index 27917a6216f5..4ffa9395afc5 100644 --- a/mypy/plugin.py +++ b/mypy/plugin.py @@ -4,7 +4,7 @@ from abc import abstractmethod from typing import Callable, List, Tuple, Optional, NamedTuple, TypeVar -from mypy.nodes import Expression, StrExpr, IntExpr, UnaryExpr, Context, DictExpr +from mypy.nodes import Expression, StrExpr, IntExpr, UnaryExpr, Context, DictExpr, ClassDef from mypy.types import ( Type, Instance, CallableType, TypedDictType, UnionType, NoneTyp, FunctionLike, TypeVarType, AnyType, TypeList, UnboundType, TypeOfAny @@ -13,7 +13,7 @@ from mypy.options import Options -class AnalyzerPluginInterface: +class TypeAnalyzerPluginInterface: """Interface for accessing semantic analyzer functionality in plugins.""" @abstractmethod @@ -40,7 +40,7 @@ def analyze_callable_args(self, arglist: TypeList) -> Optional[Tuple[List[Type], 'AnalyzeTypeContext', [ ('type', UnboundType), # Type to analyze ('context', Context), - ('api', AnalyzerPluginInterface)]) + ('api', TypeAnalyzerPluginInterface)]) class CheckerPluginInterface: @@ -53,6 +53,23 @@ def named_generic_type(self, name: str, args: List[Type]) -> Instance: raise NotImplementedError +class SemanticAnalyzerPluginInterface: + """Interface for accessing semantic analyzer functionality in plugins.""" + + @abstractmethod + def named_type(self, qualified_name: str, args: Optional[List[Type]] = None) -> Instance: + raise NotImplementedError + + @abstractmethod + def parse_bool(self, expr: Expression) -> Optional[bool]: + raise NotImplementedError + + @abstractmethod + def fail(self, msg: str, ctx: Context, serious: bool = False, *, + blocker: bool = False) -> None: + raise NotImplementedError + + # A context for a function hook that infers the return type of a function with # a special signature. # @@ -98,6 +115,14 @@ def named_generic_type(self, name: str, args: List[Type]) -> Instance: ('context', Context), ('api', CheckerPluginInterface)]) +# A context for a class hook that modifies the class definition. +ClassDefContext = NamedTuple( + 'ClassDecoratorContext', [ + ('cls', ClassDef), # The class definition + ('reason', Expression), # The expression being applied (decorator, metaclass, base class) + ('api', SemanticAnalyzerPluginInterface) + ]) + class Plugin: """Base class of all type checker plugins. @@ -136,7 +161,17 @@ def get_attribute_hook(self, fullname: str ) -> Optional[Callable[[AttributeContext], Type]]: return None - # TODO: metaclass / class decorator hook + def get_class_decorator_hook(self, fullname: str + ) -> Optional[Callable[[ClassDefContext], None]]: + return None + + def get_metaclass_hook(self, fullname: str + ) -> Optional[Callable[[ClassDefContext], None]]: + return None + + def get_base_class_hook(self, fullname: str + ) -> Optional[Callable[[ClassDefContext], None]]: + return None T = TypeVar('T') @@ -182,6 +217,18 @@ def get_attribute_hook(self, fullname: str ) -> Optional[Callable[[AttributeContext], Type]]: return self._find_hook(lambda plugin: plugin.get_attribute_hook(fullname)) + def get_class_decorator_hook(self, fullname: str + ) -> Optional[Callable[[ClassDefContext], None]]: + return self._find_hook(lambda plugin: plugin.get_class_decorator_hook(fullname)) + + def get_metaclass_hook(self, fullname: str + ) -> Optional[Callable[[ClassDefContext], None]]: + return self._find_hook(lambda plugin: plugin.get_metaclass_hook(fullname)) + + def get_base_class_hook(self, fullname: str + ) -> Optional[Callable[[ClassDefContext], None]]: + return self._find_hook(lambda plugin: plugin.get_base_class_hook(fullname)) + def _find_hook(self, lookup: Callable[[Plugin], T]) -> Optional[T]: for plugin in self._plugins: hook = lookup(plugin) diff --git a/mypy/semanal.py b/mypy/semanal.py index e9d71eb4ecf3..d8ff72ef7af7 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -81,7 +81,7 @@ from mypy.sametypes import is_same_type from mypy.options import Options from mypy import experiments -from mypy.plugin import Plugin +from mypy.plugin import Plugin, ClassDefContext, SemanticAnalyzerPluginInterface from mypy import join from mypy.util import get_prefix @@ -172,7 +172,7 @@ } -class SemanticAnalyzerPass2(NodeVisitor[None]): +class SemanticAnalyzerPass2(NodeVisitor[None], SemanticAnalyzerPluginInterface): """Semantically analyze parsed mypy files. The analyzer binds names and does various consistency checks for a @@ -719,9 +719,48 @@ def analyze_class_body(self, defn: ClassDef) -> Iterator[bool]: yield True self.calculate_abstract_status(defn.info) self.setup_type_promotion(defn) - + self.apply_class_plugin_hooks(defn) self.leave_class() + def apply_class_plugin_hooks(self, defn: ClassDef) -> None: + """Apply a plugin hook that may infer a more precise definition for a class.""" + def get_fullname(expr: Expression) -> Optional[str]: + if isinstance(expr, CallExpr): + return get_fullname(expr.callee) + elif isinstance(expr, IndexExpr): + return get_fullname(expr.base) + elif isinstance(expr, RefExpr): + if expr.fullname: + return expr.fullname + # If we don't have a fullname look it up. This happens because base classes are + # analyzed in a different manner (see exprtotype.py) and therefore those AST + # nodes will not have full names. + sym = self.lookup_type_node(expr) + if sym: + return sym.fullname + return None + + for decorator in defn.decorators: + decorator_name = get_fullname(decorator) + if decorator_name: + hook = self.plugin.get_class_decorator_hook(decorator_name) + if hook: + hook(ClassDefContext(defn, decorator, self)) + + if defn.metaclass: + metaclass_name = get_fullname(defn.metaclass) + if metaclass_name: + hook = self.plugin.get_metaclass_hook(metaclass_name) + if hook: + hook(ClassDefContext(defn, defn.metaclass, self)) + + for base_expr in defn.base_type_exprs: + base_name = get_fullname(base_expr) + if base_name: + hook = self.plugin.get_base_class_hook(base_name) + if hook: + hook(ClassDefContext(defn, base_expr, self)) + def analyze_class_keywords(self, defn: ClassDef) -> None: for value in defn.keywords.values(): value.accept(self) diff --git a/mypy/typeanal.py b/mypy/typeanal.py index f5cfcc472d19..1b4467d031db 100644 --- a/mypy/typeanal.py +++ b/mypy/typeanal.py @@ -27,7 +27,7 @@ from mypy.sametypes import is_same_type from mypy.exprtotype import expr_to_unanalyzed_type, TypeTranslationError from mypy.subtypes import is_subtype -from mypy.plugin import Plugin, AnalyzerPluginInterface, AnalyzeTypeContext +from mypy.plugin import Plugin, TypeAnalyzerPluginInterface, AnalyzeTypeContext from mypy import nodes, messages @@ -132,7 +132,7 @@ def no_subscript_builtin_alias(name: str, propose_alt: bool = True) -> str: return msg -class TypeAnalyser(SyntheticTypeVisitor[Type], AnalyzerPluginInterface): +class TypeAnalyser(SyntheticTypeVisitor[Type], TypeAnalyzerPluginInterface): """Semantic analyzer for types (semantic analysis pass 2). Converts unbound types into bound types.