Skip to content

Commit

Permalink
fix ci error - v2
Browse files Browse the repository at this point in the history
  • Loading branch information
2742195759 committed Jul 5, 2022
1 parent e2138f7 commit 15fb43a
Show file tree
Hide file tree
Showing 7 changed files with 86 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -234,13 +234,17 @@ def _run_paddle_cond(pred, true_fn, false_fn, get_args, set_args,

def new_true_fn():
set_args(init_args)
true_fn()
return get_args()
ret = true_fn()
# IfExpr will return a non-None return value, so we just return ret.
# We assume normal return has no return value.
if ret is None: return get_args()
else: return ret

def new_false_fn():
set_args(init_args)
false_fn()
return get_args()
ret = false_fn()
if ret is None: return get_args()
else: return ret

try:
cond_outs = control_flow.cond(pred, new_true_fn, new_false_fn, None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@
# solve it in dy2stat, we put float64 value with this magic number at Static
# graph as a place holder to indicate the returning placeholder means no value
# should return.
RETURN_NO_VALUE_MAGIC_NUM = 1.77113e+279

# Assign not support float64, use float32 value as magic number.
RETURN_NO_VALUE_MAGIC_NUM = 1.77113e+27
RETURN_NO_VALUE_VAR_NAME = "__no_value_return_var"


Expand Down
35 changes: 25 additions & 10 deletions python/paddle/fluid/layers/control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,12 @@ def select_input_with_buildin_type(inputs, mask):
support_ret_buildin_type = (bool, float, six.integer_types)
false_var, true_var = inputs

if isinstance(false_var, UndefinedVar) and isinstance(
true_var, UndefinedVar):
""" None -> UndefinedVar, so the real value is a [None, UndefinedVar] or [None, None], we just return None.
"""
return None

if isinstance(false_var, Variable) and isinstance(true_var, Variable):
return select_input(inputs, mask)

Expand Down Expand Up @@ -2562,16 +2568,21 @@ def false_func():
"true_fn returns non-None while false_fn returns None")

# Merge ture and false output if they are not None
true_output, false_output = expand_undefined_var(true_output, false_output)
true_output, false_output = change_none_to_undefinedvar(
true_output, false_output)
if return_name_ids is None:
return_name_ids = ["no name"] * len(to_sequence(true_output))
else:
"""
dy2static will set the return_name_ids and expand the return values to UndefinedVar.
"""
true_output, false_output = expand_undefined_var(
true_output, false_output, return_name_ids)
true_output, false_output = change_none_to_undefinedvar(
true_output, false_output)
if len(to_sequence(true_output)) != len(to_sequence(false_output)):
raise ValueError(
"true fn returns {} vars, but false fn returns {} vars, which is not equals"
.format(len(to_sequence(true_output)),
len(to_sequence(false_output))))
if return_name_ids is None:
return_name_ids = ["no name"] * len(to_sequence(true_output))
for true_out, false_out, return_name in zip(to_sequence(true_output),
to_sequence(false_output),
to_sequence(return_name_ids)):
Expand Down Expand Up @@ -2601,20 +2612,24 @@ def map_fn(x):
return nest1_out, nest2_out


def expand_undefined_var(nest1, nest2):
def expand_undefined_var(nest1, nest2, names):
from paddle.fluid.dygraph.dygraph_to_static.utils import UndefinedVar
from paddle.fluid.dygraph.dygraph_to_static.return_transformer import RETURN_VALUE_PREFIX

def pack_undefined_var_as(seq):
return pack_sequence_as(seq,
[UndefinedVar("padding") for i in flatten(seq)])

def map_fn(n1, n2):
if isinstance(n1, UndefinedVar) or n1 is None:
def map_fn(n1, n2, name):
if not name.startswith(RETURN_VALUE_PREFIX) and (isinstance(
n1, UndefinedVar) or n1 is None):
return pack_undefined_var_as(n2)
return n1

nest1_out = list(map(map_fn, to_sequence(nest1), to_sequence(nest2)))
nest2_out = list(map(map_fn, to_sequence(nest2), to_sequence(nest1)))
nest1_out = list(
map(map_fn, to_sequence(nest1), to_sequence(nest2), to_sequence(names)))
nest2_out = list(
map(map_fn, to_sequence(nest2), to_sequence(nest1), to_sequence(names)))
if not is_sequence(nest1): nest1_out = nest1_out[0]
if not is_sequence(nest2): nest2_out = nest2_out[0]
return nest1_out, nest2_out
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from __future__ import print_function

import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid import ParamAttr
from paddle.fluid import layers
Expand Down Expand Up @@ -360,7 +361,7 @@ def beam_search(self, inputs):
predicted_ids = []
parent_ids = []

for step_idx in range(self.beam_max_step_num):
for step_idx in range(paddle.to_tensor(self.beam_max_step_num)):
if fluid.layers.reduce_sum(1 - beam_finished).numpy()[0] == 0:
break
step_input = self._merge_batch_beams(step_input)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,29 @@
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph.jit import declarative
from paddle.fluid.dygraph.dygraph_to_static.program_translator import ProgramTranslator
from paddle.fluid.dygraph.dygraph_to_static.utils import Dygraph2StaticException

SEED = 2020
np.random.seed(SEED)


class TestDy2staticException(unittest.TestCase):

def setUp(self):
self.x = np.random.random([10, 16]).astype('float32')
self.dyfunc = None
self.error = "Your if/else have different number of return value."

def test_error(self):
if self.dyfunc:
with self.assertRaisesRegex(Dygraph2StaticException, self.error):
ProgramTranslator().enable(True)
self.assertTrue(declarative(self.dyfunc)(self.x))
paddle.fluid.dygraph.base._in_declarative_mode_ = False
ProgramTranslator().enable(False)


def test_continue_in_for(x):
x = fluid.dygraph.to_variable(x)
for i in range(10):
Expand Down Expand Up @@ -265,10 +283,12 @@ def init_dygraph_func(self):
self.dygraph_func = while_loop_class_var


class TestOptimBreakInFor(TestContinueInWhile):
class TestOptimBreakInFor(TestDy2staticException):

def init_dygraph_func(self):
self.dygraph_func = test_optim_break_in_for
def setUp(self):
self.x = np.random.random([10, 16]).astype('float32')
self.dyfunc = test_optim_break_in_for
self.error = "python while pred change from bool to variable."


class TestOptimBreakInWhile(TestContinueInWhile):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from paddle.fluid.dygraph import Embedding, Layer, LayerNorm, Linear, to_variable
from paddle.fluid.dygraph.jit import dygraph_to_static_func
from paddle.fluid.layers.utils import map_structure
from paddle.fluid.layers.tensor import range as pd_range


def position_encoding_init(n_position, d_pos_vec):
Expand Down Expand Up @@ -633,7 +634,7 @@ def gather(input, indices, batch_pos):
value=0),
} for i in range(self.n_layer)]

