From 66822d6e452bd5625b05333d95dd348f79869958 Mon Sep 17 00:00:00 2001 From: 0x45f Date: Thu, 9 Dec 2021 06:35:12 +0000 Subject: [PATCH 1/3] fix error when tensor_shape_transformer. Before in stmt like `if len(paddle.shape(x)[0]) > 0`, `paddle` will be used as a variable --- .../dygraph/dygraph_to_static/tensor_shape_transformer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py index 0bc167132e3ed..5aa50f0d21ede 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py @@ -282,7 +282,8 @@ def _is_var_shape(self, node): return False if isinstance(node, gast.Attribute): - if node.attr != 'shape': + # If node is `paddle.shape`, return False + if node.value.id == 'paddle' or node.attr != 'shape': return False return True From c1a8668aa7ea9293b8bd08c35ca775c6ea9f19b2 Mon Sep 17 00:00:00 2001 From: 0x45f Date: Thu, 9 Dec 2021 08:32:39 +0000 Subject: [PATCH 2/3] handle other call like `fluid.layers.mean` and `fluid.layers.shape` --- .../dygraph/dygraph_to_static/tensor_shape_transformer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py index 5aa50f0d21ede..e1df2324889b4 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py @@ -283,7 +283,10 @@ def _is_var_shape(self, node): if isinstance(node, gast.Attribute): # If node is `paddle.shape`, return False - if node.value.id == 'paddle' or node.attr != 'shape': + if (node.attr == 'shape' and isinstance(node.value, gast.Name) and + node.value.id == 'paddle'): + return False + if node.attr != 'shape': return False return True From 31b8941d01f47885bb5afef1c065002e14875bc5 Mon Sep 17 00:00:00 2001 From: 0x45f Date: Thu, 9 Dec 2021 12:34:07 +0000 Subject: [PATCH 3/3] add unit test --- .../unittests/dygraph_to_static/test_tensor_shape.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py index f7cdb12a1ab67..06d69daa75d1c 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py @@ -217,6 +217,12 @@ def dyfunc_change_shape_after_assign(x): return res +def dyfunc_len_paddle_shape(): + x = paddle.to_tensor([1, 2, 3]) + if len(paddle.shape(x)) > 0: + print(x) + + # 1. Basic tests without control flow class TestTensorShapeBasic(unittest.TestCase): def setUp(self): @@ -582,5 +588,11 @@ def test(self): func.concrete_program +class TestPaddleShape(unittest.TestCase): + def test_paddle_shape(self): + func = paddle.jit.to_static(dyfunc_len_paddle_shape) + self.assertEqual('paddle.shape(x)' in func.code, True) + + if __name__ == '__main__': unittest.main()