Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

[Compat][3.11] gen KW_NAMES when call breakgraph, enable test_constant_graph.py #375

Merged
merged 3 commits into from
Sep 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions sot/opcode_translator/executor/opcode_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@

@dataclass
class Stop:
state: bool
state: str


@Singleton
Expand Down Expand Up @@ -1225,7 +1225,6 @@ def CALL(self, instr: Instruction):
is_method = not isinstance(self.stack.peek[instr.arg + 2], NullVariable)
total_args = instr.arg + int(is_method)
kwnames = self._call_shape if self._call_shape is not None else []
self._call_shape = None
n_kwargs = len(kwnames)
n_positional_args = total_args - n_kwargs
kwargs_list = self.stack.pop_n(n_kwargs)
Expand All @@ -1236,6 +1235,7 @@ def CALL(self, instr: Instruction):
# pop the NULL variable
self.stack.pop()
self.stack.push(fn(*args, **kwargs))
self._call_shape = None

def CALL_FUNCTION(self, instr: Instruction):
assert isinstance(instr.arg, int)
Expand Down Expand Up @@ -1870,6 +1870,8 @@ def _break_graph_in_call(
var_loader.load(stack_arg)

# gen call resume fn opcode
# NOTE(SigureMo): In Python 3.11,we need generate KW_NAMES if the call shape is not None.
self._graph.pycode_gen.gen_kw_names(self._call_shape)
self._graph.pycode_gen.add_pure_instructions([instr])
self.stack.pop_n(pop_n)
stack_size = len(self.stack) + push_n
Expand Down
10 changes: 10 additions & 0 deletions sot/opcode_translator/executor/pycode_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,6 +786,16 @@ def gen_call_method(self, argc=0):
else:
self._add_instr("CALL_METHOD", arg=argc, argval=argc)

def gen_kw_names(self, kw_names: tuple[str, ...] | None):
if kw_names is None:
return
if sys.version_info < (3, 11):
raise InnerError("gen_kw_names is not supported before python3.11")
if kw_names not in self._code_options["co_consts"]:
self._code_options["co_consts"].append(kw_names)
idx = self._code_options["co_consts"].index(kw_names)
self._add_instr("KW_NAMES", arg=idx, argval=kw_names)

def gen_pop_top(self):
self._add_instr("POP_TOP")

Expand Down
1 change: 0 additions & 1 deletion tests/run_all.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ py311_skiped_tests=(
./test_15_slice.py
./test_19_closure.py
./test_21_global.py
./test_constant_graph.py
./test_enumerate.py
./test_guard_user_defined_fn.py
./test_inplace_api.py
Expand Down
2 changes: 1 addition & 1 deletion tests/test_constant_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def func_2(format_str, tensor):
return str, tensor


class TestExecutor(TestCaseBase):
class TestConstantGraph(TestCaseBase):
def test_case_1(self):
x = "{xx} is xx"
tensor = paddle.to_tensor(1)
Expand Down