Skip to content

Commit

Permalink
[Dy2Stat]Polish for zip in dy2stat (PaddlePaddle#37846)
Browse files Browse the repository at this point in the history
* polish for zip in dy2stat

* polish comment

* polish is_builtin_len

* fix comment
  • Loading branch information
0x45f authored and Zjq9409 committed Dec 10, 2021
1 parent 2a8aec1 commit 96a99bd
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,19 @@ def _no_need_convert_call(self, node):
Determines whether a function needs to be transformed by `convert_call`.
It doesn't need to be transformed when a function satisfies the following conditions:
1. It's a api of paddle
2. It's a python builtin function not include `len`
2. It's a python builtin function not include `len` and `zip`
"""
assert isinstance(node, gast.Call)
if is_paddle_api(node):
return True

func_str = ast_to_source_code(node.func).strip()
try:
from paddle.fluid.dygraph.dygraph_to_static.convert_call_func import is_builtin_len, is_builtin
from paddle.fluid.dygraph.dygraph_to_static.convert_call_func import is_builtin_len, is_builtin, is_builtin_zip
is_builtin = eval("is_builtin({})".format(func_str))
is_builtin_len = eval("is_builtin_len({})".format(func_str))
return is_builtin and not is_builtin_len
is_builtin_zip = eval("is_builtin_zip({})".format(func_str))
return is_builtin and not is_builtin_len and not is_builtin_zip
except Exception:
return False

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import six

from paddle.fluid.dygraph.container import Sequential
from paddle.fluid.dygraph.dygraph_to_static.convert_operators import convert_len
from paddle.fluid.dygraph.dygraph_to_static.convert_operators import convert_len, convert_zip
from paddle.fluid.dygraph.dygraph_to_static.logging_utils import TranslatorLogger
from paddle.fluid.dygraph.dygraph_to_static.program_translator import StaticFunction
from paddle.fluid.dygraph.dygraph_to_static.program_translator import convert_to_static
Expand Down Expand Up @@ -79,6 +79,10 @@ def is_builtin_len(func):
return False


def is_builtin_zip(func):
return is_builtin(func) and func.__name__ == 'zip'


def is_unsupported(func):
"""
Checks whether the func is supported by dygraph to static graph.
Expand Down Expand Up @@ -164,6 +168,9 @@ def dyfunc(x):
if is_builtin_len(func):
return convert_len

if is_builtin_zip(func):
return convert_zip

if is_builtin(func) or is_unsupported(func):
return func

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,15 @@ def convert_len(var):
return len(var)


def convert_zip(*args):
for i, arg in enumerate(args):
if isinstance(arg, Variable) and arg.shape[0] == -1:
raise RuntimeError(
"Not support zip(tensor, ...) when tensor.shape[0] == -1, "
"but found args[{}].shape[0] == -1 in 'zip'".format(str(i)))
return zip(*args)


def convert_var_shape(x, idx=None, in_control_flow=False):
"""
A function representation of the shape of variable.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator
from paddle.static import InputSpec

program_translator = ProgramTranslator()

Expand Down Expand Up @@ -322,6 +323,24 @@ def for_original_tuple():
return z


# 23. for zip error
@paddle.jit.to_static(
input_spec=[InputSpec(shape=[None, 10]), InputSpec(shape=[None, 10])])
def for_zip_error(x, y):
for i, j in zip(x, y):
a = i + j
return x + y


# 24. for zip
@paddle.jit.to_static(
input_spec=[InputSpec(shape=[2, 10]), InputSpec(shape=[2, 10])])
def for_zip(x, y):
for i, j in zip(x, y):
a = i + j
return x + y


class TestTransformBase(unittest.TestCase):
def setUp(self):
self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda(
Expand Down Expand Up @@ -512,5 +531,14 @@ def test_transformed_result_compare(self):
self.transformed_result_compare()


class TestForZip(unittest.TestCase):
def test_for_zip_error(self):
with self.assertRaises(RuntimeError):
paddle.jit.save(for_zip_error, './for_zip_error')

def test_for_zip(self):
paddle.jit.save(for_zip, './for_zip')


if __name__ == '__main__':
unittest.main()

0 comments on commit 96a99bd

Please sign in to comment.