Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
Fix local var created in loop but loop fn is in resume fn (#401)
Browse files Browse the repository at this point in the history
  • Loading branch information
feifei-111 authored Sep 26, 2023
1 parent b55be31 commit e227656
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 2 deletions.
6 changes: 6 additions & 0 deletions sot/opcode_translator/executor/opcode_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2071,8 +2071,10 @@ def _break_graph_in_for_loop(
self._graph.pycode_gen.gen_store(ret_names[idx], self._code)

# 3. setup vars which is created in loop
undefined_names = set()
for name in loop_body_inputs[:-1]:
if not self.has_var(name, all_used_vars[name]):
undefined_names.add(name)
self._graph.pycode_gen.gen_load_const(SotUndefinedVar())
self._graph.pycode_gen.gen_store(name, self._code)

Expand Down Expand Up @@ -2137,6 +2139,10 @@ def _break_graph_in_for_loop(
for stack_arg in self.stack:
var_loader.load(stack_arg)
for name in fn_inputs:
if not self.has_var(name) and name not in undefined_names:
undefined_names.add(name)
self._graph.pycode_gen.gen_load_const(SotUndefinedVar())
self._graph.pycode_gen.gen_store(name, self._code)
self._graph.pycode_gen.gen_load(name)

self._graph.pycode_gen.gen_call_function(
Expand Down
3 changes: 1 addition & 2 deletions sot/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,7 @@ def callback(frame):
def impl_sot(*args: P.args, **kwargs: P.kwargs) -> R:
assert hasattr(
fn, "__code__"
), "Target function has not code for simulating."

), "Target function doesn't have code for simulating."
StepInfoManager().sot_step()
GraphLogger().clear()
paddle.framework.core.set_eval_frame(callback)
Expand Down
31 changes: 31 additions & 0 deletions tests/test_12_for_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,37 @@ def test_run(self):
self.assert_nest_match(InstructionTranslatorCache().translate_count, 1)


# after_loop_fn need zzz, and zzz is created as UndefinedVar when generating loop body
# do not set zzz as UndefinedVar again
def undefined_var_case_0():
for i in [1, 2]:
sot.psdb.breakgraph()
zzz = i

zzz = zzz + 1
return zzz


# after_loop_fn need create zzz as UndefinedVar
def undefined_var_case_1():
for i in [1, 2]:
sot.psdb.breakgraph()
aaa = i

for i in [1, 3]:
zzz = i
zzz = zzz + 1
return zzz


class TestUndefinedVarInRiskyCodes(TestCaseBase):
def test_undefined_var_case_0(self):
self.assert_results(undefined_var_case_0)

def test_undefined_var_case_1(self):
self.assert_results(undefined_var_case_1)


if __name__ == "__main__":
with strict_mode_guard(0):
unittest.main()

0 comments on commit e227656

Please sign in to comment.