Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dy2Static] Add non-local for while and for. #43864

Merged
merged 12 commits into from
Jun 30, 2022
29 changes: 20 additions & 9 deletions python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from paddle.fluid.dygraph.dygraph_to_static.utils import UndefinedVar


def convert_while_loop(cond, body, loop_vars):
def convert_while_loop(cond, body, getter, setter):
"""
A function representation of a Python ``while`` statement.

Expand All @@ -39,25 +39,36 @@ def convert_while_loop(cond, body, loop_vars):

# NOTE: It may be slower if cond is very expensive, but usually cond is just O(1).
# If loop_vars is changed during cond callable, then it causes bug, but current logical_and/logical_not/... doesn't change the loop_vars.
pred = cond(*loop_vars)
pred = cond()
if isinstance(pred, Variable):
loop_vars = _run_paddle_while_loop(cond, body, loop_vars)
loop_vars = _run_paddle_while(cond, body, getter, setter)
else:
loop_vars = _run_py_while(cond, body, loop_vars)
loop_vars = _run_py_while(cond, body, getter, setter)

return loop_vars


def _run_paddle_while_loop(cond, body, loop_vars):
def _run_paddle_while(cond, body, getter, setter):
# NOTE: loop_vars of Paddle op `control_flow.while_loop` must be Paddle Tensors.
loop_vars = [to_static_variable(var) for var in loop_vars]
def to_list(x):
if isinstance(x, (tuple, list)): return x
return [x]

# UndefinedVar will become data layer not check.
loop_vars = [to_static_variable(var) for var in to_list(getter())]
setter(loop_vars if len(loop_vars) > 1 else
loop_vars[0]) # change the non-local var to variable
# variable maybe modified to inner var. change it into
loop_vars = control_flow.while_loop(cond, body, loop_vars)
setter(loop_vars if len(loop_vars) > 1 else
loop_vars[0]) # change the non-local var to variable
return loop_vars


def _run_py_while(cond, body, loop_vars):
while cond(*loop_vars):
loop_vars = body(*loop_vars)
def _run_py_while(cond, body, getter, setter):
loop_vars = getter()
while cond():
loop_vars = body()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里不需要调用setter来同步body里的修改么?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个是不需要调用setter的。因为我们的body中的修改是non local的,那么他的修改就是原来的变量。只有的paddle的时候才需要修改。

return loop_vars


Expand Down
80 changes: 11 additions & 69 deletions python/paddle/fluid/dygraph/dygraph_to_static/ifelse_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_undefined_var
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_nonlocal_stmt_node
from paddle.fluid.dygraph.dygraph_to_static.utils import create_nonlocal_stmt_node
from paddle.fluid.dygraph.dygraph_to_static.utils import create_get_args_node, create_set_args_node

TRUE_FUNC_PREFIX = 'true_fn'
FALSE_FUNC_PREFIX = 'false_fn'
Expand Down Expand Up @@ -415,17 +416,22 @@ def _vars_loaded(ids_dict):

# modified vars
body_modified_vars = _modified_vars(if_vars_dict, parent_vars_dict)
body_modified_vars = set(
filter(lambda x: x != ARGS_NAME, body_modified_vars))
orelse_modified_vars = _modified_vars(else_vars_dict, parent_vars_dict)
orelse_modified_vars = set(
filter(lambda x: x != ARGS_NAME, orelse_modified_vars))
modified_vars = body_modified_vars | orelse_modified_vars

# new vars
# TODO(remove __args when new FunctionScopeAnalysis has been used.)
body_new_vars = set([
var for var in _vars_with_store(if_vars_dict)
if var not in parent_vars_dict
if var not in parent_vars_dict and var != ARGS_NAME
])
orelse_new_vars = set([
var for var in _vars_with_store(else_vars_dict)
if var not in parent_vars_dict
if var not in parent_vars_dict and var != ARGS_NAME
])
new_vars_in_body_or_orelse = body_new_vars | orelse_new_vars
new_vars_in_one_of_body_or_orelse = body_new_vars ^ orelse_new_vars
Expand Down Expand Up @@ -511,11 +517,11 @@ def transform_if_else(node, root):
if any([not isinstance(ctx, gast.Load) for ctx in ctxs]):
parent_ids_set.add(k)

trun_args = parse_cond_args(parent_ids_set, body_name_ids,
true_args = parse_cond_args(parent_ids_set, body_name_ids,
modified_name_ids_from_parent)
false_args = parse_cond_args(parent_ids_set, orelse_name_ids,
modified_name_ids_from_parent)
nonlocal_names = list(trun_args | false_args | new_vars_to_create)
nonlocal_names = list(true_args | false_args | new_vars_to_create)
nonlocal_names.sort()
# NOTE: All var in return_name_ids should be in nonlocal_names.
nonlocal_names = _valid_nonlocal_names(return_name_ids, nonlocal_names)
Expand Down Expand Up @@ -552,70 +558,6 @@ def transform_if_else(node, root):
return create_new_vars_in_parent_stmts, true_func_node, false_func_node, get_args_node, set_args_node, return_name_ids


def create_get_args_node(names):
"""
Create get_args function as follows:

def get_args_0():
nonlocal x, y
return x, y
"""

def empty_node():
func_def = """
def {func_name}():
return
""".format(func_name=unique_name.generate(GET_ARGS_FUNC_PREFIX))
return gast.parse(textwrap.dedent(func_def)).body[0]

assert isinstance(names, (list, tuple))
if not names:
return empty_node()

template = """
def {func_name}():
nonlocal {vars}
return {vars}
"""
func_def = template.format(
func_name=unique_name.generate(GET_ARGS_FUNC_PREFIX),
vars=",".join(names))
return gast.parse(textwrap.dedent(func_def)).body[0]


def create_set_args_node(names):
"""
Create set_args function as follows:

def set_args_0(__args):
nonlocal x, y
x, y = __args
"""

def empty_node():
func_def = """
def {func_name}({args}):
pass
""".format(func_name=unique_name.generate(SET_ARGS_FUNC_PREFIX),
args=ARGS_NAME)
return gast.parse(textwrap.dedent(func_def)).body[0]

assert isinstance(names, (list, tuple))
if not names:
return empty_node()

template = """
def {func_name}({args}):
nonlocal {vars}
{vars} = {args}
"""
func_def = template.format(
func_name=unique_name.generate(SET_ARGS_FUNC_PREFIX),
args=ARGS_NAME,
vars=",".join(names))
return gast.parse(textwrap.dedent(func_def)).body[0]


def create_convert_ifelse_node(return_name_ids,
pred,
true_func,
Expand Down
Loading