Skip to content

Commit

Permalink
[Dy2St]Refine ifelse early return (#43328)
Browse files Browse the repository at this point in the history
* Refine ifelse early return
  • Loading branch information
0x45f authored Jun 14, 2022
1 parent 083d769 commit 1950a36
Show file tree
Hide file tree
Showing 4 changed files with 149 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
# See details in https://github.com/serge-sans-paille/gast/
import os
from paddle.utils import gast
from paddle.fluid.dygraph.dygraph_to_static.early_return_transformer import EarlyReturnTransformer
from paddle.fluid.dygraph.dygraph_to_static.assert_transformer import AssertTransformer
from paddle.fluid.dygraph.dygraph_to_static.basic_api_transformer import BasicApiTransformer
from paddle.fluid.dygraph.dygraph_to_static.break_continue_transformer import BreakContinueTransformer
Expand Down Expand Up @@ -87,6 +88,7 @@ def transfer_from_node_type(self, node_wrapper):
self.visit(node_wrapper.node)

transformers = [
EarlyReturnTransformer,
BasicApiTransformer, # Basic Api
TensorShapeTransformer, # Tensor.shape -> layers.shape(Tensor)
ListTransformer, # List used in control flow
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

from paddle.utils import gast
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper


class EarlyReturnTransformer(gast.NodeTransformer):
"""
Transform if/else return statement of Dygraph into Static Graph.
"""

def __init__(self, wrapper_root):
assert isinstance(
wrapper_root, AstNodeWrapper
), "Type of input node should be AstNodeWrapper, but received %s ." % type(
wrapper_root)
self.root = wrapper_root.node

def transform(self):
"""
Main function to transform AST.
"""
self.visit(self.root)

def is_define_return_in_if(self, node):
assert isinstance(
node, gast.If
), "Type of input node should be gast.If, but received %s ." % type(
node)
for child in node.body:
if isinstance(child, gast.Return):
return True
return False

def visit_block_nodes(self, nodes):
result_nodes = []
destination_nodes = result_nodes
for node in nodes:
rewritten_node = self.visit(node)

if isinstance(rewritten_node, (list, tuple)):
destination_nodes.extend(rewritten_node)
else:
destination_nodes.append(rewritten_node)

# append other nodes to if.orelse even though if.orelse is not empty
if isinstance(node, gast.If) and self.is_define_return_in_if(node):
destination_nodes = node.orelse
# handle stmt like `if/elif/elif`
while len(destination_nodes) > 0 and \
isinstance(destination_nodes[0], gast.If) and \
self.is_define_return_in_if(destination_nodes[0]):
destination_nodes = destination_nodes[0].orelse

return result_nodes

def visit_If(self, node):
node.body = self.visit_block_nodes(node.body)
node.orelse = self.visit_block_nodes(node.orelse)
return node

def visit_While(self, node):
node.body = self.visit_block_nodes(node.body)
node.orelse = self.visit_block_nodes(node.orelse)
return node

def visit_For(self, node):
node.body = self.visit_block_nodes(node.body)
node.orelse = self.visit_block_nodes(node.orelse)
return node

def visit_FunctionDef(self, node):
node.body = self.visit_block_nodes(node.body)
return node
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,30 @@ def false_fn_0(q, x, y):
return x


def dyfunc_with_if_else_early_return1():
x = paddle.to_tensor([10])
if x == 0:
a = paddle.zeros([2, 2])
b = paddle.zeros([3, 3])
return a, b
a = paddle.zeros([2, 2]) + 1
return a


def dyfunc_with_if_else_early_return2():
x = paddle.to_tensor([10])
if x == 0:
a = paddle.zeros([2, 2])
b = paddle.zeros([3, 3])
return a, b
elif x == 1:
c = paddle.zeros([2, 2]) + 1
d = paddle.zeros([3, 3]) + 1
return c, d
e = paddle.zeros([2, 2]) + 3
return e


def dyfunc_with_if_else_with_list_geneator(x):
if 10 > 5:
y = paddle.add_n(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from paddle.fluid.dygraph.dygraph_to_static.utils import func_to_source_code
import paddle.jit.dy2static as _jst

from ifelse_simple_func import dyfunc_with_if_else
from ifelse_simple_func import dyfunc_with_if_else, dyfunc_with_if_else_early_return1, dyfunc_with_if_else_early_return2

np.random.seed(0)

Expand Down Expand Up @@ -83,34 +83,22 @@ def false_fn_0(x_v):
x_v = _jst.convert_ifelse(
fluid.layers.mean(x_v)[0] > 5, true_fn_0, false_fn_0, (x_v, ),
(x_v, ))
__return_0 = _jst.create_bool_as_type(label is not None, False)

def true_fn_1(__return_0, __return_value_0, label, x_v):
def true_fn_1(__return_value_0, label, x_v):
loss = fluid.layers.cross_entropy(x_v, label)
__return_0 = _jst.create_bool_as_type(label is not None, True)
__return_value_0 = loss
return __return_0, __return_value_0

def false_fn_1(__return_0, __return_value_0):
return __return_0, __return_value_0

__return_0, __return_value_0 = _jst.convert_ifelse(
label is not None, true_fn_1, false_fn_1,
(__return_0, __return_value_0, label, x_v),
(__return_0, __return_value_0))

def true_fn_2(__return_0, __return_value_0, x_v):
__return_1 = _jst.create_bool_as_type(
_jst.convert_logical_not(__return_0), True)
__return_value_0 = x_v
return __return_value_0

def false_fn_2(__return_value_0):
def false_fn_1(__return_value_0, label, x_v):
__return_1 = _jst.create_bool_as_type(label is not None, True)
__return_value_0 = x_v
return __return_value_0

__return_value_0 = _jst.convert_ifelse(
_jst.convert_logical_not(__return_0), true_fn_2, false_fn_2,
(__return_0, __return_value_0, x_v), (__return_value_0, ))
__return_value_0 = _jst.convert_ifelse(label is not None, true_fn_1,
false_fn_1,
(__return_value_0, label, x_v),
(__return_value_0, label, x_v))
return __return_value_0


Expand All @@ -123,45 +111,33 @@ def dyfunc_with_if_else(x_v, label=None):
name='__return_value_init_1')
__return_value_1 = __return_value_init_1

def true_fn_3(x_v):
def true_fn_2(x_v):
x_v = x_v - 1
return x_v

def false_fn_3(x_v):
def false_fn_2(x_v):
x_v = x_v + 1
return x_v

x_v = _jst.convert_ifelse(
fluid.layers.mean(x_v)[0] > 5, true_fn_3, false_fn_3, (x_v, ),
fluid.layers.mean(x_v)[0] > 5, true_fn_2, false_fn_2, (x_v, ),
(x_v, ))
__return_2 = _jst.create_bool_as_type(label is not None, False)

def true_fn_4(__return_2, __return_value_1, label, x_v):
def true_fn_3(__return_value_1, label, x_v):
loss = fluid.layers.cross_entropy(x_v, label)
__return_2 = _jst.create_bool_as_type(label is not None, True)
__return_value_1 = loss
return __return_2, __return_value_1

def false_fn_4(__return_2, __return_value_1):
return __return_2, __return_value_1

__return_2, __return_value_1 = _jst.convert_ifelse(
label is not None, true_fn_4, false_fn_4,
(__return_2, __return_value_1, label, x_v),
(__return_2, __return_value_1))

def true_fn_5(__return_2, __return_value_1, x_v):
__return_3 = _jst.create_bool_as_type(
_jst.convert_logical_not(__return_2), True)
__return_value_1 = x_v
return __return_value_1

def false_fn_5(__return_value_1):
def false_fn_3(__return_value_1, label, x_v):
__return_3 = _jst.create_bool_as_type(label is not None, True)
__return_value_1 = x_v
return __return_value_1

__return_value_1 = _jst.convert_ifelse(
_jst.convert_logical_not(__return_2), true_fn_5, false_fn_5,
(__return_2, __return_value_1, x_v), (__return_value_1, ))
__return_value_1 = _jst.convert_ifelse(label is not None, true_fn_3,
false_fn_3,
(__return_value_1, label, x_v),
(__return_value_1, label, x_v))
return __return_value_1


Expand Down Expand Up @@ -358,6 +334,21 @@ def test_raise_error(self):
net.foo.train()


class TestIfElseEarlyReturn(unittest.TestCase):

def test_ifelse_early_return1(self):
answer = np.zeros([2, 2]) + 1
static_func = paddle.jit.to_static(dyfunc_with_if_else_early_return1)
out = static_func()
self.assertTrue(np.allclose(answer, out.numpy()))

def test_ifelse_early_return2(self):
answer = np.zeros([2, 2]) + 3
static_func = paddle.jit.to_static(dyfunc_with_if_else_early_return2)
out = static_func()
self.assertTrue(np.allclose(answer, out.numpy()))


class TestRemoveCommentInDy2St(unittest.TestCase):

def func_with_comment(self):
Expand Down

0 comments on commit 1950a36

Please sign in to comment.