Skip to content

Commit

Permalink
[Dy2stat]Fix error in tensor_shape_transformer. (#37999) (#38168)
Browse files Browse the repository at this point in the history
修复tensor_shape_transformer中的错误。
之前在类似if len(paddle.shape(x)[0]) > 0中,paddle会被当做一个变量被传入convert_var_shape函数中
  • Loading branch information
0x45f authored Dec 16, 2021
1 parent 8100c16 commit 19eb510
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,10 @@ def _is_var_shape(self, node):
return False

if isinstance(node, gast.Attribute):
# If node is `paddle.shape`, return False
if (node.attr == 'shape' and isinstance(node.value, gast.Name) and
node.value.id == 'paddle'):
return False
if node.attr != 'shape':
return False
return True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,12 @@ def dyfunc_change_shape_after_assign(x):
return res


def dyfunc_len_paddle_shape():
x = paddle.to_tensor([1, 2, 3])
if len(paddle.shape(x)) > 0:
print(x)


# 1. Basic tests without control flow
class TestTensorShapeBasic(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -582,5 +588,11 @@ def test(self):
func.concrete_program


class TestPaddleShape(unittest.TestCase):
def test_paddle_shape(self):
func = paddle.jit.to_static(dyfunc_len_paddle_shape)
self.assertEqual('paddle.shape(x)' in func.code, True)


if __name__ == '__main__':
unittest.main()

0 comments on commit 19eb510

Please sign in to comment.