Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
[Compat][3.11] gen KW_NAMES when call breakgraph, enable `test_cons…
Browse files Browse the repository at this point in the history
…tant_graph.py` (#375)
  • Loading branch information
SigureMo authored Sep 4, 2023
1 parent c165095 commit c64e674
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 4 deletions.
6 changes: 4 additions & 2 deletions sot/opcode_translator/executor/opcode_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@

@dataclass
class Stop:
state: bool
state: str


@Singleton
Expand Down Expand Up @@ -1225,7 +1225,6 @@ def CALL(self, instr: Instruction):
is_method = not isinstance(self.stack.peek[instr.arg + 2], NullVariable)
total_args = instr.arg + int(is_method)
kwnames = self._call_shape if self._call_shape is not None else []
self._call_shape = None
n_kwargs = len(kwnames)
n_positional_args = total_args - n_kwargs
kwargs_list = self.stack.pop_n(n_kwargs)
Expand All @@ -1236,6 +1235,7 @@ def CALL(self, instr: Instruction):
# pop the NULL variable
self.stack.pop()
self.stack.push(fn(*args, **kwargs))
self._call_shape = None

def CALL_FUNCTION(self, instr: Instruction):
assert isinstance(instr.arg, int)
Expand Down Expand Up @@ -1870,6 +1870,8 @@ def _break_graph_in_call(
var_loader.load(stack_arg)

# gen call resume fn opcode
# NOTE(SigureMo): In Python 3.11,we need generate KW_NAMES if the call shape is not None.
self._graph.pycode_gen.gen_kw_names(self._call_shape)
self._graph.pycode_gen.add_pure_instructions([instr])
self.stack.pop_n(pop_n)
stack_size = len(self.stack) + push_n
Expand Down
10 changes: 10 additions & 0 deletions sot/opcode_translator/executor/pycode_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,6 +786,16 @@ def gen_call_method(self, argc=0):
else:
self._add_instr("CALL_METHOD", arg=argc, argval=argc)

def gen_kw_names(self, kw_names: tuple[str, ...] | None):
if kw_names is None:
return
if sys.version_info < (3, 11):
raise InnerError("gen_kw_names is not supported before python3.11")
if kw_names not in self._code_options["co_consts"]:
self._code_options["co_consts"].append(kw_names)
idx = self._code_options["co_consts"].index(kw_names)
self._add_instr("KW_NAMES", arg=idx, argval=kw_names)

def gen_pop_top(self):
self._add_instr("POP_TOP")

Expand Down
1 change: 0 additions & 1 deletion tests/run_all.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ py311_skiped_tests=(
./test_15_slice.py
./test_19_closure.py
./test_21_global.py
./test_constant_graph.py
./test_enumerate.py
./test_guard_user_defined_fn.py
./test_inplace_api.py
Expand Down
2 changes: 1 addition & 1 deletion tests/test_constant_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def func_2(format_str, tensor):
return str, tensor


class TestExecutor(TestCaseBase):
class TestConstantGraph(TestCaseBase):
def test_case_1(self):
x = "{xx} is xx"
tensor = paddle.to_tensor(1)
Expand Down

0 comments on commit c64e674

Please sign in to comment.