-
Notifications
You must be signed in to change notification settings - Fork 15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
support more base Instructions and support resnet #41
Changes from all commits
ce34e9b
2ac5907
9504f12
8d22965
f082955
6f806e5
5f21cb1
5598d96
590ebae
75b8aac
b561392
c7283d9
dca017f
c0c7606
b0c3759
245ce49
e584690
76c9739
f4c3e1e
53bde89
4e4ff49
c644f49
0e2c2e6
d92f263
ae09773
385ef3b
f838b77
71a76e0
fbcb26b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
from __future__ import annotations | ||
|
||
import numpy as np | ||
import paddle | ||
import paddle.nn | ||
import paddle.tensor | ||
|
||
from paddle.vision.models import resnet18 | ||
|
||
import paddlefx | ||
|
||
|
||
def my_compiler(gl: paddlefx.GraphLayer, example_inputs: list[paddle.Tensor] = None): | ||
print("my_compiler() called with FX graph:") | ||
print(gl.get_source()) | ||
gl.graph.print_tabular(print_mode="rich") | ||
return gl.forward | ||
|
||
|
||
net = resnet18() | ||
optimized_net = paddlefx.optimize(my_compiler)(net) | ||
|
||
x = paddle.rand([1, 3, 224, 224]) | ||
out = net(x) | ||
res = optimized_net(x) | ||
|
||
np.testing.assert_equal(res.numpy(), out.numpy()) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,11 +2,13 @@ | |
|
||
import dataclasses | ||
import dis | ||
import inspect | ||
import types | ||
|
||
from typing import Callable | ||
|
||
import paddle | ||
import paddle.nn | ||
|
||
from ._eval_frame import set_eval_frame | ||
from .translator import InstructionTranslator, convert_instruction | ||
|
@@ -24,6 +26,7 @@ def __init__(self, callback): | |
|
||
def __enter__(self): | ||
self.old_callback = set_eval_frame(self.callback) | ||
return self | ||
|
||
def __exit__(self, exc_type, exc_value, traceback): | ||
set_eval_frame(self.old_callback) | ||
|
@@ -45,7 +48,22 @@ def _compile( | |
frame: types.FrameType, | ||
compiler_fn: Callable, | ||
): | ||
# TODO(zrr1999): This part can be removed when running the converted bytecode in the future. | ||
paddle_modules = [ | ||
"paddle.nn", | ||
"paddle.fluid", | ||
"paddle.tensor", | ||
# TODO(zrr1999): add more modules | ||
] | ||
module = inspect.getmodule(frame) | ||
if module is None: | ||
raise RuntimeError('Cannot find module for frame') | ||
package_name = module.__name__ | ||
|
||
code = frame.f_code | ||
for paddle_module in paddle_modules: | ||
if package_name.startswith(paddle_module): | ||
return GuardedCode(code) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
同上,这是因为目前只能跑原来的字节码,如果跑转换后的字节码理应是不会进入这些函数的 Eval Frame 里的,不过这个 PR 用于验证 ResNet 所需要的字节码的支持完备性是可以暂时这样的~在之后跑转换后的字节码时这部分逻辑应该可以删掉 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这块我也新加了个TODO,This part can be removed when running the converted bytecode in the future. |
||
instructions = list(map(convert_instruction, dis.get_instructions(code))) | ||
|
||
tracer = InstructionTranslator(instructions, frame, compiler_fn) | ||
|
@@ -64,9 +82,6 @@ def has_tensor_in_frame(frame: types.FrameType) -> bool: | |
if frame.f_code.co_name == 'in_dygraph_mode': | ||
return False | ||
|
||
# print(frame) | ||
# print(dis.disassemble(frame.f_code)) | ||
|
||
for v in frame.f_locals.values(): | ||
# TODO: supprt containers | ||
if isinstance(v, paddle.Tensor): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
据我所知,目前还是跑原来的字节码吧?所以这里对比貌似没啥意义?@gglin001
可以加一个 NOTE 或者 TODO 在这里~
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是的, 目前的对比只是确保 返回了原始的 code, 不是trace 到后转换的 code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这块我加了一个# TODO(zrr1999):
optimized_res
is the result of running the converted bytecode in the future.