From 31c68e3d062def638f35f54ff5c533f0085b2aa1 Mon Sep 17 00:00:00 2001 From: zhhsplendid Date: Fri, 30 Apr 2021 06:24:37 +0000 Subject: [PATCH 1/3] Fix to_tensor Bug Reported from QA --- .../basic_api_transformer.py | 33 +++++++++++++++---- 1 file changed, 26 insertions(+), 7 deletions(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/basic_api_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/basic_api_transformer.py index 198c2920eec7f..953def62191dd 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/basic_api_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/basic_api_transformer.py @@ -33,10 +33,11 @@ def __init__(self, wrapper_root): self.root = wrapper_root.node self.class_node_dict = {} - self.name_to_tensor_shape = {} - def transform(self): + to_tensor_transformer = ToTensorTransformer(self.root) + to_tensor_transformer.transform() self.visit(self.root) + return self.wrapper_root def visit_Assign(self, node): @@ -62,11 +63,6 @@ def visit_Expr(self, node): def _visit_Call(self, node): assert isinstance(node, gast.Call) - # Replace API `to_variable` with `fluid.layers.assign` - if is_to_variable(node): - node = to_assign_node(node) - return node - func_name = astor.to_source(gast.gast_to_ast(node.func)) if self._is_dygraph_forward(func_name): @@ -102,6 +98,29 @@ def _update_class_node_dict(self, node): return False +class ToTensorTransformer(gast.NodeTransformer): + """ + Class to transform paddle.to_tensor and paddle.to_variable to paddle.assign + """ + + def __init__(self, node): + assert isinstance( + node, gast.AST + ), "Input non-gast.AST node for the initialization of ToTensorTransformer." + self.root = node + + def transform(self): + self.visit(self.root) + return self.root + + def visit_Call(node): + assert isinstance(node, gast.Call) + if is_to_variable(node): + node = to_assign_node(node) + generic_visit(node) + return node + + def is_to_variable(node): assert isinstance(node, gast.Call) api_name = utils.ast_to_source_code(node.func).strip() From 7b9239ed32dbcb0351e56695ba315c74ec7b66f9 Mon Sep 17 00:00:00 2001 From: zhhsplendid Date: Fri, 30 Apr 2021 06:34:05 +0000 Subject: [PATCH 2/3] Add test --- .../dygraph_to_static/test_basic_api_transformation.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_basic_api_transformation.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_basic_api_transformation.py index 630b804f9a2fb..ea745ad661425 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_basic_api_transformation.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_basic_api_transformation.py @@ -64,13 +64,11 @@ def dyfunc_int_to_tensor(x): def dyfunc_float_to_tensor(x): - res = paddle.to_tensor(2.0) - return res + return paddle.to_tensor(2.0) def dyfunc_bool_to_tensor(x): - res = paddle.to_tensor(True) - return res + return paddle.to_tensor(True) class TestDygraphBasicApi_ToVariable(unittest.TestCase): From 127f85d292035b368606c7239206db5fa67b9f5b Mon Sep 17 00:00:00 2001 From: zhhsplendid Date: Fri, 30 Apr 2021 06:58:03 +0000 Subject: [PATCH 3/3] Add test and fix bug during test --- .../fluid/dygraph/dygraph_to_static/basic_api_transformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/basic_api_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/basic_api_transformer.py index 953def62191dd..5ea1fdfac0928 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/basic_api_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/basic_api_transformer.py @@ -113,11 +113,11 @@ def transform(self): self.visit(self.root) return self.root - def visit_Call(node): + def visit_Call(self, node): assert isinstance(node, gast.Call) if is_to_variable(node): node = to_assign_node(node) - generic_visit(node) + self.generic_visit(node) return node