diff --git a/sot/opcode_translator/executor/opcode_executor.py b/sot/opcode_translator/executor/opcode_executor.py index 45aeeec0b..ca60f7563 100644 --- a/sot/opcode_translator/executor/opcode_executor.py +++ b/sot/opcode_translator/executor/opcode_executor.py @@ -2071,8 +2071,10 @@ def _break_graph_in_for_loop( self._graph.pycode_gen.gen_store(ret_names[idx], self._code) # 3. setup vars which is created in loop + undefined_names = set() for name in loop_body_inputs[:-1]: if not self.has_var(name, all_used_vars[name]): + undefined_names.add(name) self._graph.pycode_gen.gen_load_const(SotUndefinedVar()) self._graph.pycode_gen.gen_store(name, self._code) @@ -2137,6 +2139,10 @@ def _break_graph_in_for_loop( for stack_arg in self.stack: var_loader.load(stack_arg) for name in fn_inputs: + if not self.has_var(name) and name not in undefined_names: + undefined_names.add(name) + self._graph.pycode_gen.gen_load_const(SotUndefinedVar()) + self._graph.pycode_gen.gen_store(name, self._code) self._graph.pycode_gen.gen_load(name) self._graph.pycode_gen.gen_call_function( diff --git a/sot/translate.py b/sot/translate.py index 83719aaa3..7a342e659 100644 --- a/sot/translate.py +++ b/sot/translate.py @@ -81,8 +81,7 @@ def callback(frame): def impl_sot(*args: P.args, **kwargs: P.kwargs) -> R: assert hasattr( fn, "__code__" - ), "Target function has not code for simulating." - + ), "Target function doesn't have code for simulating." StepInfoManager().sot_step() GraphLogger().clear() paddle.framework.core.set_eval_frame(callback) diff --git a/tests/test_12_for_loop.py b/tests/test_12_for_loop.py index 4a90adcb7..710af8476 100644 --- a/tests/test_12_for_loop.py +++ b/tests/test_12_for_loop.py @@ -248,6 +248,37 @@ def test_run(self): self.assert_nest_match(InstructionTranslatorCache().translate_count, 1) +# after_loop_fn need zzz, and zzz is created as UndefinedVar when generating loop body +# do not set zzz as UndefinedVar again +def undefined_var_case_0(): + for i in [1, 2]: + sot.psdb.breakgraph() + zzz = i + + zzz = zzz + 1 + return zzz + + +# after_loop_fn need create zzz as UndefinedVar +def undefined_var_case_1(): + for i in [1, 2]: + sot.psdb.breakgraph() + aaa = i + + for i in [1, 3]: + zzz = i + zzz = zzz + 1 + return zzz + + +class TestUndefinedVarInRiskyCodes(TestCaseBase): + def test_undefined_var_case_0(self): + self.assert_results(undefined_var_case_0) + + def test_undefined_var_case_1(self): + self.assert_results(undefined_var_case_1) + + if __name__ == "__main__": with strict_mode_guard(0): unittest.main()