From 261eb1466235b85657826ddf26bbc2dc1e31a7bd Mon Sep 17 00:00:00 2001 From: SigureMo Date: Mon, 4 Sep 2023 19:39:12 +0800 Subject: [PATCH 1/3] [Compat][3.11] gen `KW_NAMES` when call breakgraph, enable `test_constant_graph.py` --- sot/opcode_translator/executor/opcode_executor.py | 4 +++- sot/opcode_translator/executor/pycode_generator.py | 10 ++++++++++ tests/run_all.sh | 1 - 3 files changed, 13 insertions(+), 2 deletions(-) diff --git a/sot/opcode_translator/executor/opcode_executor.py b/sot/opcode_translator/executor/opcode_executor.py index 2a84a0c45..96f6e03a8 100644 --- a/sot/opcode_translator/executor/opcode_executor.py +++ b/sot/opcode_translator/executor/opcode_executor.py @@ -1225,7 +1225,6 @@ def CALL(self, instr: Instruction): is_method = not isinstance(self.stack.peek[instr.arg + 2], NullVariable) total_args = instr.arg + int(is_method) kwnames = self._call_shape if self._call_shape is not None else [] - self._call_shape = None n_kwargs = len(kwnames) n_positional_args = total_args - n_kwargs kwargs_list = self.stack.pop_n(n_kwargs) @@ -1236,6 +1235,7 @@ def CALL(self, instr: Instruction): # pop the NULL variable self.stack.pop() self.stack.push(fn(*args, **kwargs)) + self._call_shape = None def CALL_FUNCTION(self, instr: Instruction): assert isinstance(instr.arg, int) @@ -1870,6 +1870,8 @@ def _break_graph_in_call( var_loader.load(stack_arg) # gen call resume fn opcode + # NOTE(SigureMo): In Python 3.11,we need generate KW_NAMES if the call shape is not None. + self._graph.pycode_gen.gen_kw_names(self._call_shape) self._graph.pycode_gen.add_pure_instructions([instr]) self.stack.pop_n(pop_n) stack_size = len(self.stack) + push_n diff --git a/sot/opcode_translator/executor/pycode_generator.py b/sot/opcode_translator/executor/pycode_generator.py index 2b56b5f02..d69dd9238 100644 --- a/sot/opcode_translator/executor/pycode_generator.py +++ b/sot/opcode_translator/executor/pycode_generator.py @@ -786,6 +786,16 @@ def gen_call_method(self, argc=0): else: self._add_instr("CALL_METHOD", arg=argc, argval=argc) + def gen_kw_names(self, kw_names: tuple[str, ...] | None): + if kw_names is None: + return + if sys.version_info < (3, 11): + raise InnerError("gen_kwnames is not supported before python3.11") + if kw_names not in self._code_options["co_consts"]: + self._code_options["co_consts"].append(kw_names) + idx = self._code_options["co_consts"].index(kw_names) + self._add_instr("KW_NAMES", arg=idx, argval=kw_names) + def gen_pop_top(self): self._add_instr("POP_TOP") diff --git a/tests/run_all.sh b/tests/run_all.sh index 2d8d8fb2b..eea54763e 100644 --- a/tests/run_all.sh +++ b/tests/run_all.sh @@ -13,7 +13,6 @@ py311_skiped_tests=( ./test_15_slice.py ./test_19_closure.py ./test_21_global.py - ./test_constant_graph.py ./test_enumerate.py ./test_guard_user_defined_fn.py ./test_inplace_api.py From ffaee0b44eaa0f886ad7b0f7998304ab61abe5e8 Mon Sep 17 00:00:00 2001 From: SigureMo Date: Mon, 4 Sep 2023 19:52:01 +0800 Subject: [PATCH 2/3] fix name --- sot/opcode_translator/executor/pycode_generator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sot/opcode_translator/executor/pycode_generator.py b/sot/opcode_translator/executor/pycode_generator.py index d69dd9238..be511e8b4 100644 --- a/sot/opcode_translator/executor/pycode_generator.py +++ b/sot/opcode_translator/executor/pycode_generator.py @@ -790,7 +790,7 @@ def gen_kw_names(self, kw_names: tuple[str, ...] | None): if kw_names is None: return if sys.version_info < (3, 11): - raise InnerError("gen_kwnames is not supported before python3.11") + raise InnerError("gen_kw_names is not supported before python3.11") if kw_names not in self._code_options["co_consts"]: self._code_options["co_consts"].append(kw_names) idx = self._code_options["co_consts"].index(kw_names) From e373bb1f36a1d346dd2f40bf42c1e2d1f4903616 Mon Sep 17 00:00:00 2001 From: SigureMo Date: Mon, 4 Sep 2023 20:01:19 +0800 Subject: [PATCH 3/3] minor fix --- sot/opcode_translator/executor/opcode_executor.py | 2 +- tests/test_constant_graph.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sot/opcode_translator/executor/opcode_executor.py b/sot/opcode_translator/executor/opcode_executor.py index 96f6e03a8..6fa8093ef 100644 --- a/sot/opcode_translator/executor/opcode_executor.py +++ b/sot/opcode_translator/executor/opcode_executor.py @@ -107,7 +107,7 @@ @dataclass class Stop: - state: bool + state: str @Singleton diff --git a/tests/test_constant_graph.py b/tests/test_constant_graph.py index 7cbb8edce..0d42bb08c 100644 --- a/tests/test_constant_graph.py +++ b/tests/test_constant_graph.py @@ -24,7 +24,7 @@ def func_2(format_str, tensor): return str, tensor -class TestExecutor(TestCaseBase): +class TestConstantGraph(TestCaseBase): def test_case_1(self): x = "{xx} is xx" tensor = paddle.to_tensor(1)