Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

[Compat][3.11] enable test_11_jumps.py #361

Merged
merged 12 commits into from
Aug 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 79 additions & 31 deletions sot/opcode_translator/executor/opcode_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,45 @@ def inner(self: OpcodeExecutorBase, instr: Instruction):
return inner


def jump_break_graph_decorator(normal_jump):
def pop_jump_if_op_wrapper(fn: Callable[[VariableBase], bool]):
"""
A decorator function that wraps a POP_JUMP_*_IF_* opcode operation and applies certain functionality to it.

Args:
fn: The condition function.

Returns:
The wrapped POP_JUMP_*_IF_* opcode operation.

"""

@jump_break_graph_decorator
def inner(self: OpcodeExecutorBase, instr: Instruction):
"""
Inner function that represents the wrapped POP_JUMP_IF opcode operation.

Args:
self: The instance of the OpcodeExecutorBase class.
instr: The instruction to be executed.

"""
pred_obj = self.stack.pop()

if isinstance(pred_obj, (ConstantVariable, ContainerVariable)):
self._graph.add_global_guarded_variable(pred_obj)
is_jump = fn(pred_obj)
if is_jump:
assert instr.jump_to is not None
self.jump_to(instr.jump_to)
return
raise NotImplementException(
f"Currently don't support predicate a non-const / non-tensor obj, but got {pred_obj}"
)

return inner


def jump_break_graph_decorator(normal_jump: Callable):
"""
A decorator function that breaks off the graph when a JUMP-related instruction is encountered.

Expand Down Expand Up @@ -726,6 +764,16 @@ def indexof(self, instr: Instruction):
"""
return self._instructions.index(instr)

def jump_to(self, instr: Instruction):
"""
Jumps to the given instruction.

Args:
instr: The instruction to jump to.

"""
self._lasti = self.indexof(instr)

def COPY(self, instr: Instruction):
assert isinstance(instr.arg, int)
self.stack.push(self.stack.peek[instr.arg])
Expand Down Expand Up @@ -1361,11 +1409,19 @@ def GET_ITER(self, instr: Instruction):
)
)

def JUMP_FORWARD(self, instr):
self._lasti = self.indexof(instr.jump_to)

def JUMP_ABSOLUTE(self, instr: Instruction):
self._lasti = self.indexof(instr.jump_to)
assert instr.jump_to is not None
self.jump_to(instr.jump_to)

def JUMP_FORWARD(self, instr: Instruction):
self.JUMP_ABSOLUTE(instr)

def JUMP_BACKWARD(self, instr: Instruction):
# TODO: check interrupt
self.JUMP_ABSOLUTE(instr)

def JUMP_BACKWARD_NO_INTERRUPT(self, instr: Instruction):
self.JUMP_ABSOLUTE(instr)

def CONTAINS_OP(self, instr: Instruction):
# It will only be 0 or 1
Expand All @@ -1385,7 +1441,8 @@ def JUMP_IF_FALSE_OR_POP(self, instr: Instruction):
self._graph.add_global_guarded_variable(pred_obj)
is_jump = not bool(pred_obj)
if is_jump:
self._lasti = self.indexof(instr.jump_to)
assert instr.jump_to is not None
self.jump_to(instr.jump_to)
else:
self.stack.pop()
return
Expand All @@ -1400,39 +1457,30 @@ def JUMP_IF_TRUE_OR_POP(self, instr: Instruction):
self._graph.add_global_guarded_variable(pred_obj)
is_jump = bool(pred_obj)
if is_jump:
self._lasti = self.indexof(instr.jump_to)
assert instr.jump_to is not None
self.jump_to(instr.jump_to)
else:
self.stack.pop()
return
raise NotImplementException(
"Currently don't support predicate a non-const / non-tensor obj."
)

@jump_break_graph_decorator
def POP_JUMP_IF_FALSE(self, instr: Instruction):
pred_obj = self.stack.pop()
if isinstance(pred_obj, (ConstantVariable, ContainerVariable)):
self._graph.add_global_guarded_variable(pred_obj)
is_jump = not bool(pred_obj)
if is_jump:
self._lasti = self.indexof(instr.jump_to)
return
raise NotImplementException(
"Currently don't support predicate a non-const / non-tensor obj."
)
POP_JUMP_IF_FALSE = pop_jump_if_op_wrapper(lambda x: not bool(x))
POP_JUMP_FORWARD_IF_FALSE = POP_JUMP_IF_FALSE
POP_JUMP_BACKWARD_IF_FALSE = POP_JUMP_IF_FALSE

