Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dy2St]Refine ifelse early return #43328

Merged
merged 10 commits into from
Jun 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
# See details in https://github.com/serge-sans-paille/gast/
import os
from paddle.utils import gast
from paddle.fluid.dygraph.dygraph_to_static.early_return_transformer import EarlyReturnTransformer
from paddle.fluid.dygraph.dygraph_to_static.assert_transformer import AssertTransformer
from paddle.fluid.dygraph.dygraph_to_static.basic_api_transformer import BasicApiTransformer
from paddle.fluid.dygraph.dygraph_to_static.break_continue_transformer import BreakContinueTransformer
Expand Down Expand Up @@ -87,6 +88,7 @@ def transfer_from_node_type(self, node_wrapper):
self.visit(node_wrapper.node)

transformers = [
EarlyReturnTransformer,
BasicApiTransformer, # Basic Api
TensorShapeTransformer, # Tensor.shape -> layers.shape(Tensor)
ListTransformer, # List used in control flow
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

from paddle.utils import gast
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper


class EarlyReturnTransformer(gast.NodeTransformer):
"""
Transform if/else return statement of Dygraph into Static Graph.
"""

def __init__(self, wrapper_root):
assert isinstance(
wrapper_root, AstNodeWrapper
), "Type of input node should be AstNodeWrapper, but received %s ." % type(
wrapper_root)
self.root = wrapper_root.node

def transform(self):
"""
Main function to transform AST.
"""
self.visit(self.root)

def is_define_return_in_if(self, node):
assert isinstance(
node, gast.If
), "Type of input node should be gast.If, but received %s ." % type(
node)
for child in node.body:
if isinstance(child, gast.Return):
return True
return False

def visit_block_nodes(self, nodes):
result_nodes = []
destination_nodes = result_nodes
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么不直接使用 destination_nodes = [] ? 另外,函数为什么不直接return destination_nodes? 是有什么特殊含义么?

for node in nodes:
rewritten_node = self.visit(node)

if isinstance(rewritten_node, (list, tuple)):
destination_nodes.extend(rewritten_node)
else:
destination_nodes.append(rewritten_node)

# append other nodes to if.orelse even though if.orelse is not empty
if isinstance(node, gast.If) and self.is_define_return_in_if(node):
destination_nodes = node.orelse
# handle stmt like `if/elif/elif`
while len(destination_nodes) > 0 and \
isinstance(destination_nodes[0], gast.If) and \
self.is_define_return_in_if(destination_nodes[0]):
destination_nodes = destination_nodes[0].orelse

return result_nodes

def visit_If(self, node):
node.body = self.visit_block_nodes(node.body)
node.orelse = self.visit_block_nodes(node.orelse)
return node

def visit_While(self, node):
node.body = self.visit_block_nodes(node.body)
node.orelse = self.visit_block_nodes(node.orelse)
return node

def visit_For(self, node):
node.body = self.visit_block_nodes(node.body)
node.orelse = self.visit_block_nodes(node.orelse)
return node

def visit_FunctionDef(self, node):
node.body = self.visit_block_nodes(node.body)
return node
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,30 @@ def false_fn_0(q, x, y):
return x


def dyfunc_with_if_else_early_return1():
x = paddle.to_tensor([10])
if x == 0:
a = paddle.zeros([2, 2])
b = paddle.zeros([3, 3])
return a, b
a = paddle.zeros([2, 2]) + 1
return a


def dyfunc_with_if_else_early_return2():
x = paddle.to_tensor([10])
if x == 0:
a = paddle.zeros([2, 2])
b = paddle.zeros([3, 3])
return a, b
elif x == 1:
c = paddle.zeros([2, 2]) + 1
d = paddle.zeros([3, 3]) + 1
return c, d
e = paddle.zeros([2, 2]) + 3
return e


def dyfunc_with_if_else_with_list_geneator(x):
if 10 > 5:
y = paddle.add_n(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from paddle.fluid.dygraph.dygraph_to_static.utils import func_to_source_code
import paddle.jit.dy2static as _jst

from ifelse_simple_func import dyfunc_with_if_else
from ifelse_simple_func import dyfunc_with_if_else, dyfunc_with_if_else_early_return1, dyfunc_with_if_else_early_return2

np.random.seed(0)

Expand Down Expand Up @@ -83,34 +83,22 @@ def false_fn_0(x_v):
x_v = _jst.convert_ifelse(
fluid.layers.mean(x_v)[0] > 5, true_fn_0, false_fn_0, (x_v, ),
(x_v, ))
__return_0 = _jst.create_bool_as_type(label is not None, False)

def true_fn_1(__return_0, __return_value_0, label, x_v):
def true_fn_1(__return_value_0, label, x_v):
loss = fluid.layers.cross_entropy(x_v, label)
__return_0 = _jst.create_bool_as_type(label is not None, True)
__return_value_0 = loss
return __return_0, __return_value_0

def false_fn_1(__return_0, __return_value_0):
return __return_0, __return_value_0

__return_0, __return_value_0 = _jst.convert_ifelse(
label is not None, true_fn_1, false_fn_1,
(__return_0, __return_value_0, label, x_v),
(__return_0, __return_value_0))

def true_fn_2(__return_0, __return_value_0, x_v):
__return_1 = _jst.create_bool_as_type(
_jst.convert_logical_not(__return_0), True)
__return_value_0 = x_v
return __return_value_0

def false_fn_2(__return_value_0):
def false_fn_1(__return_value_0, label, x_v):
__return_1 = _jst.create_bool_as_type(label is not None, True)
__return_value_0 = x_v
return __return_value_0

__return_value_0 = _jst.convert_ifelse(
_jst.convert_logical_not(__return_0), true_fn_2, false_fn_2,
(__return_0, __return_value_0, x_v), (__return_value_0, ))
__return_value_0 = _jst.convert_ifelse(label is not None, true_fn_1,
false_fn_1,
(__return_value_0, label, x_v),
(__return_value_0, label, x_v))
return __return_value_0


Expand All @@ -123,45 +111,33 @@ def dyfunc_with_if_else(x_v, label=None):
name='__return_value_init_1')
__return_value_1 = __return_value_init_1

def true_fn_3(x_v):
def true_fn_2(x_v):
x_v = x_v - 1
return x_v

def false_fn_3(x_v):
def false_fn_2(x_v):
x_v = x_v + 1
return x_v

x_v = _jst.convert_ifelse(
fluid.layers.mean(x_v)[0] > 5, true_fn_3, false_fn_3, (x_v, ),
fluid.layers.mean(x_v)[0] > 5, true_fn_2, false_fn_2, (x_v, ),
(x_v, ))
__return_2 = _jst.create_bool_as_type(label is not None, False)

def true_fn_4(__return_2, __return_value_1, label, x_v):
def true_fn_3(__return_value_1, label, x_v):
loss = fluid.layers.cross_entropy(x_v, label)
__return_2 = _jst.create_bool_as_type(label is not None, True)
__return_value_1 = loss
return __return_2, __return_value_1

def false_fn_4(__return_2, __return_value_1):
return __return_2, __return_value_1

__return_2, __return_value_1 = _jst.convert_ifelse(
label is not None, true_fn_4, false_fn_4,
(__return_2, __return_value_1, label, x_v),
(__return_2, __return_value_1))

def true_fn_5(__return_2, __return_value_1, x_v):
__return_3 = _jst.create_bool_as_type(
_jst.convert_logical_not(__return_2), True)
__return_value_1 = x_v
return __return_value_1

def false_fn_5(__return_value_1):
def false_fn_3(__return_value_1, label, x_v):
__return_3 = _jst.create_bool_as_type(label is not None, True)
__return_value_1 = x_v
return __return_value_1

__return_value_1 = _jst.convert_ifelse(
_jst.convert_logical_not(__return_2), true_fn_5, false_fn_5,
(__return_2, __return_value_1, x_v), (__return_value_1, ))
__return_value_1 = _jst.convert_ifelse(label is not None, true_fn_3,
false_fn_3,
(__return_value_1, label, x_v),
(__return_value_1, label, x_v))
return __return_value_1


Expand Down Expand Up @@ -358,6 +334,21 @@ def test_raise_error(self):
net.foo.train()


class TestIfElseEarlyReturn(unittest.TestCase):

def test_ifelse_early_return1(self):
answer = np.zeros([2, 2]) + 1
static_func = paddle.jit.to_static(dyfunc_with_if_else_early_return1)
out = static_func()
self.assertTrue(np.allclose(answer, out.numpy()))

def test_ifelse_early_return2(self):
answer = np.zeros([2, 2]) + 3
static_func = paddle.jit.to_static(dyfunc_with_if_else_early_return2)
out = static_func()
self.assertTrue(np.allclose(answer, out.numpy()))


class TestRemoveCommentInDy2St(unittest.TestCase):

def func_with_comment(self):
Expand Down