for i in range(max_len):
for i in pd_range(0, max_len, 1, dtype="int32"):
trg_pos = layers.fill_constant(shape=trg_word.shape,
dtype="int64",
value=i)
Expand Down
25 changes: 23 additions & 2 deletions python/paddle/fluid/tests/unittests/test_cond.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from paddle.fluid.backward import append_backward
from paddle.fluid.framework import Program, program_guard
from simple_nets import simple_fc_net_with_inputs, batchnorm_fc_with_inputs
import paddle

np.random.seed(123)

Expand All @@ -41,6 +42,8 @@ def test_return_single_var(self):
return -1
"""

paddle.enable_static()

def true_func():
return layers.fill_constant(shape=[2, 3], dtype='int32', value=2)

Expand Down Expand Up @@ -73,6 +76,8 @@ def test_return_var_tuple(self):
return 3, 2
"""

paddle.enable_static()

def true_func():
return layers.fill_constant(shape=[1, 2], dtype='int32',
value=1), layers.fill_constant(
Expand Down Expand Up @@ -114,6 +119,8 @@ def test_pass_and_modify_var(self):
a = a - (i - 1)
"""

paddle.enable_static()

def true_func(a, i):
a = a * (i + 1)
return a
Expand Down Expand Up @@ -152,6 +159,8 @@ def test_return_none(self):
pass
"""

paddle.enable_static()

def true_func():
pass

Expand Down Expand Up @@ -181,6 +190,8 @@ def test_wrong_structure_exception(self):
test returning different number of tensors cannot merge into output
"""

paddle.enable_static()

def func_return_none():
return None

Expand Down Expand Up @@ -223,10 +234,11 @@ def func_return_two_tensors():
out = layers.cond(pred, func_return_one_tensor,
func_return_two_tensors)
self.assertTrue(
"Incompatible return values of true_fn and false_fn in cond" in
str(e.exception))
"true fn returns 1 vars, but false fn returns 2 vars, which is not equals"
in str(e.exception))

def test_extremely_simple_net_with_op_in_condition(self):
paddle.enable_static()
main_program = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(main_program, startup_program):
Expand Down Expand Up @@ -272,6 +284,8 @@ def test_cond_inside_cond(self):
return a / a
"""

paddle.enable_static()

def less_than_branch(i, a):
return layers.cond(i >= 3.0, lambda: layers.elementwise_add(a, a),
lambda: layers.elementwise_sub(a, a))
Expand Down Expand Up @@ -308,6 +322,7 @@ def greater_equal_branch(i, a):
self.assertEqual(ret[1][0], expected_a_grad)

def test_cond_op_in_condition(self):
paddle.enable_static()
main_program = fluid.Program()
startup_program = fluid.Program()

Expand Down Expand Up @@ -344,6 +359,7 @@ def backward_value_helper(self, cond_func, use_cuda, use_parallel_exe):
"""
Helper function that compares calculated backward value is close to dy/dx
"""
paddle.enable_static()
main_program = Program()
main_program.random_seed = 123
startup_program = Program()
Expand Down Expand Up @@ -474,6 +490,8 @@ def add_optimizer_helper(self, cond_func, use_cuda, use_parallel_exe):

def test_cond_backward(self):

paddle.enable_static()

def cond_func(i, img, label):
predicate = ((i % 2) == 0)
return layers.cond(
Expand All @@ -494,6 +512,7 @@ def cond_func(i, img, label):
use_parallel_exe)

def test_half_nested_cond_backward(self):
paddle.enable_static()

def branch(i, img, label):
return layers.cond(
Expand Down Expand Up @@ -530,6 +549,7 @@ def cond_func_simple_net_at_false(i, img, label):
use_parallel_exe)

def test_nested_cond_backward(self):
paddle.enable_static()

def branch(i, img, label, mod_two):
if mod_two:
Expand Down Expand Up @@ -560,6 +580,7 @@ def cond_func(i, img, label):
class TestCondWithError(unittest.TestCase):

def test_input_type_error(self):
paddle.enable_static()
main_program = framework.Program()
startup_program = framework.Program()
with framework.program_guard(main_program, startup_program):
Expand Down

1 comment on commit 15fb43a

@paddle-bot-old
Copy link

@paddle-bot-old paddle-bot-old bot commented on 15fb43a Jul 5, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🕵️ CI failures summary

🔍 PR: #43967 Commit ID: 15fb43a contains failed CI.

🔹 Failed: PR-CI-Static-Check

Unknown Failed
Unknown Failed

🔹 Failed: PR-CI-Model-benchmark

Unknown Failed
Unknown Failed

Please sign in to comment.