diff --git a/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py b/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py index 0d832c3b5cf85..40a4c3ae62460 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py +++ b/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py @@ -1791,8 +1791,13 @@ def _break_graph_when_if(self, result: TensorVariable, instr: Instruction): stack_size_after_if = len(self.stack) - 1 # 2. create true_fn and false_fn - def create_if_branch_fn(start_idx, input_var_names): - if self._instructions[start_idx].opname == "RETURN_VALUE": + def create_if_branch_fn(start_idx, input_var_names, is_pop_jump_branch): + # JUMP_IF_* maybe jump to the RETURN_VALUE, we should skip this case + # We shouldn't skip POP_JUMP_* case, because it will cause the stack size to be incorrect + if ( + self._instructions[start_idx].opname == "RETURN_VALUE" + and not is_pop_jump_branch + ): return None pycode_gen = PyCodeGen(self._frame) origin_instrs = get_instructions(pycode_gen._origin_code) @@ -1815,6 +1820,7 @@ def create_if_branch_fn(start_idx, input_var_names): true_fn = create_if_branch_fn( start_idx=true_fn_start_index, input_var_names=true_fn_input_var_names, + is_pop_jump_branch=False, ) false_fn_read_names, _ = analysis_used_names( @@ -1827,6 +1833,7 @@ def create_if_branch_fn(start_idx, input_var_names): false_fn = create_if_branch_fn( start_idx=false_fn_start_index, input_var_names=false_fn_input_var_names, + is_pop_jump_branch=instr.opname.startswith("POP_JUMP"), ) # 4. setup vars which is created in loop as Undefind @@ -1881,6 +1888,7 @@ def create_if_branch_fn(start_idx, input_var_names): else: false_start_code = self._graph.pycode_gen.gen_return() + # Replace the jump instruction with the new if structure if_code.jump_to = false_start_code self.new_code = self._graph.pycode_gen.gen_pycode()