Skip to content

Commit

Permalink
[warning] added warning message in cond block when one branch returns…
Browse files Browse the repository at this point in the history
… variable and another returns None (#46031)

* [cherry-pick] Allow manaully set py_reader name in standalone executor (#45898) (#45931)

* Allow manaully set py_reader name in standalone executor
  • Loading branch information
feifei-111 authored Sep 19, 2022
1 parent 97cdc7c commit 1a8c969
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 5 deletions.
54 changes: 49 additions & 5 deletions python/paddle/fluid/layers/control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2653,15 +2653,22 @@ def false_func():

# Merge ture and false output if they are not None
if return_names is None:
is_dy2staic = False
return_names = ["no name"] * len(to_sequence(true_output))
else:
"""
dy2static will set the return_names and expand the return values to UndefinedVar.
"""
is_dy2staic = True

# TODO: expand_undefined_var will replace None to Undefinedvar(), to fix cases like:
# a = None
# if condition:
# a = 1
# Because we can not use variable to express 'None'
true_output, false_output = expand_undefined_var(
true_output, false_output, return_names)
true_output, false_output = change_none_to_undefinedvar(
true_output, false_output)

if len(to_sequence(true_output)) != len(to_sequence(false_output)):
raise ValueError(
"true fn returns {} vars, but false fn returns {} vars, which is not equals"
Expand All @@ -2677,6 +2684,28 @@ def false_func():
"Incompatible return values of `{}` in true_fn and false_fn in cond: {}"
.format(return_name, e))

def check_ret_none(seq_true, seq_false, seq_names):
length = len(seq_true)
for i in range(length):
f_true = flatten(seq_true[i])
f_false = flatten(seq_false[i])
for idx in range(len(f_true)):
if f_true[idx] is None and f_false[idx] is not None or f_false[
idx] is None and f_true[idx] is not None:
warnings.warn(
"In cond : Var '{}' or part of it is set differently in ifelse branchs, "
"<{}, {}> in true branch and <{}, {}> in false branch. Set var to "
"'None' in ifelse block might lead to error.".format(
seq_names[i], type(f_true[idx]), f_true[idx],
type(f_false[idx]), f_false[idx]))

check_ret_none(to_sequence(true_output), to_sequence(false_output),
to_sequence(return_names))

if is_dy2staic:
true_output, false_output = change_none_to_undefinedvar(
true_output, false_output)

mask = cast(pred, dtype='int32')
merge_func = lambda name, false_var, true_var: select_input_with_buildin_type(
[false_var, true_var], mask, name)
Expand Down Expand Up @@ -2716,16 +2745,31 @@ def pack_undefined_var_as(seq):
return pack_sequence_as(seq,
[UndefinedVar("padding") for i in flatten(seq)])

def map_fn(n1, n2, name):
def map_fn(n1, n2, name, order):
if not name.startswith(RETURN_VALUE_PREFIX) and (isinstance(
n1, UndefinedVar) or n1 is None):
if n1 is None and n2 is not None:
if order == 0:
warnings.warn(
"In cond : Var '{}' or part of it is set differently in ifelse branchs, "
"<{}, {}> in true branch and <{}, {}> in false branch. Set var to "
"'None' in ifelse block might lead to error.".format(
name, type(n1), n1, type(n2), n2))
else:
warnings.warn(
"In cond : Var '{}' or part of it is set differently in ifelse branchs, "
"<{}, {}> in true branch and <{}, {}> in false branch. Set var to "
"'None' in ifelse block might lead to error.".format(
name, type(n2), n2, type(n1), n1))
return pack_undefined_var_as(n2)
return n1

nest1_out = list(
map(map_fn, to_sequence(nest1), to_sequence(nest2), to_sequence(names)))
map(map_fn, to_sequence(nest1), to_sequence(nest2), to_sequence(names),
[0 for i in to_sequence(names)]))
nest2_out = list(
map(map_fn, to_sequence(nest2), to_sequence(nest1), to_sequence(names)))
map(map_fn, to_sequence(nest2), to_sequence(nest1), to_sequence(names),
[1 for i in to_sequence(names)]))
if not is_sequence(nest1): nest1_out = nest1_out[0]
if not is_sequence(nest2): nest2_out = nest2_out[0]
return nest1_out, nest2_out
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import paddle
import warnings
from paddle.fluid.dygraph.dygraph_to_static.program_translator import convert_to_static
from paddle.fluid.layers.control_flow import cond


@paddle.jit.to_static
def fun1():
a = paddle.to_tensor(1)
b = paddle.to_tensor(2)
if a > b:
b = paddle.to_tensor(3)
else:
b = None


def true_fn():
return [paddle.to_tensor(1), [paddle.to_tensor(2), paddle.to_tensor(3)]]


def false_fn():
return [paddle.to_tensor(3), [None, paddle.to_tensor(4)]]


class TestReturnNoneInIfelse(unittest.TestCase):

def test_dy2static_warning(self):
paddle.disable_static()
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
fun1()
flag = False
for warn in w:
if (
issubclass(warn.category, UserWarning)
) and "Set var to 'None' in ifelse block might lead to error." in str(
warn.message):
flag = True
break
self.assertTrue(flag)

def test_cond_warning(self):
paddle.enable_static()
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
a = paddle.to_tensor(1)
b = paddle.to_tensor(2)
cond(a < b, true_fn, false_fn, return_names=['ret1', 'ret2'])
flag = False
for warn in w:
if (
issubclass(warn.category, UserWarning)
) and "Set var to 'None' in ifelse block might lead to error." in str(
warn.message):
flag = True
break
self.assertTrue(flag)


if __name__ == '__main__':
unittest.main()

0 comments on commit 1a8c969

Please sign in to comment.