Skip to content

Commit

Permalink
[Dy2Stat] Fix undefined var used in For (#32153)
Browse files Browse the repository at this point in the history
* fix undefind var in For

* fix code style
  • Loading branch information
Aurelius84 authored Apr 9, 2021
1 parent 95122eb commit 4636d13
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -238,11 +238,16 @@ def is_required_ctx(ctxs, required_ctx):
return new_name_ids

def _is_call_func_name_node(self, node):
white_func_names = set(['append', 'extend'])
if len(self.ancestor_nodes) > 1:
assert self.ancestor_nodes[-1] == node
parent_node = self.ancestor_nodes[-2]
if isinstance(parent_node, gast.Call) and parent_node.func == node:
return True
# e.g: var_list.append(elem), var_list is also a name_id.
should_skip = isinstance(
node, gast.Attribute) and node.attr in white_func_names
if not should_skip:
return True
return False

def _update_name_ids(self, new_name_ids):
Expand Down Expand Up @@ -398,10 +403,13 @@ def _modified_vars(child_dict, parent_dict):
])

def _vars_loaded_before_store(ids_dict):
"""
gast.Param is also a kind of `load` semantic.
"""
new_dict = defaultdict(list)
for k, ctxs in six.iteritems(ids_dict):
for ctx in ctxs:
if isinstance(ctx, gast.Load):
if isinstance(ctx, (gast.Load, gast.Param)):
new_dict[k].append(ctx)
elif isinstance(ctx, gast.Store):
break
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -342,5 +342,28 @@ def init_net(self):
self.Net = DiffModeNet2


class TestNewVarCreateInOneBranch(unittest.TestCase):
def test_var_used_in_another_for(self):
def case_func(training):
# targets and targets_list is dynamically defined by training
if training:
targets = [1, 2, 3]
targets_list = [targets]

num_step = 3
for i in range(num_step):
if i > 0:
rois, rosi_num = 1, 2
# targets is in loop_vars.
if training:
ros, rosi_num, targets = -1, -2, [-1, -2, -3]
targets_list.append(targets)

return rosi_num

self.assertEqual(paddle.jit.to_static(case_func)(False), 2)
self.assertEqual(paddle.jit.to_static(case_func)(True), -2)


if __name__ == '__main__':
unittest.main()

0 comments on commit 4636d13

Please sign in to comment.