From 0ba700ba3bf84616e246fbaaf5958072bd6b13fc Mon Sep 17 00:00:00 2001 From: gouzi <530971494@qq.com> Date: Tue, 20 Feb 2024 20:06:54 +0800 Subject: [PATCH 1/2] suppor py312 gen_load_attr --- .../sot/opcode_translator/executor/opcode_executor.py | 1 + .../sot/opcode_translator/executor/pycode_generator.py | 6 ++++-- test/sot/skip_files_py312 | 10 ---------- 3 files changed, 5 insertions(+), 12 deletions(-) diff --git a/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py b/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py index bf399949d26f9..02b01a25a6aa5 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py +++ b/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py @@ -1075,6 +1075,7 @@ def KW_NAMES(self, instr: Instruction): assert isinstance(instr.arg, int) self._call_shape = self._co_consts[instr.arg].get_py_value() + @call_break_graph_decorator(push_n=1) def CALL(self, instr: Instruction): assert isinstance(instr.arg, int) assert instr.arg + 2 <= len(self.stack) diff --git a/python/paddle/jit/sot/opcode_translator/executor/pycode_generator.py b/python/paddle/jit/sot/opcode_translator/executor/pycode_generator.py index f5737beb0947f..0ec780d4b9bda 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/pycode_generator.py +++ b/python/paddle/jit/sot/opcode_translator/executor/pycode_generator.py @@ -797,6 +797,8 @@ def gen_load_attr(self, name: str): if name not in self._code_options["co_names"]: self._code_options["co_names"].append(name) idx = self._code_options["co_names"].index(name) + if sys.version_info >= (3, 12): + idx <<= 1 self._add_instr("LOAD_ATTR", arg=idx, argval=name) def gen_store_attr(self, name: str): @@ -878,7 +880,7 @@ def gen_unpack_sequence(self, count): def gen_call_function(self, argc=0): if sys.version_info >= (3, 11): - if sys.version_info >= (3, 11) and sys.version_info < (3, 12): + if sys.version_info < (3, 12): self._add_instr("PRECALL", arg=argc, argval=argc) self._add_instr("CALL", arg=argc, argval=argc) else: @@ -892,7 +894,7 @@ def gen_call_function_ex(self, has_kwargs): def gen_call_method(self, argc=0): if sys.version_info >= (3, 11): - if sys.version_info >= (3, 11) and sys.version_info < (3, 12): + if sys.version_info < (3, 12): self._add_instr("PRECALL", arg=argc, argval=argc) self._add_instr("CALL", arg=argc, argval=argc) else: diff --git a/test/sot/skip_files_py312 b/test/sot/skip_files_py312 index 49ac001429c1f..5012816657ada 100644 --- a/test/sot/skip_files_py312 +++ b/test/sot/skip_files_py312 @@ -1,8 +1,3 @@ -./test_01_basic.py -./test_03_tuple.py -./test_04_list.py -./test_05_dict.py -./test_07_unpack.py ./test_09_f_string.py ./test_10_build_unpack.py ./test_11_jumps.py @@ -10,18 +5,14 @@ ./test_14_operators.py ./test_15_slice.py ./test_17_paddle_layer.py -./test_18_tensor_method.py -./test_19_closure.py ./test_20_string.py ./test_21_global.py ./test_analysis_inputs.py ./test_binary_operator_tracker.py ./test_break_graph.py -./test_builtin_dispatch.py ./test_builtin_map.py ./test_builtin_range.py ./test_builtin_zip.py -./test_constant_graph.py ./test_dup_top.py ./test_enumerate.py ./test_guard_user_defined_fn.py @@ -33,7 +24,6 @@ ./test_simulate_initialize.py ./test_sir_rollback.py ./test_sot_cost_model.py -./test_sot_exception.py ./test_sot_export.py ./test_sot_resnet.py ./test_sot_resnet50_backward.py From f6e72434422b92cb06613c80031b0ec965a9f3f3 Mon Sep 17 00:00:00 2001 From: gouzi <530971494@qq.com> Date: Tue, 20 Feb 2024 22:00:58 +0800 Subject: [PATCH 2/2] fix `py312-` --- .../sot/opcode_translator/executor/opcode_executor.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py b/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py index 02b01a25a6aa5..8614356ce3a85 100644 --- a/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py +++ b/python/paddle/jit/sot/opcode_translator/executor/opcode_executor.py @@ -1075,8 +1075,7 @@ def KW_NAMES(self, instr: Instruction): assert isinstance(instr.arg, int) self._call_shape = self._co_consts[instr.arg].get_py_value() - @call_break_graph_decorator(push_n=1) - def CALL(self, instr: Instruction): + def call(self, instr: Instruction): assert isinstance(instr.arg, int) assert instr.arg + 2 <= len(self.stack) is_method = not isinstance(self.stack.peek[instr.arg + 2], NullVariable) @@ -1094,6 +1093,12 @@ def CALL(self, instr: Instruction): self.stack.push(fn(*args, **kwargs)) self._call_shape = None + CALL = ( + call_break_graph_decorator(push_n=1)(call) + if sys.version_info >= (3, 12) + else call + ) + @call_break_graph_decorator(push_n=1) def CALL_FUNCTION(self, instr: Instruction): assert isinstance(instr.arg, int)