Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
[opcode][3.11] support closure (#372)
Browse files Browse the repository at this point in the history
Co-authored-by: SigureMo <sigure.qaq@gmail.com>
  • Loading branch information
gouzil and SigureMo authored Sep 24, 2023
1 parent 7f80cf2 commit 4974b06
Show file tree
Hide file tree
Showing 10 changed files with 250 additions and 25 deletions.
134 changes: 134 additions & 0 deletions docs/compat/python311/closure.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# Closure 适配


## Python 中的闭包示例

以下是在新版本中闭包函数处理的demo,以及它对应的字节码 :

```python
import dis

def func():
free_x = 1
free_y = 2

def local(y):
return y + free_x + free_y
return local(1)

dis.dis(func)
```

```bash
0 MAKE_CELL 1 (free_x)
2 MAKE_CELL 2 (free_y)

9 4 RESUME 0

10 6 LOAD_CONST 1 (1)
8 STORE_DEREF 1 (free_x)

11 10 LOAD_CONST 2 (2)
12 STORE_DEREF 2 (free_y)

13 14 LOAD_CLOSURE 1 (free_x)
16 LOAD_CLOSURE 2 (free_y)
18 BUILD_TUPLE 2
20 LOAD_CONST 3 (<code object local at 0x1022e0100, file "/Users/gouzi/Documents/git/paddle-symbolic-trace/tests/demo2.py", line 13>)
22 MAKE_FUNCTION 8 (closure)
24 STORE_FAST 0 (local)

15 26 PUSH_NULL
28 LOAD_FAST 0 (local)
30 LOAD_CONST 1 (1)
32 PRECALL 1
36 CALL 1
46 RETURN_VALUE

Disassembly of <code object local at 0x1022e0100, file "/Users/gouzi/Documents/git/paddle-symbolic-trace/tests/demo2.py", line 13>:
0 COPY_FREE_VARS 2

13 2 RESUME 0

14 4 LOAD_FAST 0 (y)
6 LOAD_DEREF 1 (free_x)
8 BINARY_OP 0 (+)
12 LOAD_DEREF 2 (free_y)
14 BINARY_OP 0 (+)
18 RETURN_VALUE
```

## 新版本中对字节码的改动:

### 首先是语义上的改动

LOAD_CLOSURE: 新版本不再是`co_cellvars + co_freevars`长度偏移量, 而是`LOAD_FAST`的一个别名

LOAD_DEREF: 加载包含在 locals 中的元素

STORE_DEREF: 存储 TOS 到 locals 中

### 新增字节码

MAKE_CELL: 如果元素不存在于 locals 则从 co_freevars 和 co_cellvars 中加载

COPY_FREE_VARS: 复制 co_freevars 和 co_cellvars 中的元素到 locals

## 分析

从字节码上的改动来看,在 python3.11 中, 闭包将数据存储在 locals 中,而不是 cell 中,这样做的好处是可以减少一次间接寻址,提高性能。

## 实现

LOAD_CLOSURE: 作为`LOAD_FAST`的别名,所以直接调用

LOAD_DEREF: 改为从 `self._locals` 中加载元素到 TOS 中

STORE_DEREF: 改为存储 TOS 到 `self._locals`

MAKE_CELL: 从 `self._cells` 中加载元素到 `self._locals`

COPY_FREE_VARS(闭包内部字节码): 从 `self._code.co_freevars` 拿到 key 在 `self._cells` 中找到元素存储到 `self._locals`

## codegen

```bash
[transform] NewCode: #foo_af1a0
9 0 MAKE_CELL 0 (x) # 在此处生成存储字节码,将元素存储至 locals
2 MAKE_CELL 1 (y)
4 MAKE_CELL 5 (z)
6 RESUME 0
8 LOAD_GLOBAL 1 (NULL + paddle_set_eval_frame_fn)
...
104 POP_TOP
106 RETURN_VALUE

Disassembly of <code object local at 0x10cb216f0, file "/Users/gouzi/Documents/git/paddle-symbolic-trace/tests/test_19_closure.py", line 12>:
0 COPY_FREE_VARS 3 # 在此处生成拷贝字节码,将数据拷贝至闭包内部调用

12 2 RESUME 0

13 4 LOAD_FAST 0 (a)
...
30 RETURN_VALUE

```


## 单测

新增一项之前未覆盖情况

```python
def create_closure():
x = 1

def closure():
return x + 1

return closure
```

## 其他更改

此次升级还依赖于 eval frame 修改,相关适配链接:[#57490](https://github.com/PaddlePaddle/Paddle/pull/57490)[#57653](https://github.com/PaddlePaddle/Paddle/pull/57653)
2 changes: 2 additions & 0 deletions docs/compat/python311/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,5 @@
## 字节码修改适配

- [CALL 相关字节码](./CALL.md)

- [closure 相关修改](./closure.md)
5 changes: 4 additions & 1 deletion sot/opcode_translator/executor/function_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,10 @@ def get_opcode_executor_stack():
source_lines, start_line = inspect.getsourcelines(
current_executor._code
)
code_line = source_lines[current_line - start_line]
# TODO(SigureMo): In 3.11, lineno maybe changed after multiple breakgraph,
# We need to find a way to fix this.
line_idx = min(current_line - start_line, len(source_lines) - 1)
code_line = source_lines[line_idx]
stack = []
stack.append(
' File "{}", line {}, in {}'.format(
Expand Down
27 changes: 22 additions & 5 deletions sot/opcode_translator/executor/opcode_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -907,19 +907,32 @@ def LOAD_CONST(self, instr: Instruction):
var = self._co_consts[instr.arg]
self.stack.push(var)

def LOAD_CLOSURE(self, instr):
def MAKE_CELL(self, instr: Instruction):
self._locals[instr.argval] = self._cells[instr.argval]

def LOAD_CLOSURE(self, instr: Instruction):
if sys.version_info >= (3, 11):
self.LOAD_FAST(instr)
return
namemap = self._code.co_cellvars + self._code.co_freevars
name = namemap[instr.arg]
self.stack.push(self._cells[name])

def LOAD_DEREF(self, instr):
def LOAD_DEREF(self, instr: Instruction):
if sys.version_info >= (3, 11):
self.stack.push(self._locals[instr.argval].cell_content())
return
namemap = self._code.co_cellvars + self._code.co_freevars
name = namemap[instr.arg]
self.stack.push(self._cells[name].cell_content())

def COPY_FREE_VARS(self, instr: Instruction):
for i in range(instr.arg):
freevar_name = self._code.co_freevars[i]
self._locals[freevar_name] = self._cells[freevar_name]

def LOAD_FAST(self, instr: Instruction):
varname = self._code.co_varnames[instr.arg]
var = self._locals[varname]
var = self._locals[instr.argval]
self.stack.push(var)

def DELETE_FAST(self, instr: Instruction):
Expand Down Expand Up @@ -983,7 +996,11 @@ def STORE_ATTR(self, instr):
f"STORE_ATTR don't support {type(obj)}.{key}={val}"
)

def STORE_DEREF(self, instr):
def STORE_DEREF(self, instr: Instruction):
if sys.version_info >= (3, 11):
self._cells[instr.argval].set_value(self.stack.pop())
self._locals[instr.argval] = self._cells[instr.argval]
return
namemap = self._code.co_cellvars + self._code.co_freevars
name = namemap[instr.arg]
self._cells[name].set_value(self.stack.pop())
Expand Down
59 changes: 54 additions & 5 deletions sot/opcode_translator/executor/pycode_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,11 +418,43 @@ def __init__(
self._f_globals = frame.f_globals
self._instructions = []
self.disable_eval_frame = disable_eval_frame
if sys.version_info >= (3, 11):
self._add_instr("RESUME", arg=0, argval=0)
if self.disable_eval_frame:
self.gen_disable_eval_frame()

def insert_prefix_instructions(self):
"""
Insert prefix instructions to the instruction list.
In Python 3.11+, we need to insert MAKE_CELL and COPY_FREE_VARS before the
first instruction.
The implementation is based on cpython implementation:
https://github.com/python/cpython/blob/f45ef5edabb1cc0748f3326e7114b8aaa0424392/Python/compile.c#L8177
"""
prefixes = []
if sys.version_info >= (3, 11):
if self._code_options["co_cellvars"]:
# Insert MAKE_CELL
name_map = list(
OrderedSet(self._code_options["co_varnames"])
| OrderedSet(self._code_options["co_cellvars"])
)

for i in self._code_options["co_cellvars"]:
idx: int = name_map.index(i)
prefixes.append(gen_instr("MAKE_CELL", arg=idx, argval=i))

if self._code_options["co_freevars"]:
n_freevars = len(self._code_options["co_freevars"])
# Insert COPY_FREE_VARS
prefixes.append(
gen_instr(
"COPY_FREE_VARS", arg=n_freevars, argval=n_freevars
)
)

# Insert RESUME
prefixes.append(gen_instr("RESUME", arg=0, argval=0))
self._instructions[:] = prefixes + self._instructions

def update_code_name(self, fn_name, is_resumed_fn):
if is_resumed_fn:
self._code_options[
Expand All @@ -446,6 +478,7 @@ def gen_pycode(self) -> types.CodeType:
Returns:
CodeType: The generated code object.
"""
self.insert_prefix_instructions()
modify_instrs(self._instructions)
modify_vars(self._instructions, self._code_options)
new_code = gen_new_opcode(
Expand Down Expand Up @@ -576,7 +609,7 @@ def gen_load_const(self, value: Any):
self._add_instr("LOAD_CONST", arg=idx, argval=value)

def gen_print_log(self, message):
"""print a log :"""
"""print a log"""
import paddle

self.gen_load_object(
Expand Down Expand Up @@ -712,7 +745,15 @@ def gen_load_fast(self, name):
def gen_load_deref(self, name):
if name not in self.cell_free_storage:
self._code_options["co_freevars"].append(name)
idx = self.cell_free_storage.index(name)
if sys.version_info >= (3, 11):
# Because the co_varnames maybe changed after other codegen
# operations, we need re-calculate the index in modify_vars
idx = (
self._code_options["co_varnames"]
+ self._code_options["co_freevars"]
).index(name)
else:
idx = self.cell_free_storage.index(name)
self._add_instr("LOAD_DEREF", arg=idx, argval=name)

def gen_load_attr(self, name: str):
Expand Down Expand Up @@ -768,7 +809,15 @@ def gen_store_global(self, name):
def gen_store_deref(self, name):
if name not in self.cell_free_storage:
self._code_options["co_freevars"].append(name)
idx = self.cell_free_storage.index(name)
if sys.version_info >= (3, 11):
# Because the co_varnames maybe changed after other codegen
# operations, we need re-calculate the index in modify_vars
idx = (
self._code_options["co_varnames"]
+ self._code_options["co_freevars"]
).index(name)
else:
idx = self.cell_free_storage.index(name)
self._add_instr("STORE_DEREF", arg=idx, argval=name)

def gen_store_subscr(self):
Expand Down
9 changes: 9 additions & 0 deletions sot/opcode_translator/executor/variables/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,6 +757,15 @@ def __init__(self, value=None):
assert isinstance(value, (VariableBase, type(None)))
self.set_value(value)

def reconstruct(
self,
codegen: PyCodeGen,
*,
use_tracker: bool = True,
add_to_global_guarded_vars: bool = True,
):
raise FallbackError("Break graph in closure is not support.")

def cell_content(self):
return self.value

Expand Down
8 changes: 8 additions & 0 deletions sot/opcode_translator/instruction_utils/instruction_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,12 +300,20 @@ def bind_ex_arg_with_instr(ex_arg, instr):
def modify_vars(instructions, code_options):
co_names = code_options['co_names']
co_varnames = code_options['co_varnames']
co_freevars = code_options['co_freevars']
for instrs in instructions:
if instrs.opname == 'LOAD_FAST' or instrs.opname == 'STORE_FAST':
assert (
instrs.argval in co_varnames
), f"`{instrs.argval}` not in {co_varnames}"
instrs.arg = co_varnames.index(instrs.argval)
elif instrs.opname == "LOAD_DEREF" or instrs.opname == "STORE_DEREF":
if sys.version_info >= (3, 11):
namemap = co_varnames + co_freevars
assert (
instrs.argval in namemap
), f"`{instrs.argval}` not in {namemap}"
instrs.arg = namemap.index(instrs.argval)


def calc_offset_from_bytecode_offset(
Expand Down
12 changes: 0 additions & 12 deletions tests/run_all.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,11 @@ export STRICT_MODE=1
export COST_MODEL=False
export MIN_GRAPH_SIZE=0

IS_PY311=`python -c "import sys; print(sys.version_info >= (3, 11))"`
echo "IS_PY311:" $IS_PY311

failed_tests=()

py311_skiped_tests=(
./test_19_closure.py
./test_tensor_dtype_in_guard.py
)

for file in ./test_*.py; do
# 检查文件是否为 python 文件
if [ -f "$file" ]; then
if [[ "$IS_PY311" == "True" && "${py311_skiped_tests[@]}" =~ "$file" ]]; then
echo "skip $file for python3.11"
continue
fi
if [[ -n "$GITHUB_ACTIONS" ]]; then
echo ::group::Running: PYTHONPATH=$PYTHONPATH " STRICT_MODE=1 python " $file
else
Expand Down
2 changes: 1 addition & 1 deletion tests/run_all_paddle_ci.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ failed_tests=()
disabled_tests=(
${PADDLE_TEST_BASE}/test_lac.py # disabled by paddle
${PADDLE_TEST_BASE}/test_sentiment.py # disabled unitcase by paddle
${PADDLE_TEST_BASE}/test_convert_call.py
${PADDLE_TEST_BASE}/test_pylayer.py # This ut cannot directly run
)

for file in ${PADDLE_TEST_BASE}/*.py; do
Expand Down
Loading

0 comments on commit 4974b06

Please sign in to comment.