Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dy2Stat]Refactor convert_shape transformer logic #43846

Merged
merged 2 commits into from
Jun 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 16 additions & 74 deletions python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,88 +338,30 @@ def convert_zip(*args):
return zip(*args)


def convert_var_shape(x, idx=None, in_control_flow=False):
def convert_shape(x):
"""
A function representation of the shape of variable.
"""

def has_negative(list_shape, idx=None):
if idx is not None:
return list_shape[idx] < 0

num_negative = sum([1 if i < 0 else 0 for i in list_shape])
return num_negative > 0

# When `x` is Variable, call nn.shape(x) in following cases:
# (1) The shape of `x` is used in control flow condition.
# ```
# if x.shape[0] == 1:
# y = XX
# ```
# (2) The dim to be used is negative
# ```
# # Assume x.shape=[3, -1] in static mode
# y = paddle.reshape(x, shape=[1, x.shape[1]])
# ```
if isinstance(x, Variable) and has_negative(x.shape, idx):
return nn.shape(x) if idx is None else nn.shape(x)[idx]
else:
return list(x.shape) if idx is None else x.shape[idx]
def has_negative(list_shape):
return any([x < 0 for x in list_shape])

# When `x` is Variable:
# (1) if x.shape contains -1, such as [2, -1, 64], returns [2, var, 64],
# where var = paddle.shape(x)[1]

# (2) if x.shape does not contains -1, return lsit(x.shape) directly

def convert_var_shape_simple(x):
"""
A function representation of the shape of variable.
"""
if isinstance(x, Variable):
return nn.shape(x)
values = list(x.shape)
if has_negative(values):
shape_tensor = nn.shape(x)
for i, v in enumerate(values):
if v is None or v < 0:
values[i] = shape_tensor[i]
return values
else:
# Use list() to make returned type consistant with dygraph
return list(x.shape)


def eval_if_exist_else_none(name, global_symbol_table):
"""
Args:
name([str]): Expression passed into `eval`.
local_symbol_table(dict): Specified from `globals()`. DO NOT use `locals()`,
because all STATIC_CONVERT_VAR_SHAPE_SUFFIX vars is
declared with keyword `global`.

Returns:
Return the variable if found in global_symbol_table else None.
"""
try:
return eval(name, global_symbol_table)
except:
return None


def choose_shape_attr_or_api(attr_shape, api_shape, idx=None):
"""
Input can be attribute `x.shape` or api `shape(x)`, this function
chooses which one to return to use in dy2stat.

Note: sometimes users write `x.shape[3]`, so attr_shape can be an integer.
"""
if api_shape is None:
return attr_shape if idx is None else attr_shape[idx]
if not isinstance(attr_shape, (list, tuple)):
# some variables like x.shape[0] is no longer a list or tuple
if isinstance(attr_shape, int) and attr_shape < 0:
return api_shape if idx is None else api_shape[idx]
return attr_shape if idx is None else attr_shape[idx]

def has_negative(list_shape, idx=None):
if idx is not None:
return list_shape[idx] < 0

num_negative = sum([1 if i < 0 else 0 for i in list_shape])
return num_negative > 0

if has_negative(attr_shape, idx):
return api_shape if idx is None else api_shape[idx]
return attr_shape if idx is None else attr_shape[idx]
return x.shape


def convert_shape_compare(left, *args):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,28 +63,6 @@ def visit_UnaryOp(self, node):
return new_node
return node

def visit_Compare(self, node):
self.generic_visit(node)
left_str = ast_to_source_code(node.left).strip()
if left_str.startswith("_jst.convert_var_shape"):
# check left and comparators are all converted var shape
compare_arg_strs = left_str
for i, comparator in enumerate(node.comparators):
comparator_str = ast_to_source_code(comparator).strip()
if not comparator_str.startswith("_jst.convert_var_shape"):
return node
op_str = cmpop_node_to_str(node.ops[i])
compare_arg_strs += (", '" + op_str + "', " + comparator_str)

# Now all left and comparators are converted shape
# Replace some comparsion operation because of difference between
# Python and Paddle
new_node_str = "_jst.convert_shape_compare({})".format(
compare_arg_strs)
new_node = gast.parse(new_node_str).body[0].value
return new_node
return node

def visit_BoolOp(self, node):
self.generic_visit(node)
if isinstance(node.op, gast.And):
Expand Down
Loading