@jump_break_graph_decorator
def POP_JUMP_IF_TRUE(self, instr: Instruction):
pred_obj = self.stack.pop()
if isinstance(pred_obj, (ConstantVariable, ContainerVariable)):
self._graph.add_global_guarded_variable(pred_obj)
is_jump = bool(pred_obj)
if is_jump:
self._lasti = self.indexof(instr.jump_to)
return
raise NotImplementException(
"Currently don't support predicate a non-const / non-tensor obj."
)
POP_JUMP_IF_TRUE = pop_jump_if_op_wrapper(bool)
POP_JUMP_FORWARD_IF_TRUE = POP_JUMP_IF_TRUE
POP_JUMP_BACKWARD_IF_TRUE = POP_JUMP_IF_TRUE

POP_JUMP_FORWARD_IF_NONE = pop_jump_if_op_wrapper(lambda x: x.is_none())
POP_JUMP_BACKWARD_IF_NONE = POP_JUMP_FORWARD_IF_NONE

POP_JUMP_FORWARD_IF_NOT_NONE = pop_jump_if_op_wrapper(
lambda x: not x.is_none()
)
POP_JUMP_BACKWARD_IF_NOT_NONE = POP_JUMP_FORWARD_IF_NOT_NONE
zrr1999 marked this conversation as resolved.
Show resolved Hide resolved

def UNPACK_SEQUENCE(self, instr: Instruction):
sequence = self.stack.pop()
Expand Down
38 changes: 22 additions & 16 deletions sot/opcode_translator/executor/pycode_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ def gen_resume_fn_at(self, index: int, stack_size: int = 0):
gen_instr('LOAD_FAST', argval=stack_arg_str.format(i))
for i in range(stack_size)
]
+ [gen_instr('JUMP_ABSOLUTE', jump_to=self._instructions[index])]
+ [gen_instr('JUMP_FORWARD', jump_to=self._instructions[index])]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里用 JUMP_FORWARD 应该需要确保一点,它是向前跳的

这里应该没问题,因为是函数初始跳转到中间 breakgraph 处

+ self._instructions
)

Expand Down Expand Up @@ -844,23 +844,29 @@ def gen_pop_top(self):
def gen_rot_n(self, n):
if n <= 1:
return
if n <= 4:
self._add_instr("ROT_" + ["TWO", "THREE", "FOUR"][n - 2])
if sys.version_info >= (3, 11):
for i in range(n, 1, -1):
self._add_instr("SWAP", arg=i)
elif sys.version_info >= (3, 10):
self._add_instr("ROT_N", arg=n)
else:
if n <= 4:
self._add_instr("ROT_" + ["TWO", "THREE", "FOUR"][n - 2])
else:

def rot_n_fn(n):
vars = [f"var{i}" for i in range(n)]
rotated = reversed(vars[-1:] + vars[:-1])
fn = eval(f"lambda {','.join(vars)}: ({','.join(rotated)})")
fn = no_eval_frame(fn)
fn.__name__ = f"rot_{n}_fn"
return fn

self.gen_build_tuple(n)
self.gen_load_const(rot_n_fn(n))
self.gen_rot_n(2)
self._add_instr("CALL_FUNCTION_EX", arg=0)
self.gen_unpack_sequence(n)
def rot_n_fn(n):
vars = [f"var{i}" for i in range(n)]
rotated = reversed(vars[-1:] + vars[:-1])
fn = eval(f"lambda {','.join(vars)}: ({','.join(rotated)})")
fn = no_eval_frame(fn)
fn.__name__ = f"rot_{n}_fn"
return fn

self.gen_build_tuple(n)
self.gen_load_const(rot_n_fn(n))
self.gen_rot_n(2)
self._add_instr("CALL_FUNCTION_EX", arg=0)
self.gen_unpack_sequence(n)

