From 96a99bd51f0da2132334fad7268daa3f5b4b9e8f Mon Sep 17 00:00:00 2001 From: 0x45f <23097963+0x45f@users.noreply.github.com> Date: Tue, 7 Dec 2021 14:11:17 +0800 Subject: [PATCH] [Dy2Stat]Polish for zip in dy2stat (#37846) * polish for zip in dy2stat * polish comment * polish is_builtin_len * fix comment --- .../dygraph_to_static/call_transformer.py | 7 +++-- .../dygraph_to_static/convert_call_func.py | 9 +++++- .../dygraph_to_static/convert_operators.py | 9 ++++++ .../dygraph_to_static/test_for_enumerate.py | 28 +++++++++++++++++++ 4 files changed, 49 insertions(+), 4 deletions(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/call_transformer.py b/python/paddle/fluid/dygraph/dygraph_to_static/call_transformer.py index 3e606139245d6..a80dfa11402c5 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/call_transformer.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/call_transformer.py @@ -39,7 +39,7 @@ def _no_need_convert_call(self, node): Determines whether a function needs to be transformed by `convert_call`. It doesn't need to be transformed when a function satisfies the following conditions: 1. It's a api of paddle - 2. It's a python builtin function not include `len` + 2. It's a python builtin function not include `len` and `zip` """ assert isinstance(node, gast.Call) if is_paddle_api(node): @@ -47,10 +47,11 @@ def _no_need_convert_call(self, node): func_str = ast_to_source_code(node.func).strip() try: - from paddle.fluid.dygraph.dygraph_to_static.convert_call_func import is_builtin_len, is_builtin + from paddle.fluid.dygraph.dygraph_to_static.convert_call_func import is_builtin_len, is_builtin, is_builtin_zip is_builtin = eval("is_builtin({})".format(func_str)) is_builtin_len = eval("is_builtin_len({})".format(func_str)) - return is_builtin and not is_builtin_len + is_builtin_zip = eval("is_builtin_zip({})".format(func_str)) + return is_builtin and not is_builtin_len and not is_builtin_zip except Exception: return False diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/convert_call_func.py b/python/paddle/fluid/dygraph/dygraph_to_static/convert_call_func.py index 300586969ff65..0b009c0049dcb 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/convert_call_func.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/convert_call_func.py @@ -27,7 +27,7 @@ import six from paddle.fluid.dygraph.container import Sequential -from paddle.fluid.dygraph.dygraph_to_static.convert_operators import convert_len +from paddle.fluid.dygraph.dygraph_to_static.convert_operators import convert_len, convert_zip from paddle.fluid.dygraph.dygraph_to_static.logging_utils import TranslatorLogger from paddle.fluid.dygraph.dygraph_to_static.program_translator import StaticFunction from paddle.fluid.dygraph.dygraph_to_static.program_translator import convert_to_static @@ -79,6 +79,10 @@ def is_builtin_len(func): return False +def is_builtin_zip(func): + return is_builtin(func) and func.__name__ == 'zip' + + def is_unsupported(func): """ Checks whether the func is supported by dygraph to static graph. @@ -164,6 +168,9 @@ def dyfunc(x): if is_builtin_len(func): return convert_len + if is_builtin_zip(func): + return convert_zip + if is_builtin(func) or is_unsupported(func): return func diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py b/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py index 0ac4da947a46b..ba45dedc40faa 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py @@ -298,6 +298,15 @@ def convert_len(var): return len(var) +def convert_zip(*args): + for i, arg in enumerate(args): + if isinstance(arg, Variable) and arg.shape[0] == -1: + raise RuntimeError( + "Not support zip(tensor, ...) when tensor.shape[0] == -1, " + "but found args[{}].shape[0] == -1 in 'zip'".format(str(i))) + return zip(*args) + + def convert_var_shape(x, idx=None, in_control_flow=False): """ A function representation of the shape of variable. diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_for_enumerate.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_for_enumerate.py index 2aab27c03110d..750ed615e7109 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_for_enumerate.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_for_enumerate.py @@ -20,6 +20,7 @@ import paddle import paddle.fluid as fluid from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator +from paddle.static import InputSpec program_translator = ProgramTranslator() @@ -322,6 +323,24 @@ def for_original_tuple(): return z +# 23. for zip error +@paddle.jit.to_static( + input_spec=[InputSpec(shape=[None, 10]), InputSpec(shape=[None, 10])]) +def for_zip_error(x, y): + for i, j in zip(x, y): + a = i + j + return x + y + + +# 24. for zip +@paddle.jit.to_static( + input_spec=[InputSpec(shape=[2, 10]), InputSpec(shape=[2, 10])]) +def for_zip(x, y): + for i, j in zip(x, y): + a = i + j + return x + y + + class TestTransformBase(unittest.TestCase): def setUp(self): self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda( @@ -512,5 +531,14 @@ def test_transformed_result_compare(self): self.transformed_result_compare() +class TestForZip(unittest.TestCase): + def test_for_zip_error(self): + with self.assertRaises(RuntimeError): + paddle.jit.save(for_zip_error, './for_zip_error') + + def test_for_zip(self): + paddle.jit.save(for_zip, './for_zip') + + if __name__ == '__main__': unittest.main()