From 4974b061f136690c98c4a6dba599ec99b65489c5 Mon Sep 17 00:00:00 2001 From: gouzil <66515297+gouzil@users.noreply.github.com> Date: Sun, 24 Sep 2023 11:54:52 +0800 Subject: [PATCH] [opcode][3.11] support closure (#372) Co-authored-by: SigureMo --- docs/compat/python311/closure.md | 134 ++++++++++++++++++ docs/compat/python311/index.md | 2 + .../executor/function_graph.py | 5 +- .../executor/opcode_executor.py | 27 +++- .../executor/pycode_generator.py | 59 +++++++- .../executor/variables/basic.py | 9 ++ .../instruction_utils/instruction_utils.py | 8 ++ tests/run_all.sh | 12 -- tests/run_all_paddle_ci.sh | 2 +- tests/test_19_closure.py | 17 ++- 10 files changed, 250 insertions(+), 25 deletions(-) create mode 100644 docs/compat/python311/closure.md diff --git a/docs/compat/python311/closure.md b/docs/compat/python311/closure.md new file mode 100644 index 000000000..0dae0e2f2 --- /dev/null +++ b/docs/compat/python311/closure.md @@ -0,0 +1,134 @@ +# Closure 适配 + + +## Python 中的闭包示例 + +以下是在新版本中闭包函数处理的demo,以及它对应的字节码 : + +```python +import dis + +def func(): + free_x = 1 + free_y = 2 + + def local(y): + return y + free_x + free_y + return local(1) + +dis.dis(func) +``` + +```bash + 0 MAKE_CELL 1 (free_x) + 2 MAKE_CELL 2 (free_y) + + 9 4 RESUME 0 + + 10 6 LOAD_CONST 1 (1) + 8 STORE_DEREF 1 (free_x) + + 11 10 LOAD_CONST 2 (2) + 12 STORE_DEREF 2 (free_y) + + 13 14 LOAD_CLOSURE 1 (free_x) + 16 LOAD_CLOSURE 2 (free_y) + 18 BUILD_TUPLE 2 + 20 LOAD_CONST 3 () + 22 MAKE_FUNCTION 8 (closure) + 24 STORE_FAST 0 (local) + + 15 26 PUSH_NULL + 28 LOAD_FAST 0 (local) + 30 LOAD_CONST 1 (1) + 32 PRECALL 1 + 36 CALL 1 + 46 RETURN_VALUE + +Disassembly of : + 0 COPY_FREE_VARS 2 + + 13 2 RESUME 0 + + 14 4 LOAD_FAST 0 (y) + 6 LOAD_DEREF 1 (free_x) + 8 BINARY_OP 0 (+) + 12 LOAD_DEREF 2 (free_y) + 14 BINARY_OP 0 (+) + 18 RETURN_VALUE +``` + +## 新版本中对字节码的改动: + +### 首先是语义上的改动 + +LOAD_CLOSURE: 新版本不再是`co_cellvars + co_freevars`长度偏移量, 而是`LOAD_FAST`的一个别名 + +LOAD_DEREF: 加载包含在 locals 中的元素 + +STORE_DEREF: 存储 TOS 到 locals 中 + +### 新增字节码 + +MAKE_CELL: 如果元素不存在于 locals 则从 co_freevars 和 co_cellvars 中加载 + +COPY_FREE_VARS: 复制 co_freevars 和 co_cellvars 中的元素到 locals + +## 分析 + +从字节码上的改动来看,在 python3.11 中, 闭包将数据存储在 locals 中,而不是 cell 中,这样做的好处是可以减少一次间接寻址,提高性能。 + +## 实现 + +LOAD_CLOSURE: 作为`LOAD_FAST`的别名,所以直接调用 + +LOAD_DEREF: 改为从 `self._locals` 中加载元素到 TOS 中 + +STORE_DEREF: 改为存储 TOS 到 `self._locals` 中 + +MAKE_CELL: 从 `self._cells` 中加载元素到 `self._locals` + +COPY_FREE_VARS(闭包内部字节码): 从 `self._code.co_freevars` 拿到 key 在 `self._cells` 中找到元素存储到 `self._locals` + +## codegen + +```bash +[transform] NewCode: #foo_af1a0 + 9 0 MAKE_CELL 0 (x) # 在此处生成存储字节码,将元素存储至 locals + 2 MAKE_CELL 1 (y) + 4 MAKE_CELL 5 (z) + 6 RESUME 0 + 8 LOAD_GLOBAL 1 (NULL + paddle_set_eval_frame_fn) + ... + 104 POP_TOP + 106 RETURN_VALUE + +Disassembly of : + 0 COPY_FREE_VARS 3 # 在此处生成拷贝字节码,将数据拷贝至闭包内部调用 + + 12 2 RESUME 0 + + 13 4 LOAD_FAST 0 (a) + ... + 30 RETURN_VALUE + +``` + + +## 单测 + +新增一项之前未覆盖情况 + +```python +def create_closure(): + x = 1 + + def closure(): + return x + 1 + + return closure +``` + +## 其他更改 + +此次升级还依赖于 eval frame 修改,相关适配链接:[#57490](https://github.com/PaddlePaddle/Paddle/pull/57490)、[#57653](https://github.com/PaddlePaddle/Paddle/pull/57653) diff --git a/docs/compat/python311/index.md b/docs/compat/python311/index.md index 7c23ed85c..659b80a1f 100644 --- a/docs/compat/python311/index.md +++ b/docs/compat/python311/index.md @@ -9,3 +9,5 @@ ## 字节码修改适配 - [CALL 相关字节码](./CALL.md) + +- [closure 相关修改](./closure.md) diff --git a/sot/opcode_translator/executor/function_graph.py b/sot/opcode_translator/executor/function_graph.py index 3cc526691..29706453c 100644 --- a/sot/opcode_translator/executor/function_graph.py +++ b/sot/opcode_translator/executor/function_graph.py @@ -392,7 +392,10 @@ def get_opcode_executor_stack(): source_lines, start_line = inspect.getsourcelines( current_executor._code ) - code_line = source_lines[current_line - start_line] + # TODO(SigureMo): In 3.11, lineno maybe changed after multiple breakgraph, + # We need to find a way to fix this. + line_idx = min(current_line - start_line, len(source_lines) - 1) + code_line = source_lines[line_idx] stack = [] stack.append( ' File "{}", line {}, in {}'.format( diff --git a/sot/opcode_translator/executor/opcode_executor.py b/sot/opcode_translator/executor/opcode_executor.py index 454bdecb3..45aeeec0b 100644 --- a/sot/opcode_translator/executor/opcode_executor.py +++ b/sot/opcode_translator/executor/opcode_executor.py @@ -907,19 +907,32 @@ def LOAD_CONST(self, instr: Instruction): var = self._co_consts[instr.arg] self.stack.push(var) - def LOAD_CLOSURE(self, instr): + def MAKE_CELL(self, instr: Instruction): + self._locals[instr.argval] = self._cells[instr.argval] + + def LOAD_CLOSURE(self, instr: Instruction): + if sys.version_info >= (3, 11): + self.LOAD_FAST(instr) + return namemap = self._code.co_cellvars + self._code.co_freevars name = namemap[instr.arg] self.stack.push(self._cells[name]) - def LOAD_DEREF(self, instr): + def LOAD_DEREF(self, instr: Instruction): + if sys.version_info >= (3, 11): + self.stack.push(self._locals[instr.argval].cell_content()) + return namemap = self._code.co_cellvars + self._code.co_freevars name = namemap[instr.arg] self.stack.push(self._cells[name].cell_content()) + def COPY_FREE_VARS(self, instr: Instruction): + for i in range(instr.arg): + freevar_name = self._code.co_freevars[i] + self._locals[freevar_name] = self._cells[freevar_name] + def LOAD_FAST(self, instr: Instruction): - varname = self._code.co_varnames[instr.arg] - var = self._locals[varname] + var = self._locals[instr.argval] self.stack.push(var) def DELETE_FAST(self, instr: Instruction): @@ -983,7 +996,11 @@ def STORE_ATTR(self, instr): f"STORE_ATTR don't support {type(obj)}.{key}={val}" ) - def STORE_DEREF(self, instr): + def STORE_DEREF(self, instr: Instruction): + if sys.version_info >= (3, 11): + self._cells[instr.argval].set_value(self.stack.pop()) + self._locals[instr.argval] = self._cells[instr.argval] + return namemap = self._code.co_cellvars + self._code.co_freevars name = namemap[instr.arg] self._cells[name].set_value(self.stack.pop()) diff --git a/sot/opcode_translator/executor/pycode_generator.py b/sot/opcode_translator/executor/pycode_generator.py index 05340e515..2e9e51e3d 100644 --- a/sot/opcode_translator/executor/pycode_generator.py +++ b/sot/opcode_translator/executor/pycode_generator.py @@ -418,11 +418,43 @@ def __init__( self._f_globals = frame.f_globals self._instructions = [] self.disable_eval_frame = disable_eval_frame - if sys.version_info >= (3, 11): - self._add_instr("RESUME", arg=0, argval=0) if self.disable_eval_frame: self.gen_disable_eval_frame() + def insert_prefix_instructions(self): + """ + Insert prefix instructions to the instruction list. + In Python 3.11+, we need to insert MAKE_CELL and COPY_FREE_VARS before the + first instruction. + The implementation is based on cpython implementation: + https://github.com/python/cpython/blob/f45ef5edabb1cc0748f3326e7114b8aaa0424392/Python/compile.c#L8177 + """ + prefixes = [] + if sys.version_info >= (3, 11): + if self._code_options["co_cellvars"]: + # Insert MAKE_CELL + name_map = list( + OrderedSet(self._code_options["co_varnames"]) + | OrderedSet(self._code_options["co_cellvars"]) + ) + + for i in self._code_options["co_cellvars"]: + idx: int = name_map.index(i) + prefixes.append(gen_instr("MAKE_CELL", arg=idx, argval=i)) + + if self._code_options["co_freevars"]: + n_freevars = len(self._code_options["co_freevars"]) + # Insert COPY_FREE_VARS + prefixes.append( + gen_instr( + "COPY_FREE_VARS", arg=n_freevars, argval=n_freevars + ) + ) + + # Insert RESUME + prefixes.append(gen_instr("RESUME", arg=0, argval=0)) + self._instructions[:] = prefixes + self._instructions + def update_code_name(self, fn_name, is_resumed_fn): if is_resumed_fn: self._code_options[ @@ -446,6 +478,7 @@ def gen_pycode(self) -> types.CodeType: Returns: CodeType: The generated code object. """ + self.insert_prefix_instructions() modify_instrs(self._instructions) modify_vars(self._instructions, self._code_options) new_code = gen_new_opcode( @@ -576,7 +609,7 @@ def gen_load_const(self, value: Any): self._add_instr("LOAD_CONST", arg=idx, argval=value) def gen_print_log(self, message): - """print a log :""" + """print a log""" import paddle self.gen_load_object( @@ -712,7 +745,15 @@ def gen_load_fast(self, name): def gen_load_deref(self, name): if name not in self.cell_free_storage: self._code_options["co_freevars"].append(name) - idx = self.cell_free_storage.index(name) + if sys.version_info >= (3, 11): + # Because the co_varnames maybe changed after other codegen + # operations, we need re-calculate the index in modify_vars + idx = ( + self._code_options["co_varnames"] + + self._code_options["co_freevars"] + ).index(name) + else: + idx = self.cell_free_storage.index(name) self._add_instr("LOAD_DEREF", arg=idx, argval=name) def gen_load_attr(self, name: str): @@ -768,7 +809,15 @@ def gen_store_global(self, name): def gen_store_deref(self, name): if name not in self.cell_free_storage: self._code_options["co_freevars"].append(name) - idx = self.cell_free_storage.index(name) + if sys.version_info >= (3, 11): + # Because the co_varnames maybe changed after other codegen + # operations, we need re-calculate the index in modify_vars + idx = ( + self._code_options["co_varnames"] + + self._code_options["co_freevars"] + ).index(name) + else: + idx = self.cell_free_storage.index(name) self._add_instr("STORE_DEREF", arg=idx, argval=name) def gen_store_subscr(self): diff --git a/sot/opcode_translator/executor/variables/basic.py b/sot/opcode_translator/executor/variables/basic.py index 9dbbb15d1..610688142 100644 --- a/sot/opcode_translator/executor/variables/basic.py +++ b/sot/opcode_translator/executor/variables/basic.py @@ -757,6 +757,15 @@ def __init__(self, value=None): assert isinstance(value, (VariableBase, type(None))) self.set_value(value) + def reconstruct( + self, + codegen: PyCodeGen, + *, + use_tracker: bool = True, + add_to_global_guarded_vars: bool = True, + ): + raise FallbackError("Break graph in closure is not support.") + def cell_content(self): return self.value diff --git a/sot/opcode_translator/instruction_utils/instruction_utils.py b/sot/opcode_translator/instruction_utils/instruction_utils.py index 8ed8372ce..6bef3db62 100644 --- a/sot/opcode_translator/instruction_utils/instruction_utils.py +++ b/sot/opcode_translator/instruction_utils/instruction_utils.py @@ -300,12 +300,20 @@ def bind_ex_arg_with_instr(ex_arg, instr): def modify_vars(instructions, code_options): co_names = code_options['co_names'] co_varnames = code_options['co_varnames'] + co_freevars = code_options['co_freevars'] for instrs in instructions: if instrs.opname == 'LOAD_FAST' or instrs.opname == 'STORE_FAST': assert ( instrs.argval in co_varnames ), f"`{instrs.argval}` not in {co_varnames}" instrs.arg = co_varnames.index(instrs.argval) + elif instrs.opname == "LOAD_DEREF" or instrs.opname == "STORE_DEREF": + if sys.version_info >= (3, 11): + namemap = co_varnames + co_freevars + assert ( + instrs.argval in namemap + ), f"`{instrs.argval}` not in {namemap}" + instrs.arg = namemap.index(instrs.argval) def calc_offset_from_bytecode_offset( diff --git a/tests/run_all.sh b/tests/run_all.sh index b98964249..af018bf56 100644 --- a/tests/run_all.sh +++ b/tests/run_all.sh @@ -4,23 +4,11 @@ export STRICT_MODE=1 export COST_MODEL=False export MIN_GRAPH_SIZE=0 -IS_PY311=`python -c "import sys; print(sys.version_info >= (3, 11))"` -echo "IS_PY311:" $IS_PY311 - failed_tests=() -py311_skiped_tests=( - ./test_19_closure.py - ./test_tensor_dtype_in_guard.py -) - for file in ./test_*.py; do # 检查文件是否为 python 文件 if [ -f "$file" ]; then - if [[ "$IS_PY311" == "True" && "${py311_skiped_tests[@]}" =~ "$file" ]]; then - echo "skip $file for python3.11" - continue - fi if [[ -n "$GITHUB_ACTIONS" ]]; then echo ::group::Running: PYTHONPATH=$PYTHONPATH " STRICT_MODE=1 python " $file else diff --git a/tests/run_all_paddle_ci.sh b/tests/run_all_paddle_ci.sh index bd1ec96bf..82b75176b 100644 --- a/tests/run_all_paddle_ci.sh +++ b/tests/run_all_paddle_ci.sh @@ -9,7 +9,7 @@ failed_tests=() disabled_tests=( ${PADDLE_TEST_BASE}/test_lac.py # disabled by paddle ${PADDLE_TEST_BASE}/test_sentiment.py # disabled unitcase by paddle - ${PADDLE_TEST_BASE}/test_convert_call.py + ${PADDLE_TEST_BASE}/test_pylayer.py # This ut cannot directly run ) for file in ${PADDLE_TEST_BASE}/*.py; do diff --git a/tests/test_19_closure.py b/tests/test_19_closure.py index d24651f49..e089030c3 100644 --- a/tests/test_19_closure.py +++ b/tests/test_19_closure.py @@ -12,7 +12,7 @@ def foo(x: int, y: paddle.Tensor): def local(a, b=5): return a + x + z + b + y - return local(4) + return local(4) + z def foo2(y: paddle.Tensor, x=1): @@ -146,6 +146,15 @@ def foo7(): return func7(3, 5) +def create_closure(): + x = 1 + + def closure(): + return x + 1 + + return closure + + class TestExecutor(TestCaseBase): def test_closure(self): self.assert_results(foo, 1, paddle.to_tensor(2)) @@ -219,6 +228,12 @@ def test_closure(self): self.assert_results(non_local_test, tx) +class TestCreateClosure(TestCaseBase): + def test_create_closure(self): + closure = create_closure() + self.assert_results(closure) + + if __name__ == "__main__": unittest.main()