diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 88badf25e690..ba44f0ea673a 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -56,7 +56,11 @@ from mypy.util import split_module_names from mypy.typevars import fill_typevars from mypy.visitor import ExpressionVisitor -from mypy.plugin import Plugin, MethodContext, MethodSigContext, FunctionContext +from mypy.plugin import ( + Plugin, + MethodContext, MethodSigContext, + FunctionContext, FunctionSigContext, +) from mypy.typeops import ( tuple_fallback, make_simplified_union, true_only, false_only, erase_to_union_or_bound, function_type, callable_type, try_getting_str_literals, custom_special_method, @@ -730,12 +734,15 @@ def apply_function_plugin(self, callee.arg_names, formal_arg_names, callee.ret_type, formal_arg_exprs, context, self.chk)) - def apply_method_signature_hook( + def apply_signature_hook( self, callee: FunctionLike, args: List[Expression], - arg_kinds: List[int], context: Context, - arg_names: Optional[Sequence[Optional[str]]], object_type: Type, - signature_hook: Callable[[MethodSigContext], CallableType]) -> FunctionLike: - """Apply a plugin hook that may infer a more precise signature for a method.""" + arg_kinds: List[int], + arg_names: Optional[Sequence[Optional[str]]], + hook: Callable[ + [List[List[Expression]], CallableType], + CallableType, + ]) -> FunctionLike: + """Helper to apply a signature hook for either a function or method""" if isinstance(callee, CallableType): num_formals = len(callee.arg_kinds) formal_to_actual = map_actuals_to_formals( @@ -746,19 +753,40 @@ def apply_method_signature_hook( for formal, actuals in enumerate(formal_to_actual): for actual in actuals: formal_arg_exprs[formal].append(args[actual]) - object_type = get_proper_type(object_type) - return signature_hook( - MethodSigContext(object_type, formal_arg_exprs, callee, context, self.chk)) + return hook(formal_arg_exprs, callee) else: assert isinstance(callee, Overloaded) items = [] for item in callee.items(): - adjusted = self.apply_method_signature_hook( - item, args, arg_kinds, context, arg_names, object_type, signature_hook) + adjusted = self.apply_signature_hook( + item, args, arg_kinds, arg_names, hook) assert isinstance(adjusted, CallableType) items.append(adjusted) return Overloaded(items) + def apply_function_signature_hook( + self, callee: FunctionLike, args: List[Expression], + arg_kinds: List[int], context: Context, + arg_names: Optional[Sequence[Optional[str]]], + signature_hook: Callable[[FunctionSigContext], CallableType]) -> FunctionLike: + """Apply a plugin hook that may infer a more precise signature for a function.""" + return self.apply_signature_hook( + callee, args, arg_kinds, arg_names, + (lambda args, sig: + signature_hook(FunctionSigContext(args, sig, context, self.chk)))) + + def apply_method_signature_hook( + self, callee: FunctionLike, args: List[Expression], + arg_kinds: List[int], context: Context, + arg_names: Optional[Sequence[Optional[str]]], object_type: Type, + signature_hook: Callable[[MethodSigContext], CallableType]) -> FunctionLike: + """Apply a plugin hook that may infer a more precise signature for a method.""" + pobject_type = get_proper_type(object_type) + return self.apply_signature_hook( + callee, args, arg_kinds, arg_names, + (lambda args, sig: + signature_hook(MethodSigContext(pobject_type, args, sig, context, self.chk)))) + def transform_callee_type( self, callable_name: Optional[str], callee: Type, args: List[Expression], arg_kinds: List[int], context: Context, @@ -779,13 +807,17 @@ def transform_callee_type( (if appropriate) before the signature is passed to check_call. """ callee = get_proper_type(callee) - if (callable_name is not None - and object_type is not None - and isinstance(callee, FunctionLike)): - signature_hook = self.plugin.get_method_signature_hook(callable_name) - if signature_hook: - return self.apply_method_signature_hook( - callee, args, arg_kinds, context, arg_names, object_type, signature_hook) + if callable_name is not None and isinstance(callee, FunctionLike): + if object_type is not None: + method_sig_hook = self.plugin.get_method_signature_hook(callable_name) + if method_sig_hook: + return self.apply_method_signature_hook( + callee, args, arg_kinds, context, arg_names, object_type, method_sig_hook) + else: + function_sig_hook = self.plugin.get_function_signature_hook(callable_name) + if function_sig_hook: + return self.apply_function_signature_hook( + callee, args, arg_kinds, context, arg_names, function_sig_hook) return callee diff --git a/mypy/plugin.py b/mypy/plugin.py index fcc372dacf8c..52c44d457c1b 100644 --- a/mypy/plugin.py +++ b/mypy/plugin.py @@ -365,6 +365,16 @@ def final_iteration(self) -> bool: ('is_check', bool) # Is this invocation for checking whether the config matches ]) +# A context for a function signature hook that infers a better signature for a +# function. Note that argument types aren't available yet. If you need them, +# you have to use a method hook instead. +FunctionSigContext = NamedTuple( + 'FunctionSigContext', [ + ('args', List[List[Expression]]), # Actual expressions for each formal argument + ('default_signature', CallableType), # Original signature of the method + ('context', Context), # Relevant location context (e.g. for error messages) + ('api', CheckerPluginInterface)]) + # A context for a function hook that infers the return type of a function with # a special signature. # @@ -395,7 +405,7 @@ def final_iteration(self) -> bool: # TODO: document ProperType in the plugin changelog/update issue. MethodSigContext = NamedTuple( 'MethodSigContext', [ - ('type', ProperType), # Base object type for method call + ('type', ProperType), # Base object type for method call ('args', List[List[Expression]]), # Actual expressions for each formal argument ('default_signature', CallableType), # Original signature of the method ('context', Context), # Relevant location context (e.g. for error messages) @@ -407,7 +417,7 @@ def final_iteration(self) -> bool: # This is very similar to FunctionContext (only differences are documented). MethodContext = NamedTuple( 'MethodContext', [ - ('type', ProperType), # Base object type for method call + ('type', ProperType), # Base object type for method call ('arg_types', List[List[Type]]), # List of actual caller types for each formal argument # see FunctionContext for details about names and kinds ('arg_kinds', List[List[int]]), @@ -421,7 +431,7 @@ def final_iteration(self) -> bool: # A context for an attribute type hook that infers the type of an attribute. AttributeContext = NamedTuple( 'AttributeContext', [ - ('type', ProperType), # Type of object with attribute + ('type', ProperType), # Type of object with attribute ('default_attr_type', Type), # Original attribute type ('context', Context), # Relevant location context (e.g. for error messages) ('api', CheckerPluginInterface)]) @@ -533,6 +543,22 @@ def func(x: Other[int]) -> None: """ return None + def get_function_signature_hook(self, fullname: str + ) -> Optional[Callable[[FunctionSigContext], CallableType]]: + """Adjust the signature a function. + + This method is called before type checking a function call. Plugin + may infer a better type for the function. + + from lib import Class, do_stuff + + do_stuff(42) + Class() + + This method will be called with 'lib.do_stuff' and then with 'lib.Class'. + """ + return None + def get_function_hook(self, fullname: str ) -> Optional[Callable[[FunctionContext], Type]]: """Adjust the return type of a function call. @@ -721,6 +747,10 @@ def get_type_analyze_hook(self, fullname: str ) -> Optional[Callable[[AnalyzeTypeContext], Type]]: return self._find_hook(lambda plugin: plugin.get_type_analyze_hook(fullname)) + def get_function_signature_hook(self, fullname: str + ) -> Optional[Callable[[FunctionSigContext], CallableType]]: + return self._find_hook(lambda plugin: plugin.get_function_signature_hook(fullname)) + def get_function_hook(self, fullname: str ) -> Optional[Callable[[FunctionContext], Type]]: return self._find_hook(lambda plugin: plugin.get_function_hook(fullname)) diff --git a/test-data/unit/check-custom-plugin.test b/test-data/unit/check-custom-plugin.test index 6e7f6a066a95..9ab79bafd244 100644 --- a/test-data/unit/check-custom-plugin.test +++ b/test-data/unit/check-custom-plugin.test @@ -721,3 +721,12 @@ Cls().attr = "foo" # E: Incompatible types in assignment (expression has type " [file mypy.ini] \[mypy] plugins=/test-data/unit/plugins/descriptor.py + +[case testFunctionSigPluginFile] +# flags: --config-file tmp/mypy.ini + +def dynamic_signature(arg1: str) -> str: ... +reveal_type(dynamic_signature(1)) # N: Revealed type is 'builtins.int' +[file mypy.ini] +\[mypy] +plugins=/test-data/unit/plugins/function_sig_hook.py diff --git a/test-data/unit/plugins/function_sig_hook.py b/test-data/unit/plugins/function_sig_hook.py new file mode 100644 index 000000000000..d83c7df26209 --- /dev/null +++ b/test-data/unit/plugins/function_sig_hook.py @@ -0,0 +1,26 @@ +from mypy.plugin import CallableType, CheckerPluginInterface, FunctionSigContext, Plugin +from mypy.types import Instance, Type + +class FunctionSigPlugin(Plugin): + def get_function_signature_hook(self, fullname): + if fullname == '__main__.dynamic_signature': + return my_hook + return None + +def _str_to_int(api: CheckerPluginInterface, typ: Type) -> Type: + if isinstance(typ, Instance): + if typ.type.fullname == 'builtins.str': + return api.named_generic_type('builtins.int', []) + elif typ.args: + return typ.copy_modified(args=[_str_to_int(api, t) for t in typ.args]) + + return typ + +def my_hook(ctx: FunctionSigContext) -> CallableType: + return ctx.default_signature.copy_modified( + arg_types=[_str_to_int(ctx.api, t) for t in ctx.default_signature.arg_types], + ret_type=_str_to_int(ctx.api, ctx.default_signature.ret_type), + ) + +def plugin(version): + return FunctionSigPlugin