Skip to content

Commit

Permalink
[ Dy2Static ] Fix bugs when select inputs meeting different shape or …
Browse files Browse the repository at this point in the history
…undefined-var (#45916)

* fix select_input with different shape errors:
1. select_input_with_buildin_type directly return non-undefinedvar branch when meeting undefined var
2. the output shape of select_input is inferred from inputs.

* reverse the logic in select_input
  • Loading branch information
2742195759 authored Sep 14, 2022
1 parent 6833ecf commit b85c9b5
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 42 deletions.
15 changes: 0 additions & 15 deletions python/paddle/fluid/dygraph/dygraph_to_static/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,21 +145,6 @@ def data_layer_not_check(name, shape, dtype='float32', lod_level=0):
need_check_feed=False)


def create_undefined_var_like(variable):
""" create a undefined var with the same shape and dtype like varaible.
"""
from paddle.fluid.dygraph.dygraph_to_static.return_transformer import RETURN_NO_VALUE_MAGIC_NUM
var = data_layer_not_check(unique_name.generate("undefined_var"),
variable.shape, variable.dtype)
var.stop_gradient = False
helper = LayerHelper('create_undefined_var_like', **locals())
saved_block_ids = helper.main_program.current_block_idx
helper.main_program.current_block_idx = 0
assign(RETURN_NO_VALUE_MAGIC_NUM, var)
helper.main_program.current_block_idx = saved_block_ids
return var


def create_undefined_variable():
from paddle.fluid.dygraph.dygraph_to_static.return_transformer import RETURN_NO_VALUE_MAGIC_NUM
var = data_layer_not_check(unique_name.generate("undefined_var"), [1],
Expand Down
77 changes: 52 additions & 25 deletions python/paddle/fluid/layers/control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,25 @@ def select_output(input, outputs, mask):
return outputs


def _select_input_infer_shape(first_shape, second_shape):
"""
This function infer the output shape by following algorithm:
1. if the dims is different, raise a error.
2. compare axis one by one:
if a == b: we set axis to a
if a != b: we set axis to -1
for compatibility,non declarative mode, we just return second_shape.
"""
if len(first_shape) != len(second_shape):
warnings.warn(
f"the input shapes of select_input should have the same rank, but get {first_shape}, {second_shape}"
)
return second_shape
out_shape = list(
map(lambda a, b: a if a == b else -1, first_shape, second_shape))
return out_shape


def select_input(inputs, mask):
"""
**select_input**
Expand All @@ -89,13 +108,15 @@ def select_input(inputs, mask):
check_type(inputs, 'inputs', (list, tuple), 'select_input')
check_variable_and_dtype(mask, 'mask', ['int32'], 'select_input')

input_dtype = inputs[1].dtype
input_shape = inputs[1].shape
input_type = inputs[1].type
# Select input should expand the shape. If it is - 1 and valid number, use - 1 first. If the dim is different, an error will be reported directly
#assert inputs[0].dtype == inputs[1].dtype, f"Expect the inputs should have the same dtype, but get {inputs[0].dtype} and {inputs[1].dtype}"
output_shape = _select_input_infer_shape(inputs[0].shape, inputs[1].shape)
output_dtype = inputs[1].dtype
output_type = inputs[1].type

out = helper.create_variable(dtype=input_dtype,
shape=input_shape,
type=input_type)
out = helper.create_variable(dtype=output_dtype,
shape=output_shape,
type=output_type)
helper.append_op(type='select_input',
inputs={
'X': inputs,
Expand All @@ -105,9 +126,9 @@ def select_input(inputs, mask):
return out


def select_input_with_buildin_type(inputs, mask):
def select_input_with_buildin_type(inputs, mask, name):
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import to_static_variable
from paddle.fluid.dygraph.dygraph_to_static.utils import UndefinedVar, create_undefined_var_like
from paddle.fluid.dygraph.dygraph_to_static.utils import UndefinedVar
false_var, true_var = inputs

if isinstance(false_var, UndefinedVar) and isinstance(
Expand All @@ -117,7 +138,11 @@ def select_input_with_buildin_type(inputs, mask):
return None

if isinstance(false_var, Variable) and isinstance(true_var, Variable):
return select_input(inputs, mask)
try:
return select_input(inputs, mask)
except Exception as e:
raise RuntimeError(
f"Exceptions throwed while doing select_input on {name}:\n{e}")

elif (isinstance(false_var, (support_ret_buildin_type))
and isinstance(false_var, type(true_var))):
Expand Down Expand Up @@ -148,24 +173,19 @@ def create_var_if_not_undefined_var(a):
if isinstance(a, UndefinedVar): return a
return to_static_variable(a)

def create_like_if_undefined_var(a, b):
if isinstance(a, UndefinedVar): return create_undefined_var_like(b)
return a

# TODO(xiongkun): add warning here.
true_var, false_var = create_var_if_not_undefined_var(
true_var), create_var_if_not_undefined_var(false_var)
inputs = [
create_like_if_undefined_var(false_var, true_var),
create_like_if_undefined_var(true_var, false_var)
]
true_var, false_var = to_static_variable(true_var), to_static_variable(
false_var)
inputs = [false_var, true_var]
else:
raise TypeError(
"Unsupported return type of true_fn and false_fn in cond: false_var "
"returned by fasle_fn is '{}' and true_var of true_fn is '{}'".
format(type(false_var), type(true_var)))

return select_input(inputs, mask)
try:
return select_input(inputs, mask)
except Exception as e:
raise RuntimeError(
f"Exceptions throwed while doing select_input on {name}:\n{e}")


def split_lod_tensor(input, mask, level=0):
Expand Down Expand Up @@ -2658,9 +2678,16 @@ def false_func():
.format(return_name, e))

mask = cast(pred, dtype='int32')
merge_func = lambda false_var, true_var: select_input_with_buildin_type(
[false_var, true_var], mask)
merged_output = map_structure(merge_func, false_output, true_output)
merge_func = lambda name, false_var, true_var: select_input_with_buildin_type(
[false_var, true_var], mask, name)

def merge_every_var_list(false_vars, true_vars, name):
return map_structure(partial(merge_func, name), false_vars, true_vars)

merged_output = list(
map(merge_every_var_list, to_sequence(false_output),
to_sequence(true_output), to_sequence(return_names)))
merged_output = pack_sequence_as(false_output, flatten(merged_output))
return merged_output


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def nested_if_else(x_v):
if paddle.mean(y).numpy()[0] < batch_size:
y = fluid.layers.abs(y)
else:
tmp = fluid.layers.fill_constant([feat_size],
tmp = fluid.layers.fill_constant(y.shape,
dtype='float32',
value=-1)
y = y - tmp
Expand Down Expand Up @@ -273,7 +273,7 @@ def forward(self, input):
[hidden_dim], dtype='float32', value=9)
y = fluid.layers.abs(y)
else:
tmp = fluid.layers.fill_constant([5],
tmp = fluid.layers.fill_constant(y.shape,
dtype='float32',
value=-1)
y = y - tmp
Expand Down

0 comments on commit b85c9b5

Please sign in to comment.