Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
del code status
Browse files Browse the repository at this point in the history
  • Loading branch information
feifei-111 committed Oct 13, 2023
1 parent 5325894 commit 2d2afaa
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 141 deletions.
1 change: 0 additions & 1 deletion sot/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from .code_status import CodeStatus # noqa: F401
from .exceptions import ( # noqa: F401
BreakGraphError,
FallbackError,
Expand Down
280 changes: 140 additions & 140 deletions tests/test_code_status.py
Original file line number Diff line number Diff line change
@@ -1,140 +1,140 @@
import unittest

from test_case_base import TestCaseBase, strict_mode_guard

import paddle
import sot
from sot.opcode_translator.skip_files import skip_function
from sot.utils.code_status import CodeState, CodeStatus


class SimpleNet1(paddle.nn.Layer):
def __init__(self):
super().__init__()
self.layers = paddle.nn.LayerList(
[paddle.nn.Linear(10, 10) for _ in range(30)]
)

def forward(self, x):
for i in range(len(self.layers)):
sot.psdb.breakgraph()
x = self.layers[i](x)
x = self.layers[i](x)
x = self.layers[i](x)
x = self.layers[i](x)
return x


class SimpleNet2(paddle.nn.Layer):
def __init__(self):
super().__init__()
self.layers = paddle.nn.LayerList(
[paddle.nn.Linear(10, 10) for _ in range(30)]
)

def forward(self, x):
sot.psdb.fallback()
for i in range(len(self.layers)):
x = self.layers[i](x)
x = self.layers[i](x)
x = self.layers[i](x)
x = self.layers[i](x)
return x


def run_net(net, x):
for i in range(20):
x = net(x)
return x


class TestCodeInfo(TestCaseBase):
def test_case_1(self):
CodeStatus().clear()
net = SimpleNet1()
inp = paddle.rand((10, 10))
self.assert_results(run_net, net, inp)
code_map = CodeStatus().code_map
states = []
for k, v in code_map.items():
if k.co_name.startswith("#") or k.co_name.startswith("$"):
states.append(v)
elif k in CodeStatus().WITH_GRAPH_API:
assert v.state == CodeState.WITH_GRAPH
else:
assert v.state == CodeState.WITHOUT_GRAPH
# run_net, forward, loop body, resumed part2 in loop body
assert len([v for v in states if v.state == CodeState.WITH_GRAPH]) == 4
# resumed part1 in loop body
assert (
len([v for v in states if v.state == CodeState.WITHOUT_GRAPH]) == 1
)

def test_case_2(self):
with strict_mode_guard(0):
CodeStatus().clear()
net = SimpleNet2()
inp = paddle.rand((10, 10))
self.assert_results(run_net, net, inp)
code_map = CodeStatus().code_map
states = []
for k, v in code_map.items():
if k.co_name.startswith("#") or k.co_name.startswith("$"):
states.append(v)
elif k in CodeStatus().WITH_GRAPH_API:
assert v.state == CodeState.WITH_GRAPH
else:
assert v.state == CodeState.WITHOUT_GRAPH
# no graph found because fallback (paddle api will not enter simulate)
assert (
len([v for v in states if v.state == CodeState.WITH_GRAPH]) == 0
)


def no_skip_func_0(x):
return x + 1


def skipped_func_0():
pass


def skipped_func_1(x):
return x + 1


def skipped_func_2(x):
return no_skip_func_0(x)


def call_skipped_func_0(x):
for i in range(15):
skipped_func_0()
x = skipped_func_1(x)
x = skipped_func_2(x)
return x


skip_function(skipped_func_0)
skip_function(skipped_func_1)
skip_function(skipped_func_2)
skip_function(call_skipped_func_0)


class TestDisableSkippedFrame(TestCaseBase):
def test_case_0(self):
CodeStatus().clear()
x = paddle.to_tensor([1])
self.assert_results(call_skipped_func_0, x)
code_map = CodeStatus().code_map
assert (
code_map[skipped_func_0.__code__].state == CodeState.WITHOUT_GRAPH
)
assert (
code_map[skipped_func_1.__code__].state == CodeState.WITHOUT_GRAPH
)
assert code_map[skipped_func_2.__code__].state == CodeState.WITH_GRAPH


if __name__ == "__main__":
unittest.main()
# import unittest

# from test_case_base import TestCaseBase, strict_mode_guard

# import paddle
# import sot
# from sot.opcode_translator.skip_files import skip_function
# from sot.utils.code_status import CodeState, CodeStatus


# class SimpleNet1(paddle.nn.Layer):
# def __init__(self):
# super().__init__()
# self.layers = paddle.nn.LayerList(
# [paddle.nn.Linear(10, 10) for _ in range(30)]
# )

# def forward(self, x):
# for i in range(len(self.layers)):
# sot.psdb.breakgraph()
# x = self.layers[i](x)
# x = self.layers[i](x)
# x = self.layers[i](x)
# x = self.layers[i](x)
# return x


# class SimpleNet2(paddle.nn.Layer):
# def __init__(self):
# super().__init__()
# self.layers = paddle.nn.LayerList(
# [paddle.nn.Linear(10, 10) for _ in range(30)]
# )

# def forward(self, x):
# sot.psdb.fallback()
# for i in range(len(self.layers)):
# x = self.layers[i](x)
# x = self.layers[i](x)
# x = self.layers[i](x)
# x = self.layers[i](x)
# return x


# def run_net(net, x):
# for i in range(20):
# x = net(x)
# return x


# class TestCodeInfo(TestCaseBase):
# def test_case_1(self):
# CodeStatus().clear()
# net = SimpleNet1()
# inp = paddle.rand((10, 10))
# self.assert_results(run_net, net, inp)
# code_map = CodeStatus().code_map
# states = []
# for k, v in code_map.items():
# if k.co_name.startswith("#") or k.co_name.startswith("$"):
# states.append(v)
# elif k in CodeStatus().WITH_GRAPH_API:
# assert v.state == CodeState.WITH_GRAPH
# else:
# assert v.state == CodeState.WITHOUT_GRAPH
# # run_net, forward, loop body, resumed part2 in loop body
# assert len([v for v in states if v.state == CodeState.WITH_GRAPH]) == 4
# # resumed part1 in loop body
# assert (
# len([v for v in states if v.state == CodeState.WITHOUT_GRAPH]) == 1
# )

# def test_case_2(self):
# with strict_mode_guard(0):
# CodeStatus().clear()
# net = SimpleNet2()
# inp = paddle.rand((10, 10))
# self.assert_results(run_net, net, inp)
# code_map = CodeStatus().code_map
# states = []
# for k, v in code_map.items():
# if k.co_name.startswith("#") or k.co_name.startswith("$"):
# states.append(v)
# elif k in CodeStatus().WITH_GRAPH_API:
# assert v.state == CodeState.WITH_GRAPH
# else:
# assert v.state == CodeState.WITHOUT_GRAPH
# # no graph found because fallback (paddle api will not enter simulate)
# assert (
# len([v for v in states if v.state == CodeState.WITH_GRAPH]) == 0
# )


# def no_skip_func_0(x):
# return x + 1


# def skipped_func_0():
# pass


# def skipped_func_1(x):
# return x + 1


# def skipped_func_2(x):
# return no_skip_func_0(x)


# def call_skipped_func_0(x):
# for i in range(15):
# skipped_func_0()
# x = skipped_func_1(x)
# x = skipped_func_2(x)
# return x


# skip_function(skipped_func_0)
# skip_function(skipped_func_1)
# skip_function(skipped_func_2)
# skip_function(call_skipped_func_0)


# class TestDisableSkippedFrame(TestCaseBase):
# def test_case_0(self):
# CodeStatus().clear()
# x = paddle.to_tensor([1])
# self.assert_results(call_skipped_func_0, x)
# code_map = CodeStatus().code_map
# assert (
# code_map[skipped_func_0.__code__].state == CodeState.WITHOUT_GRAPH
# )
# assert (
# code_map[skipped_func_1.__code__].state == CodeState.WITHOUT_GRAPH
# )
# assert code_map[skipped_func_2.__code__].state == CodeState.WITH_GRAPH


# if __name__ == "__main__":
# unittest.main()

0 comments on commit 2d2afaa

Please sign in to comment.