Skip to content

Commit

Permalink
only skip in non-POP_JUMP_*
Browse files Browse the repository at this point in the history
  • Loading branch information
SigureMo committed Mar 5, 2024
1 parent e5ea35c commit a5ffa60
Showing 1 changed file with 26 additions and 13 deletions.
39 changes: 26 additions & 13 deletions python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1791,7 +1791,14 @@ def _break_graph_when_if(self, result: TensorVariable, instr: Instruction):
stack_size_after_if = len(self.stack) - 1

# 2. create true_fn and false_fn
def create_if_branch_fn(start_idx, input_var_names):
def create_if_branch_fn(start_idx, input_var_names, is_pop_jump_branch):
# JUMP_IF_* maybe jump to the RETURN_VALUE, we should skip this case
# We shouldn't skip POP_JUMP_* case, because it will cause the stack size to be incorrect
if (
self._instructions[start_idx].opname == "RETURN_VALUE"
and not is_pop_jump_branch
):
return None
pycode_gen = PyCodeGen(self._frame)
origin_instrs = get_instructions(pycode_gen._origin_code)
pycode_gen.set_function_inputs(
Expand All @@ -1813,6 +1820,7 @@ def create_if_branch_fn(start_idx, input_var_names):
true_fn = create_if_branch_fn(
start_idx=true_fn_start_index,
input_var_names=true_fn_input_var_names,
is_pop_jump_branch=False,
)

false_fn_read_names, _ = analysis_used_names(
Expand All @@ -1825,6 +1833,7 @@ def create_if_branch_fn(start_idx, input_var_names):
false_fn = create_if_branch_fn(
start_idx=false_fn_start_index,
input_var_names=false_fn_input_var_names,
is_pop_jump_branch=instr.opname.startswith("POP_JUMP"),
)

# 4. setup vars which is created in loop as Undefind
Expand All @@ -1847,7 +1856,8 @@ def create_if_branch_fn(start_idx, input_var_names):
var_loader.load(result)
if_code = self._graph.pycode_gen.add_instr(instr.opname)

# Generate true_fn and false_fn
assert true_fn is not None

self._graph.pycode_gen.gen_load_object(
true_fn, true_fn.__code__.co_name
)
Expand All @@ -1862,18 +1872,21 @@ def create_if_branch_fn(start_idx, input_var_names):
)
self._graph.pycode_gen.gen_return()

false_start_code = self._graph.pycode_gen.gen_load_object(
false_fn, false_fn.__code__.co_name
)
for stack_arg in list(self.stack)[:-1]:
var_loader.load(stack_arg)
for name in false_fn_input_var_names:
var_loader.load(self.get_var(name, allow_undefined=True))
if false_fn is not None:
false_start_code = self._graph.pycode_gen.gen_load_object(
false_fn, false_fn.__code__.co_name
)
for stack_arg in list(self.stack)[:-1]:
var_loader.load(stack_arg)
for name in false_fn_input_var_names:
var_loader.load(self.get_var(name, allow_undefined=True))

self._graph.pycode_gen.gen_call_function(
argc=false_fn.__code__.co_argcount,
)
self._graph.pycode_gen.gen_return()
self._graph.pycode_gen.gen_call_function(
argc=false_fn.__code__.co_argcount,
)
self._graph.pycode_gen.gen_return()
else:
false_start_code = self._graph.pycode_gen.gen_return()

# Replace the jump instruction with the new if structure
if_code.jump_to = false_start_code
Expand Down

0 comments on commit a5ffa60

Please sign in to comment.