def gen_return(self):
self._add_instr("RETURN_VALUE")
Expand Down
6 changes: 6 additions & 0 deletions sot/opcode_translator/executor/variables/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,12 @@ def get_py_type(self):
"""
return type(self.get_py_value())

def is_none(self) -> bool:
"""
Method to check if the variable's value is None
"""
return self.get_py_value() is None

def reconstruct(
self,
codegen: PyCodeGen,
Expand Down
2 changes: 1 addition & 1 deletion tests/run_all.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ py311_skiped_tests=(
# ./test_04_list.py There are some case need to be fixed
# ./test_05_dict.py There are some case need to be fixed
./test_10_build_unpack.py
./test_11_jumps.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

部分完成的像上面加一个注释吧

# ./test_11_jumps.py There are some case need to be fixed
./test_12_for_loop.py
./test_14_operators.py
./test_15_slice.py
Expand Down
56 changes: 45 additions & 11 deletions tests/test_11_jumps.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from __future__ import annotations

import sys
import unittest

from test_case_base import TestCaseBase

import paddle
from sot.psdb import check_no_breakgraph


@check_no_breakgraph
def pop_jump_if_false(x: bool, y: paddle.Tensor):
if x:
y += 1
Expand All @@ -15,25 +18,47 @@ def pop_jump_if_false(x: bool, y: paddle.Tensor):
return y


@check_no_breakgraph
def pop_jump_if_true(x: bool, y: bool, z: paddle.Tensor):
return (x or y) and z


@check_no_breakgraph
def jump_if_false_or_pop(x: bool, y: paddle.Tensor):
return x and (y + 1)


@check_no_breakgraph
def jump_if_true_or_pop(x: bool, y: paddle.Tensor):
return x or (y + 1)


def pop_jump_if_true(x: bool, y: bool, z: paddle.Tensor):
return (x or y) and z


@check_no_breakgraph
def jump_absolute(x: int, y: paddle.Tensor):
while x > 0:
y += 1
x -= 1
return y


@check_no_breakgraph
def pop_jump_if_none(x: bool, y: paddle.Tensor):
if x is not None:
y += 1
else:
y -= 1
return y


@check_no_breakgraph
def pop_jump_if_not_none(x: bool, y: paddle.Tensor):
if x is None:
y += 1
else:
y -= 1
return y


a = paddle.to_tensor(1)
b = paddle.to_tensor(2)
c = paddle.to_tensor(3)
Expand All @@ -45,18 +70,26 @@ def jump_absolute(x: int, y: paddle.Tensor):

class TestExecutor(TestCaseBase):
def test_simple(self):
self.assert_results(pop_jump_if_false, True, a)
self.assert_results(jump_if_false_or_pop, True, a)
self.assert_results(jump_if_true_or_pop, False, a)
self.assert_results(pop_jump_if_true, True, False, a)
self.assert_results(jump_absolute, 5, a)

self.assert_results(pop_jump_if_false, True, a)
self.assert_results(pop_jump_if_false, False, a)
self.assert_results(jump_if_false_or_pop, True, a)
self.assert_results(jump_if_false_or_pop, False, a)
self.assert_results(jump_if_true_or_pop, True, a)
self.assert_results(jump_if_true_or_pop, False, a)
self.assert_results(pop_jump_if_true, True, False, a)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里和上面是不是有一部分重复了?按理说两个单测 case 应该测不同的分支

self.assert_results(pop_jump_if_true, False, False, a)

def test_fallback(self):
self.assert_results(pop_jump_if_none, None, a)
self.assert_results(pop_jump_if_none, True, a)
self.assert_results(pop_jump_if_not_none, None, a)
self.assert_results(pop_jump_if_not_none, True, a)

@unittest.skipIf(
sys.version_info >= (3, 11), "Python 3.11+ not support breakgraph."
)
def test_breakgraph(self):
self.assert_results(pop_jump_if_false, true_tensor, a)
self.assert_results(jump_if_false_or_pop, true_tensor, a)
self.assert_results(jump_if_true_or_pop, false_tensor, a)
Expand All @@ -67,8 +100,9 @@ def test_fallback(self):
self.assert_results(jump_if_true_or_pop, false_tensor, a)
self.assert_results(pop_jump_if_true, true_tensor, false_tensor, a)

self.assert_results(pop_jump_if_none, true_tensor, a)
self.assert_results(pop_jump_if_not_none, true_tensor, a)


if __name__ == "__main__":
unittest.main()

# TODO: JUMP_FORWARD