Skip to content

Commit 3acbf3f

Browse files
sobolevnmsullivan
andauthored
Adds get_function_signature_hook (#9102)
This PR introduces get_function_signature_hook that behaves the similar way as get_method_signature_hook. Closes #9101 Co-authored-by: Michael Sullivan <sully@msully.net>
1 parent e4131a5 commit 3acbf3f

File tree

4 files changed

+118
-21
lines changed

4 files changed

+118
-21
lines changed

mypy/checkexpr.py

Lines changed: 50 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,11 @@
5656
from mypy.util import split_module_names
5757
from mypy.typevars import fill_typevars
5858
from mypy.visitor import ExpressionVisitor
59-
from mypy.plugin import Plugin, MethodContext, MethodSigContext, FunctionContext
59+
from mypy.plugin import (
60+
Plugin,
61+
MethodContext, MethodSigContext,
62+
FunctionContext, FunctionSigContext,
63+
)
6064
from mypy.typeops import (
6165
tuple_fallback, make_simplified_union, true_only, false_only, erase_to_union_or_bound,
6266
function_type, callable_type, try_getting_str_literals, custom_special_method,
@@ -730,12 +734,15 @@ def apply_function_plugin(self,
730734
callee.arg_names, formal_arg_names,
731735
callee.ret_type, formal_arg_exprs, context, self.chk))
732736

733-
def apply_method_signature_hook(
737+
def apply_signature_hook(
734738
self, callee: FunctionLike, args: List[Expression],
735-
arg_kinds: List[int], context: Context,
736-
arg_names: Optional[Sequence[Optional[str]]], object_type: Type,
737-
signature_hook: Callable[[MethodSigContext], CallableType]) -> FunctionLike:
738-
"""Apply a plugin hook that may infer a more precise signature for a method."""
739+
arg_kinds: List[int],
740+
arg_names: Optional[Sequence[Optional[str]]],
741+
hook: Callable[
742+
[List[List[Expression]], CallableType],
743+
CallableType,
744+
]) -> FunctionLike:
745+
"""Helper to apply a signature hook for either a function or method"""
739746
if isinstance(callee, CallableType):
740747
num_formals = len(callee.arg_kinds)
741748
formal_to_actual = map_actuals_to_formals(
@@ -746,19 +753,40 @@ def apply_method_signature_hook(
746753
for formal, actuals in enumerate(formal_to_actual):
747754
for actual in actuals:
748755
formal_arg_exprs[formal].append(args[actual])
749-
object_type = get_proper_type(object_type)
750-
return signature_hook(
751-
MethodSigContext(object_type, formal_arg_exprs, callee, context, self.chk))
756+
return hook(formal_arg_exprs, callee)
752757
else:
753758
assert isinstance(callee, Overloaded)
754759
items = []
755760
for item in callee.items():
756-
adjusted = self.apply_method_signature_hook(
757-
item, args, arg_kinds, context, arg_names, object_type, signature_hook)
761+
adjusted = self.apply_signature_hook(
762+
item, args, arg_kinds, arg_names, hook)
758763
assert isinstance(adjusted, CallableType)
759764
items.append(adjusted)
760765
return Overloaded(items)
761766

767+
def apply_function_signature_hook(
768+
self, callee: FunctionLike, args: List[Expression],
769+
arg_kinds: List[int], context: Context,
770+
arg_names: Optional[Sequence[Optional[str]]],
771+
signature_hook: Callable[[FunctionSigContext], CallableType]) -> FunctionLike:
772+
"""Apply a plugin hook that may infer a more precise signature for a function."""
773+
return self.apply_signature_hook(
774+
callee, args, arg_kinds, arg_names,
775+
(lambda args, sig:
776+
signature_hook(FunctionSigContext(args, sig, context, self.chk))))
777+
778+
def apply_method_signature_hook(
779+
self, callee: FunctionLike, args: List[Expression],
780+
arg_kinds: List[int], context: Context,
781+
arg_names: Optional[Sequence[Optional[str]]], object_type: Type,
782+
signature_hook: Callable[[MethodSigContext], CallableType]) -> FunctionLike:
783+
"""Apply a plugin hook that may infer a more precise signature for a method."""
784+
pobject_type = get_proper_type(object_type)
785+
return self.apply_signature_hook(
786+
callee, args, arg_kinds, arg_names,
787+
(lambda args, sig:
788+
signature_hook(MethodSigContext(pobject_type, args, sig, context, self.chk))))
789+
762790
def transform_callee_type(
763791
self, callable_name: Optional[str], callee: Type, args: List[Expression],
764792
arg_kinds: List[int], context: Context,
@@ -779,13 +807,17 @@ def transform_callee_type(
779807
(if appropriate) before the signature is passed to check_call.
780808
"""
781809
callee = get_proper_type(callee)
782-
if (callable_name is not None
783-
and object_type is not None
784-
and isinstance(callee, FunctionLike)):
785-
signature_hook = self.plugin.get_method_signature_hook(callable_name)
786-
if signature_hook:
787-
return self.apply_method_signature_hook(
788-
callee, args, arg_kinds, context, arg_names, object_type, signature_hook)
810+
if callable_name is not None and isinstance(callee, FunctionLike):
811+
if object_type is not None:
812+
method_sig_hook = self.plugin.get_method_signature_hook(callable_name)
813+
if method_sig_hook:
814+
return self.apply_method_signature_hook(
815+
callee, args, arg_kinds, context, arg_names, object_type, method_sig_hook)
816+
else:
817+
function_sig_hook = self.plugin.get_function_signature_hook(callable_name)
818+
if function_sig_hook:
819+
return self.apply_function_signature_hook(
820+
callee, args, arg_kinds, context, arg_names, function_sig_hook)
789821

790822
return callee
791823

mypy/plugin.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,16 @@ def final_iteration(self) -> bool:
365365
('is_check', bool) # Is this invocation for checking whether the config matches
366366
])
367367

368+
# A context for a function signature hook that infers a better signature for a
369+
# function. Note that argument types aren't available yet. If you need them,
370+
# you have to use a method hook instead.
371+
FunctionSigContext = NamedTuple(
372+
'FunctionSigContext', [
373+
('args', List[List[Expression]]), # Actual expressions for each formal argument
374+
('default_signature', CallableType), # Original signature of the method
375+
('context', Context), # Relevant location context (e.g. for error messages)
376+
('api', CheckerPluginInterface)])
377+
368378
# A context for a function hook that infers the return type of a function with
369379
# a special signature.
370380
#
@@ -395,7 +405,7 @@ def final_iteration(self) -> bool:
395405
# TODO: document ProperType in the plugin changelog/update issue.
396406
MethodSigContext = NamedTuple(
397407
'MethodSigContext', [
398-
('type', ProperType), # Base object type for method call
408+
('type', ProperType), # Base object type for method call
399409
('args', List[List[Expression]]), # Actual expressions for each formal argument
400410
('default_signature', CallableType), # Original signature of the method
401411
('context', Context), # Relevant location context (e.g. for error messages)
@@ -407,7 +417,7 @@ def final_iteration(self) -> bool:
407417
# This is very similar to FunctionContext (only differences are documented).
408418
MethodContext = NamedTuple(
409419
'MethodContext', [
410-
('type', ProperType), # Base object type for method call
420+
('type', ProperType), # Base object type for method call
411421
('arg_types', List[List[Type]]), # List of actual caller types for each formal argument
412422
# see FunctionContext for details about names and kinds
413423
('arg_kinds', List[List[int]]),
@@ -421,7 +431,7 @@ def final_iteration(self) -> bool:
421431
# A context for an attribute type hook that infers the type of an attribute.
422432
AttributeContext = NamedTuple(
423433
'AttributeContext', [
424-
('type', ProperType), # Type of object with attribute
434+
('type', ProperType), # Type of object with attribute
425435
('default_attr_type', Type), # Original attribute type
426436
('context', Context), # Relevant location context (e.g. for error messages)
427437
('api', CheckerPluginInterface)])
@@ -533,6 +543,22 @@ def func(x: Other[int]) -> None:
533543
"""
534544
return None
535545

546+
def get_function_signature_hook(self, fullname: str
547+
) -> Optional[Callable[[FunctionSigContext], CallableType]]:
548+
"""Adjust the signature a function.
549+
550+
This method is called before type checking a function call. Plugin
551+
may infer a better type for the function.
552+
553+
from lib import Class, do_stuff
554+
555+
do_stuff(42)
556+
Class()
557+
558+
This method will be called with 'lib.do_stuff' and then with 'lib.Class'.
559+
"""
560+
return None
561+
536562
def get_function_hook(self, fullname: str
537563
) -> Optional[Callable[[FunctionContext], Type]]:
538564
"""Adjust the return type of a function call.
@@ -721,6 +747,10 @@ def get_type_analyze_hook(self, fullname: str
721747
) -> Optional[Callable[[AnalyzeTypeContext], Type]]:
722748
return self._find_hook(lambda plugin: plugin.get_type_analyze_hook(fullname))
723749

750+
def get_function_signature_hook(self, fullname: str
751+
) -> Optional[Callable[[FunctionSigContext], CallableType]]:
752+
return self._find_hook(lambda plugin: plugin.get_function_signature_hook(fullname))
753+
724754
def get_function_hook(self, fullname: str
725755
) -> Optional[Callable[[FunctionContext], Type]]:
726756
return self._find_hook(lambda plugin: plugin.get_function_hook(fullname))

test-data/unit/check-custom-plugin.test

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -721,3 +721,12 @@ Cls().attr = "foo" # E: Incompatible types in assignment (expression has type "
721721
[file mypy.ini]
722722
\[mypy]
723723
plugins=<ROOT>/test-data/unit/plugins/descriptor.py
724+
725+
[case testFunctionSigPluginFile]
726+
# flags: --config-file tmp/mypy.ini
727+
728+
def dynamic_signature(arg1: str) -> str: ...
729+
reveal_type(dynamic_signature(1)) # N: Revealed type is 'builtins.int'
730+
[file mypy.ini]
731+
\[mypy]
732+
plugins=<ROOT>/test-data/unit/plugins/function_sig_hook.py
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from mypy.plugin import CallableType, CheckerPluginInterface, FunctionSigContext, Plugin
2+
from mypy.types import Instance, Type
3+
4+
class FunctionSigPlugin(Plugin):
5+
def get_function_signature_hook(self, fullname):
6+
if fullname == '__main__.dynamic_signature':
7+
return my_hook
8+
return None
9+
10+
def _str_to_int(api: CheckerPluginInterface, typ: Type) -> Type:
11+
if isinstance(typ, Instance):
12+
if typ.type.fullname == 'builtins.str':
13+
return api.named_generic_type('builtins.int', [])
14+
elif typ.args:
15+
return typ.copy_modified(args=[_str_to_int(api, t) for t in typ.args])
16+
17+
return typ
18+
19+
def my_hook(ctx: FunctionSigContext) -> CallableType:
20+
return ctx.default_signature.copy_modified(
21+
arg_types=[_str_to_int(ctx.api, t) for t in ctx.default_signature.arg_types],
22+
ret_type=_str_to_int(ctx.api, ctx.default_signature.ret_type),
23+
)
24+
25+
def plugin(version):
26+
return FunctionSigPlugin

0 commit comments

Comments
 (0)