Skip to content

Commit

Permalink
Add support for classmethods and staticmethods in add_method (#13397)
Browse files Browse the repository at this point in the history
Co-authored-by: Nikita Sobolev <mail@sobolevn.me>
  • Loading branch information
svalentin and sobolevn authored Aug 20, 2022
1 parent 48bd26e commit 551f8f4
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 3 deletions.
43 changes: 40 additions & 3 deletions mypy/plugins/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
Block,
CallExpr,
ClassDef,
Decorator,
Expression,
FuncDef,
JsonDict,
Expand All @@ -26,6 +27,7 @@
CallableType,
Overloaded,
Type,
TypeType,
TypeVarType,
deserialize_type,
get_proper_type,
Expand Down Expand Up @@ -102,6 +104,8 @@ def add_method(
return_type: Type,
self_type: Type | None = None,
tvar_def: TypeVarType | None = None,
is_classmethod: bool = False,
is_staticmethod: bool = False,
) -> None:
"""
Adds a new method to a class.
Expand All @@ -115,6 +119,8 @@ def add_method(
return_type=return_type,
self_type=self_type,
tvar_def=tvar_def,
is_classmethod=is_classmethod,
is_staticmethod=is_staticmethod,
)


Expand All @@ -126,8 +132,15 @@ def add_method_to_class(
return_type: Type,
self_type: Type | None = None,
tvar_def: TypeVarType | None = None,
is_classmethod: bool = False,
is_staticmethod: bool = False,
) -> None:
"""Adds a new method to a class definition."""

assert not (
is_classmethod is True and is_staticmethod is True
), "Can't add a new method that's both staticmethod and classmethod."

info = cls.info

# First remove any previously generated methods with the same name
Expand All @@ -137,13 +150,21 @@ def add_method_to_class(
if sym.plugin_generated and isinstance(sym.node, FuncDef):
cls.defs.body.remove(sym.node)

self_type = self_type or fill_typevars(info)
if isinstance(api, SemanticAnalyzerPluginInterface):
function_type = api.named_type("builtins.function")
else:
function_type = api.named_generic_type("builtins.function", [])

args = [Argument(Var("self"), self_type, None, ARG_POS)] + args
if is_classmethod:
self_type = self_type or TypeType(fill_typevars(info))
first = [Argument(Var("_cls"), self_type, None, ARG_POS, True)]
elif is_staticmethod:
first = []
else:
self_type = self_type or fill_typevars(info)
first = [Argument(Var("self"), self_type, None, ARG_POS)]
args = first + args

arg_types, arg_names, arg_kinds = [], [], []
for arg in args:
assert arg.type_annotation, "All arguments must be fully typed."
Expand All @@ -158,6 +179,8 @@ def add_method_to_class(
func = FuncDef(name, args, Block([PassStmt()]))
func.info = info
func.type = set_callable_name(signature, func)
func.is_class = is_classmethod
func.is_static = is_staticmethod
func._fullname = info.fullname + "." + name
func.line = info.line

Expand All @@ -168,7 +191,21 @@ def add_method_to_class(
r_name = get_unique_redefinition_name(name, info.names)
info.names[r_name] = info.names[name]

info.names[name] = SymbolTableNode(MDEF, func, plugin_generated=True)
# Add decorator for is_staticmethod. It's unnecessary for is_classmethod.
if is_staticmethod:
func.is_decorated = True
v = Var(name, func.type)
v.info = info
v._fullname = func._fullname
v.is_staticmethod = True
dec = Decorator(func, [], v)
dec.line = info.line
sym = SymbolTableNode(MDEF, dec)
else:
sym = SymbolTableNode(MDEF, func)
sym.plugin_generated = True
info.names[name] = sym

info.defn.defs.body.append(func)


Expand Down
14 changes: 14 additions & 0 deletions test-data/unit/check-custom-plugin.test
Original file line number Diff line number Diff line change
Expand Up @@ -991,3 +991,17 @@ class Cls:
[file mypy.ini]
\[mypy]
plugins=<ROOT>/test-data/unit/plugins/class_attr_hook.py

[case testAddClassMethodPlugin]
# flags: --config-file tmp/mypy.ini
class BaseAddMethod: pass

class MyClass(BaseAddMethod):
pass

my_class = MyClass()
reveal_type(MyClass.foo_classmethod) # N: Revealed type is "def ()"
reveal_type(MyClass.foo_staticmethod) # N: Revealed type is "def (builtins.int) -> builtins.str"
[file mypy.ini]
\[mypy]
plugins=<ROOT>/test-data/unit/plugins/add_classmethod.py
32 changes: 32 additions & 0 deletions test-data/unit/check-incremental.test
Original file line number Diff line number Diff line change
Expand Up @@ -5911,6 +5911,38 @@ tmp/c.py:4: note: Revealed type is "TypedDict('a.N', {'r': Union[TypedDict('b.M'
tmp/c.py:5: error: Incompatible types in assignment (expression has type "Optional[N]", variable has type "int")
tmp/c.py:7: note: Revealed type is "TypedDict('a.N', {'r': Union[TypedDict('b.M', {'r': Union[..., None], 'x': builtins.int}), None], 'x': builtins.int})"

[case testIncrementalAddClassMethodPlugin]
# flags: --config-file tmp/mypy.ini
import b

[file mypy.ini]
\[mypy]
plugins=<ROOT>/test-data/unit/plugins/add_classmethod.py

[file a.py]
class BaseAddMethod: pass

class MyClass(BaseAddMethod):
pass

[file b.py]
import a

[file b.py.2]
import a

my_class = a.MyClass()
reveal_type(a.MyClass.foo_classmethod)
reveal_type(a.MyClass.foo_staticmethod)
reveal_type(my_class.foo_classmethod)
reveal_type(my_class.foo_staticmethod)

[rechecked b]
[out2]
tmp/b.py:4: note: Revealed type is "def ()"
tmp/b.py:5: note: Revealed type is "def (builtins.int) -> builtins.str"
tmp/b.py:6: note: Revealed type is "def ()"
tmp/b.py:7: note: Revealed type is "def (builtins.int) -> builtins.str"
[case testGenericNamedTupleSerialization]
import b
[file a.py]
Expand Down
28 changes: 28 additions & 0 deletions test-data/unit/plugins/add_classmethod.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from typing import Callable, Optional

from mypy.nodes import ARG_POS, Argument, Var
from mypy.plugin import ClassDefContext, Plugin
from mypy.plugins.common import add_method
from mypy.types import NoneType


class ClassMethodPlugin(Plugin):
def get_base_class_hook(self, fullname: str) -> Optional[Callable[[ClassDefContext], None]]:
if "BaseAddMethod" in fullname:
return add_extra_methods_hook
return None


def add_extra_methods_hook(ctx: ClassDefContext) -> None:
add_method(ctx, "foo_classmethod", [], NoneType(), is_classmethod=True)
add_method(
ctx,
"foo_staticmethod",
[Argument(Var(""), ctx.api.named_type("builtins.int"), None, ARG_POS)],
ctx.api.named_type("builtins.str"),
is_staticmethod=True,
)


def plugin(version):
return ClassMethodPlugin

0 comments on commit 551f8f4

Please sign in to comment.