diff --git a/docs/source/extending_mypy.rst b/docs/source/extending_mypy.rst index 5c59bef506cc..00c328be7728 100644 --- a/docs/source/extending_mypy.rst +++ b/docs/source/extending_mypy.rst @@ -198,6 +198,10 @@ fields which already exist on the class. *Exception:* if :py:meth:`__getattr__ < :py:meth:`__getattribute__ ` is a method on the class, the hook is called for all fields which do not refer to methods. +**get_class_attribute_hook()** is similar to above, but for attributes on classes rather than instances. +Unlike above, this does not have special casing for :py:meth:`__getattr__ ` or +:py:meth:`__getattribute__ `. + **get_class_decorator_hook()** can be used to update class definition for given class decorators. For example, you can add some attributes to the class to match runtime behaviour: diff --git a/mypy/checkmember.py b/mypy/checkmember.py index c01f52de5a77..b1058dfb2132 100644 --- a/mypy/checkmember.py +++ b/mypy/checkmember.py @@ -703,10 +703,13 @@ def analyze_class_attribute_access(itype: Instance, if override_info: info = override_info + fullname = '{}.{}'.format(info.fullname, name) + hook = mx.chk.plugin.get_class_attribute_hook(fullname) + node = info.get(name) if not node: if info.fallback_to_any: - return AnyType(TypeOfAny.special_form) + return apply_class_attr_hook(mx, hook, AnyType(TypeOfAny.special_form)) return None is_decorated = isinstance(node.node, Decorator) @@ -731,14 +734,16 @@ def analyze_class_attribute_access(itype: Instance, if info.is_enum and not (mx.is_lvalue or is_decorated or is_method): enum_class_attribute_type = analyze_enum_class_attribute_access(itype, name, mx) if enum_class_attribute_type: - return enum_class_attribute_type + return apply_class_attr_hook(mx, hook, enum_class_attribute_type) t = node.type if t: if isinstance(t, PartialType): symnode = node.node assert isinstance(symnode, Var) - return mx.chk.handle_partial_var_type(t, mx.is_lvalue, symnode, mx.context) + return apply_class_attr_hook(mx, hook, + mx.chk.handle_partial_var_type(t, mx.is_lvalue, symnode, + mx.context)) # Find the class where method/variable was defined. if isinstance(node.node, Decorator): @@ -789,7 +794,8 @@ def analyze_class_attribute_access(itype: Instance, mx.self_type, original_vars=original_vars) if not mx.is_lvalue: result = analyze_descriptor_access(result, mx) - return result + + return apply_class_attr_hook(mx, hook, result) elif isinstance(node.node, Var): mx.not_ready_callback(name, mx.context) return AnyType(TypeOfAny.special_form) @@ -813,7 +819,7 @@ def analyze_class_attribute_access(itype: Instance, if is_decorated: assert isinstance(node.node, Decorator) if node.node.type: - return node.node.type + return apply_class_attr_hook(mx, hook, node.node.type) else: mx.not_ready_callback(name, mx.context) return AnyType(TypeOfAny.from_error) @@ -825,7 +831,17 @@ def analyze_class_attribute_access(itype: Instance, # unannotated implicit class methods we do this here. if node.node.is_class: typ = bind_self(typ, is_classmethod=True) - return typ + return apply_class_attr_hook(mx, hook, typ) + + +def apply_class_attr_hook(mx: MemberContext, + hook: Optional[Callable[[AttributeContext], Type]], + result: Type, + ) -> Optional[Type]: + if hook: + result = hook(AttributeContext(get_proper_type(mx.original_type), + result, mx.context, mx.chk)) + return result def analyze_enum_class_attribute_access(itype: Instance, diff --git a/mypy/plugin.py b/mypy/plugin.py index 3772d7039b05..8a4f39186085 100644 --- a/mypy/plugin.py +++ b/mypy/plugin.py @@ -632,10 +632,10 @@ def get_method_hook(self, fullname: str def get_attribute_hook(self, fullname: str ) -> Optional[Callable[[AttributeContext], Type]]: - """Adjust type of a class attribute. + """Adjust type of an instance attribute. - This method is called with attribute full name using the class where the attribute was - defined (or Var.info.fullname for generated attributes). + This method is called with attribute full name using the class of the instance where + the attribute was defined (or Var.info.fullname for generated attributes). For classes without __getattr__ or __getattribute__, this hook is only called for names of fields/properties (but not methods) that exist in the instance MRO. @@ -662,6 +662,25 @@ class Derived(Base): """ return None + def get_class_attribute_hook(self, fullname: str + ) -> Optional[Callable[[AttributeContext], Type]]: + """ + Adjust type of a class attribute. + + This method is called with attribute full name using the class where the attribute was + defined (or Var.info.fullname for generated attributes). + + For example: + + class Cls: + x: Any + + Cls.x + + get_class_attribute_hook is called with '__main__.Cls.x' as fullname. + """ + return None + def get_class_decorator_hook(self, fullname: str ) -> Optional[Callable[[ClassDefContext], None]]: """Update class definition for given class decorators. @@ -783,6 +802,10 @@ def get_attribute_hook(self, fullname: str ) -> Optional[Callable[[AttributeContext], Type]]: return self._find_hook(lambda plugin: plugin.get_attribute_hook(fullname)) + def get_class_attribute_hook(self, fullname: str + ) -> Optional[Callable[[AttributeContext], Type]]: + return self._find_hook(lambda plugin: plugin.get_class_attribute_hook(fullname)) + def get_class_decorator_hook(self, fullname: str ) -> Optional[Callable[[ClassDefContext], None]]: return self._find_hook(lambda plugin: plugin.get_class_decorator_hook(fullname)) diff --git a/test-data/unit/check-custom-plugin.test b/test-data/unit/check-custom-plugin.test index 2707d886d64e..6f8dac77c442 100644 --- a/test-data/unit/check-custom-plugin.test +++ b/test-data/unit/check-custom-plugin.test @@ -902,3 +902,92 @@ reveal_type(f()) # N: Revealed type is "builtins.str" [file mypy.ini] \[mypy] plugins=/test-data/unit/plugins/method_in_decorator.py + +[case testClassAttrPluginClassVar] +# flags: --config-file tmp/mypy.ini + +from typing import Type + +class Cls: + attr = 'test' + unchanged = 'test' + +reveal_type(Cls().attr) # N: Revealed type is "builtins.str" +reveal_type(Cls.attr) # N: Revealed type is "builtins.int" +reveal_type(Cls.unchanged) # N: Revealed type is "builtins.str" +x: Type[Cls] +reveal_type(x.attr) # N: Revealed type is "builtins.int" +[file mypy.ini] +\[mypy] +plugins=/test-data/unit/plugins/class_attr_hook.py + +[case testClassAttrPluginMethod] +# flags: --config-file tmp/mypy.ini + +class Cls: + def attr(self) -> None: + pass + +reveal_type(Cls.attr) # N: Revealed type is "builtins.int" +[file mypy.ini] +\[mypy] +plugins=/test-data/unit/plugins/class_attr_hook.py + +[case testClassAttrPluginEnum] +# flags: --config-file tmp/mypy.ini + +import enum + +class Cls(enum.Enum): + attr = 'test' + +reveal_type(Cls.attr) # N: Revealed type is "builtins.int" +[file mypy.ini] +\[mypy] +plugins=/test-data/unit/plugins/class_attr_hook.py + +[case testClassAttrPluginMetaclassAnyBase] +# flags: --config-file tmp/mypy.ini + +from typing import Any, Type +class M(type): + attr = 'test' + +B: Any +class Cls(B, metaclass=M): + pass + +reveal_type(Cls.attr) # N: Revealed type is "builtins.int" +[file mypy.ini] +\[mypy] +plugins=/test-data/unit/plugins/class_attr_hook.py + +[case testClassAttrPluginMetaclassRegularBase] +# flags: --config-file tmp/mypy.ini + +from typing import Any, Type +class M(type): + attr = 'test' + +class B: + attr = None + +class Cls(B, metaclass=M): + pass + +reveal_type(Cls.attr) # N: Revealed type is "builtins.int" +[file mypy.ini] +\[mypy] +plugins=/test-data/unit/plugins/class_attr_hook.py + +[case testClassAttrPluginPartialType] +# flags: --config-file tmp/mypy.ini + +class Cls: + attr = None + def f(self) -> int: + return Cls.attr + 1 + +[file mypy.ini] +\[mypy] +plugins=/test-data/unit/plugins/class_attr_hook.py diff --git a/test-data/unit/plugins/class_attr_hook.py b/test-data/unit/plugins/class_attr_hook.py new file mode 100644 index 000000000000..348e5df0ee03 --- /dev/null +++ b/test-data/unit/plugins/class_attr_hook.py @@ -0,0 +1,20 @@ +from typing import Callable, Optional + +from mypy.plugin import AttributeContext, Plugin +from mypy.types import Type as MypyType + + +class ClassAttrPlugin(Plugin): + def get_class_attribute_hook(self, fullname: str + ) -> Optional[Callable[[AttributeContext], MypyType]]: + if fullname == '__main__.Cls.attr': + return my_hook + return None + + +def my_hook(ctx: AttributeContext) -> MypyType: + return ctx.api.named_generic_type('builtins.int', []) + + +def plugin(_version: str): + return ClassAttrPlugin