Skip to content

Commit

Permalink
[Dy2stat] Fix to_tensor Bug Reported from QA (#32701)
Browse files Browse the repository at this point in the history
Dy2stat failed when user writes return paddle.to_tensor(xxx), the reason is that visit_Expr doesn't work when the Expr is in return. Some other statements may trigger same bug. To fix it, we re-wrote a transformer to transform paddle.to_tensor to paddle.assign for all Call nodes.
  • Loading branch information
zhhsplendid authored Apr 30, 2021
1 parent 0a0f324 commit 0026819
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,11 @@ def __init__(self, wrapper_root):
self.root = wrapper_root.node
self.class_node_dict = {}

self.name_to_tensor_shape = {}

def transform(self):
to_tensor_transformer = ToTensorTransformer(self.root)
to_tensor_transformer.transform()
self.visit(self.root)

return self.wrapper_root

def visit_Assign(self, node):
Expand All @@ -62,11 +63,6 @@ def visit_Expr(self, node):

def _visit_Call(self, node):
assert isinstance(node, gast.Call)
# Replace API `to_variable` with `fluid.layers.assign`
if is_to_variable(node):
node = to_assign_node(node)
return node

func_name = astor.to_source(gast.gast_to_ast(node.func))

if self._is_dygraph_forward(func_name):
Expand Down Expand Up @@ -102,6 +98,29 @@ def _update_class_node_dict(self, node):
return False


class ToTensorTransformer(gast.NodeTransformer):
"""
Class to transform paddle.to_tensor and paddle.to_variable to paddle.assign
"""

def __init__(self, node):
assert isinstance(
node, gast.AST
), "Input non-gast.AST node for the initialization of ToTensorTransformer."
self.root = node

def transform(self):
self.visit(self.root)
return self.root

def visit_Call(self, node):
assert isinstance(node, gast.Call)
if is_to_variable(node):
node = to_assign_node(node)
self.generic_visit(node)
return node


def is_to_variable(node):
assert isinstance(node, gast.Call)
api_name = utils.ast_to_source_code(node.func).strip()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,11 @@ def dyfunc_int_to_tensor(x):


def dyfunc_float_to_tensor(x):
res = paddle.to_tensor(2.0)
return res
return paddle.to_tensor(2.0)


def dyfunc_bool_to_tensor(x):
res = paddle.to_tensor(True)
return res
return paddle.to_tensor(True)


class TestDygraphBasicApi_ToVariable(unittest.TestCase):
Expand Down

0 comments on commit 0026819

Please sign in to comment.