Skip to content

Commit

Permalink
[Dy2stat]Fix error in tensor_shape_transformer. (#37999)
Browse files Browse the repository at this point in the history
* fix error when tensor_shape_transformer. Before in stmt like `if len(paddle.shape(x)[0]) > 0`, `paddle` will be used as a variable

* handle other call like `fluid.layers.mean` and `fluid.layers.shape`

* add unit test
  • Loading branch information
0x45f authored Dec 15, 2021
1 parent 141b285 commit 5082249
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,10 @@ def _is_var_shape(self, node):
return False

if isinstance(node, gast.Attribute):
# If node is `paddle.shape`, return False
if (node.attr == 'shape' and isinstance(node.value, gast.Name) and
node.value.id == 'paddle'):
return False
if node.attr != 'shape':
return False
return True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,12 @@ def dyfunc_change_shape_after_assign(x):
return res


def dyfunc_len_paddle_shape():
x = paddle.to_tensor([1, 2, 3])
if len(paddle.shape(x)) > 0:
print(x)


# 1. Basic tests without control flow
class TestTensorShapeBasic(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -582,5 +588,11 @@ def test(self):
func.concrete_program


class TestPaddleShape(unittest.TestCase):
def test_paddle_shape(self):
func = paddle.jit.to_static(dyfunc_len_paddle_shape)
self.assertEqual('paddle.shape(x)' in func.code, True)


if __name__ == '__main__':
unittest.main()

0 comments on commit 5082249

Please sign in to comment.