From eb8b7b97f1a7dc10b6bb0e6aea0f3d9eec6571c8 Mon Sep 17 00:00:00 2001 From: Reid Barton Date: Tue, 19 Apr 2016 12:04:43 -0700 Subject: [PATCH 1/2] Rewrite references to inner functions in treetransform Fixes #1323. --- mypy/treetransform.py | 38 ++++++++++++++++++++++-- test-data/unit/check-typevar-values.test | 30 +++++++++++++++++++ 2 files changed, 66 insertions(+), 2 deletions(-) diff --git a/mypy/treetransform.py b/mypy/treetransform.py index 06ea5cbed629..5b57c4b1f648 100644 --- a/mypy/treetransform.py +++ b/mypy/treetransform.py @@ -22,6 +22,7 @@ YieldExpr, ExecStmt, Argument, BackquoteExpr, AwaitExpr, ) from mypy.types import Type, FunctionLike, Instance +from mypy.traverser import TraverserVisitor from mypy.visitor import NodeVisitor @@ -36,7 +37,7 @@ class TransformVisitor(NodeVisitor[Node]): * Do not duplicate TypeInfo nodes. This would generally not be desirable. * Only update some name binding cross-references, but only those that - refer to Var nodes, not those targeting ClassDef, TypeInfo or FuncDef + refer to Var or FuncDef nodes, not those targeting ClassDef or TypeInfo nodes. * Types are not transformed, but you can override type() to also perform type transformation. @@ -48,6 +49,7 @@ def __init__(self) -> None: # There may be multiple references to a Var node. Keep track of # Var translations using a dictionary. self.var_map = {} # type: Dict[Var, Var] + self.func_map = {} # type: Dict[FuncDef, FuncDef] def visit_mypy_file(self, node: MypyFile) -> Node: # NOTE: The 'names' and 'imports' instance variables will be empty! @@ -98,6 +100,18 @@ def copy_argument(self, argument: Argument) -> Argument: def visit_func_def(self, node: FuncDef) -> FuncDef: # Note that a FuncDef must be transformed to a FuncDef. + + # These contortions are needed to handle the case of recursive + # references inside the function being transformed. + # Set up empty nodes for references within this function + # to other functions defined inside it. + # Don't create an entry for this function itself though, + # since we want self-references to point to the original + # function if this is the top-level node we are transforming. + init = FuncMapInitializer(self) + for stmt in node.body.body: + stmt.accept(init) + new = FuncDef(node.name(), [self.copy_argument(arg) for arg in node.arguments], self.block(node.body), @@ -113,7 +127,13 @@ def visit_func_def(self, node: FuncDef) -> FuncDef: new.is_class = node.is_class new.is_property = node.is_property new.original_def = node.original_def - return new + + if node in self.func_map: + result = self.func_map[node] + result.__dict__ = new.__dict__ + return result + else: + return new def visit_func_expr(self, node: FuncExpr) -> Node: new = FuncExpr([self.copy_argument(arg) for arg in node.arguments], @@ -330,6 +350,9 @@ def copy_ref(self, new: RefExpr, original: RefExpr) -> None: target = original.node if isinstance(target, Var): target = self.visit_var(target) + elif isinstance(target, FuncDef): + if target in self.func_map: + target = self.func_map[target] new.node = target new.is_def = original.is_def @@ -527,3 +550,14 @@ def types(self, types: List[Type]) -> List[Type]: def optional_types(self, types: List[Type]) -> List[Type]: return [self.optional_type(type) for type in types] + + +class FuncMapInitializer(TraverserVisitor): + def __init__(self, transformer: TransformVisitor) -> None: + self.transformer = transformer + + def visit_func_def(self, node: FuncDef) -> None: + if node not in self.transformer.func_map: + self.transformer.func_map[node] = FuncDef( + node.name(), node.arguments, node.body, None) + super().visit_func_def(node) diff --git a/test-data/unit/check-typevar-values.test b/test-data/unit/check-typevar-values.test index b165c25fe638..238be6e5e918 100644 --- a/test-data/unit/check-typevar-values.test +++ b/test-data/unit/check-typevar-values.test @@ -479,3 +479,33 @@ a = g b = g b = g b = f # E: Incompatible types in assignment (expression has type Callable[[T], T], variable has type Callable[[U], U]) + +[case testInnerFunctionWithTypevarValues] +from typing import TypeVar +T = TypeVar('T', int, str) +U = TypeVar('U', int, str) +def outer(x: T) -> T: + def inner(y: T) -> T: + return x + def inner2(y: U) -> U: + return y + inner(x) + inner(3) # E: Argument 1 to "inner" has incompatible type "int"; expected "str" + inner2(x) + inner2(3) + outer(3) + return x +[out] +main: note: In function "outer": + +[case testInnerFunctionMutualRecursionWithTypevarValues] +from typing import TypeVar +T = TypeVar('T', int, str) +def outer(x: T) -> T: + def inner1(y: T) -> T: + return inner2(y) + def inner2(y: T) -> T: + return inner1('a') # E: Argument 1 to "inner1" has incompatible type "str"; expected "int" + return inner1(x) +[out] +main: note: In function "inner2": From ebc3a8bdf2b59a2a46e202668506c5078ec25a2a Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Mon, 29 Aug 2016 15:01:12 +0100 Subject: [PATCH 2/2] Add comments and minor code cleanup --- mypy/treetransform.py | 30 ++++++++++++++++++++++-------- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/mypy/treetransform.py b/mypy/treetransform.py index 5b57c4b1f648..798ce8fbeeb4 100644 --- a/mypy/treetransform.py +++ b/mypy/treetransform.py @@ -49,7 +49,11 @@ def __init__(self) -> None: # There may be multiple references to a Var node. Keep track of # Var translations using a dictionary. self.var_map = {} # type: Dict[Var, Var] - self.func_map = {} # type: Dict[FuncDef, FuncDef] + # These are uninitialized placeholder nodes used temporarily for nested + # functions while we are transforming a top-level function. This maps an + # untransformed node to a placeholder (which will later become the + # transformed node). + self.func_placeholder_map = {} # type: Dict[FuncDef, FuncDef] def visit_mypy_file(self, node: MypyFile) -> Node: # NOTE: The 'names' and 'imports' instance variables will be empty! @@ -103,7 +107,7 @@ def visit_func_def(self, node: FuncDef) -> FuncDef: # These contortions are needed to handle the case of recursive # references inside the function being transformed. - # Set up empty nodes for references within this function + # Set up placholder nodes for references within this function # to other functions defined inside it. # Don't create an entry for this function itself though, # since we want self-references to point to the original @@ -128,8 +132,12 @@ def visit_func_def(self, node: FuncDef) -> FuncDef: new.is_property = node.is_property new.original_def = node.original_def - if node in self.func_map: - result = self.func_map[node] + if node in self.func_placeholder_map: + # There is a placeholder definition for this function. Replace + # the attributes of the placeholder with those form the transformed + # function. We know that the classes will be identical (otherwise + # this wouldn't work). + result = self.func_placeholder_map[node] result.__dict__ = new.__dict__ return result else: @@ -351,8 +359,8 @@ def copy_ref(self, new: RefExpr, original: RefExpr) -> None: if isinstance(target, Var): target = self.visit_var(target) elif isinstance(target, FuncDef): - if target in self.func_map: - target = self.func_map[target] + # Use a placeholder node for the function if it exists. + target = self.func_placeholder_map.get(target, target) new.node = target new.is_def = original.is_def @@ -553,11 +561,17 @@ def optional_types(self, types: List[Type]) -> List[Type]: class FuncMapInitializer(TraverserVisitor): + """This traverser creates mappings from nested FuncDefs to placeholder FuncDefs. + + The placholders will later be replaced with transformed nodes. + """ + def __init__(self, transformer: TransformVisitor) -> None: self.transformer = transformer def visit_func_def(self, node: FuncDef) -> None: - if node not in self.transformer.func_map: - self.transformer.func_map[node] = FuncDef( + if node not in self.transformer.func_placeholder_map: + # Haven't seen this FuncDef before, so create a placeholder node. + self.transformer.func_placeholder_map[node] = FuncDef( node.name(), node.arguments, node.body, None) super().visit_func_def(node)