diff --git a/mypy/treetransform.py b/mypy/treetransform.py index 06ea5cbed629..798ce8fbeeb4 100644 --- a/mypy/treetransform.py +++ b/mypy/treetransform.py @@ -22,6 +22,7 @@ YieldExpr, ExecStmt, Argument, BackquoteExpr, AwaitExpr, ) from mypy.types import Type, FunctionLike, Instance +from mypy.traverser import TraverserVisitor from mypy.visitor import NodeVisitor @@ -36,7 +37,7 @@ class TransformVisitor(NodeVisitor[Node]): * Do not duplicate TypeInfo nodes. This would generally not be desirable. * Only update some name binding cross-references, but only those that - refer to Var nodes, not those targeting ClassDef, TypeInfo or FuncDef + refer to Var or FuncDef nodes, not those targeting ClassDef or TypeInfo nodes. * Types are not transformed, but you can override type() to also perform type transformation. @@ -48,6 +49,11 @@ def __init__(self) -> None: # There may be multiple references to a Var node. Keep track of # Var translations using a dictionary. self.var_map = {} # type: Dict[Var, Var] + # These are uninitialized placeholder nodes used temporarily for nested + # functions while we are transforming a top-level function. This maps an + # untransformed node to a placeholder (which will later become the + # transformed node). + self.func_placeholder_map = {} # type: Dict[FuncDef, FuncDef] def visit_mypy_file(self, node: MypyFile) -> Node: # NOTE: The 'names' and 'imports' instance variables will be empty! @@ -98,6 +104,18 @@ def copy_argument(self, argument: Argument) -> Argument: def visit_func_def(self, node: FuncDef) -> FuncDef: # Note that a FuncDef must be transformed to a FuncDef. + + # These contortions are needed to handle the case of recursive + # references inside the function being transformed. + # Set up placholder nodes for references within this function + # to other functions defined inside it. + # Don't create an entry for this function itself though, + # since we want self-references to point to the original + # function if this is the top-level node we are transforming. + init = FuncMapInitializer(self) + for stmt in node.body.body: + stmt.accept(init) + new = FuncDef(node.name(), [self.copy_argument(arg) for arg in node.arguments], self.block(node.body), @@ -113,7 +131,17 @@ def visit_func_def(self, node: FuncDef) -> FuncDef: new.is_class = node.is_class new.is_property = node.is_property new.original_def = node.original_def - return new + + if node in self.func_placeholder_map: + # There is a placeholder definition for this function. Replace + # the attributes of the placeholder with those form the transformed + # function. We know that the classes will be identical (otherwise + # this wouldn't work). + result = self.func_placeholder_map[node] + result.__dict__ = new.__dict__ + return result + else: + return new def visit_func_expr(self, node: FuncExpr) -> Node: new = FuncExpr([self.copy_argument(arg) for arg in node.arguments], @@ -330,6 +358,9 @@ def copy_ref(self, new: RefExpr, original: RefExpr) -> None: target = original.node if isinstance(target, Var): target = self.visit_var(target) + elif isinstance(target, FuncDef): + # Use a placeholder node for the function if it exists. + target = self.func_placeholder_map.get(target, target) new.node = target new.is_def = original.is_def @@ -527,3 +558,20 @@ def types(self, types: List[Type]) -> List[Type]: def optional_types(self, types: List[Type]) -> List[Type]: return [self.optional_type(type) for type in types] + + +class FuncMapInitializer(TraverserVisitor): + """This traverser creates mappings from nested FuncDefs to placeholder FuncDefs. + + The placholders will later be replaced with transformed nodes. + """ + + def __init__(self, transformer: TransformVisitor) -> None: + self.transformer = transformer + + def visit_func_def(self, node: FuncDef) -> None: + if node not in self.transformer.func_placeholder_map: + # Haven't seen this FuncDef before, so create a placeholder node. + self.transformer.func_placeholder_map[node] = FuncDef( + node.name(), node.arguments, node.body, None) + super().visit_func_def(node) diff --git a/test-data/unit/check-typevar-values.test b/test-data/unit/check-typevar-values.test index b165c25fe638..238be6e5e918 100644 --- a/test-data/unit/check-typevar-values.test +++ b/test-data/unit/check-typevar-values.test @@ -479,3 +479,33 @@ a = g b = g b = g b = f # E: Incompatible types in assignment (expression has type Callable[[T], T], variable has type Callable[[U], U]) + +[case testInnerFunctionWithTypevarValues] +from typing import TypeVar +T = TypeVar('T', int, str) +U = TypeVar('U', int, str) +def outer(x: T) -> T: + def inner(y: T) -> T: + return x + def inner2(y: U) -> U: + return y + inner(x) + inner(3) # E: Argument 1 to "inner" has incompatible type "int"; expected "str" + inner2(x) + inner2(3) + outer(3) + return x +[out] +main: note: In function "outer": + +[case testInnerFunctionMutualRecursionWithTypevarValues] +from typing import TypeVar +T = TypeVar('T', int, str) +def outer(x: T) -> T: + def inner1(y: T) -> T: + return inner2(y) + def inner2(y: T) -> T: + return inner1('a') # E: Argument 1 to "inner1" has incompatible type "str"; expected "int" + return inner1(x) +[out] +main: note: In function "inner2":