Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SOT] Always generate false_fn when POP_JUMP_* breakgraph #62424

Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 16 additions & 18 deletions python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1792,8 +1792,6 @@ def _break_graph_when_if(self, result: TensorVariable, instr: Instruction):

# 2. create true_fn and false_fn
def create_if_branch_fn(start_idx, input_var_names):
if self._instructions[start_idx].opname == "RETURN_VALUE":
return None
pycode_gen = PyCodeGen(self._frame)
origin_instrs = get_instructions(pycode_gen._origin_code)
pycode_gen.set_function_inputs(
Expand Down Expand Up @@ -1849,8 +1847,7 @@ def create_if_branch_fn(start_idx, input_var_names):
var_loader.load(result)
if_code = self._graph.pycode_gen.add_instr(instr.opname)

assert true_fn is not None

# Generate true_fn and false_fn
self._graph.pycode_gen.gen_load_object(
true_fn, true_fn.__code__.co_name
)
Expand All @@ -1865,22 +1862,20 @@ def create_if_branch_fn(start_idx, input_var_names):
)
self._graph.pycode_gen.gen_return()

if false_fn is not None:
false_start_code = self._graph.pycode_gen.gen_load_object(
false_fn, false_fn.__code__.co_name
)
for stack_arg in list(self.stack)[:-1]:
var_loader.load(stack_arg)
for name in false_fn_input_var_names:
var_loader.load(self.get_var(name, allow_undefined=True))
false_start_code = self._graph.pycode_gen.gen_load_object(
false_fn, false_fn.__code__.co_name
)
for stack_arg in list(self.stack)[:-1]:
var_loader.load(stack_arg)
for name in false_fn_input_var_names:
var_loader.load(self.get_var(name, allow_undefined=True))

self._graph.pycode_gen.gen_call_function(
argc=false_fn.__code__.co_argcount,
)
self._graph.pycode_gen.gen_return()
else:
false_start_code = self._graph.pycode_gen.gen_return()
self._graph.pycode_gen.gen_call_function(
argc=false_fn.__code__.co_argcount,
)
self._graph.pycode_gen.gen_return()

# Replace the jump instruction with the new if structure
if_code.jump_to = false_start_code

self.new_code = self._graph.pycode_gen.gen_pycode()
Expand Down Expand Up @@ -2279,3 +2274,6 @@ def create_inline_call_fn():

for name, var in zip(output_var_names[:-1], ret[slice_variable]):
self.set_var(name, var)

for name, var in zip(output_var_names[:-1], ret[slice_variable]):
self.set_var(name, var)