Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support more base Instructions and support resnet #41

Merged
merged 29 commits into from
May 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ jobs:

- name: Install dependencies
run: |
pip install --upgrade pip
pip install -r requirements_dev.txt

- name: Build
Expand Down
17 changes: 8 additions & 9 deletions examples/graph_editing.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,17 @@ def net(x, y):
graph = traced_layer.graph

print("Before editing:")
graph.print_tabular()
print(traced_layer.get_source())

for node in graph.nodes:
if node.op == 'call_function':
with graph.inserting_after(node):
new_node = graph.create_node(
node.op, paddle.add, args=(node.args[0], node.args[0]), kwargs={}
)
node.replace_all_uses_with(new_node)
graph.erase_node(node)
break

with graph.inserting_after(node):
new_node = graph.create_node(
node.op, paddle.add, args=(node.args[0], node.args[0]), kwargs={}
)
node.replace_all_uses_with(new_node)
graph.erase_node(node)

print("After editing:")
graph.print_tabular()
print(traced_layer.get_source())
27 changes: 27 additions & 0 deletions examples/resnet_dynamo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from __future__ import annotations

import numpy as np
import paddle
import paddle.nn
import paddle.tensor

from paddle.vision.models import resnet18

import paddlefx


def my_compiler(gl: paddlefx.GraphLayer, example_inputs: list[paddle.Tensor] = None):
print("my_compiler() called with FX graph:")
print(gl.get_source())
gl.graph.print_tabular(print_mode="rich")
return gl.forward


net = resnet18()
optimized_net = paddlefx.optimize(my_compiler)(net)

x = paddle.rand([1, 3, 224, 224])
out = net(x)
res = optimized_net(x)

np.testing.assert_equal(res.numpy(), out.numpy())
3 changes: 1 addition & 2 deletions examples/resnet_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,5 @@
assert paddle.allclose(orig_output, traced_output)

print(f"python IR for {type(net).__name__}")
print(traced_layer.get_source())
traced_layer.graph.print_tabular(print_mode="tabulate")
traced_layer.graph.print_tabular(print_mode="rich")
traced_layer.graph.print_tabular(print_mode="raw")
22 changes: 22 additions & 0 deletions examples/simple_dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,25 @@ def inplace(a, b):
optimized_res = optimized_foo(in_a, in_b)

np.testing.assert_equal(original_res.numpy(), optimized_res.numpy())


class ExampleNet(paddle.nn.Layer):
def __init__(self):
super().__init__()
self.fc = [paddle.nn.Linear(1, 1), paddle.nn.Linear(1, 1)]

def forward(self, a, b):
c = self.fc[0](a)
d = self.fc[1](b)
e = paddle.add(c, d)
return e


net = ExampleNet()

optimized_func = paddlefx.optimize(my_compiler)(net)

original_res = net(in_a, in_b)
optimized_res = optimized_func(in_a, in_b)
# TODO(zrr1999): `optimized_res` is the result of running the converted bytecode in the future.
np.testing.assert_equal(original_res.numpy(), optimized_res.numpy())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

据我所知,目前还是跑原来的字节码吧?所以这里对比貌似没啥意义?@gglin001

可以加一个 NOTE 或者 TODO 在这里~

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

目前还是跑原来的字节码吧?所以这里对比貌似没啥意义?

是的, 目前的对比只是确保 返回了原始的 code, 不是trace 到后转换的 code

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

据我所知,目前还是跑原来的字节码吧?所以这里对比貌似没啥意义?@gglin001

可以加一个 NOTE 或者 TODO 在这里~

这块我加了一个# TODO(zrr1999): optimized_res is the result of running the converted bytecode in the future.

21 changes: 18 additions & 3 deletions src/paddlefx/eval_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@

import dataclasses
import dis
import inspect
import types

from typing import Callable

import paddle
import paddle.nn

from ._eval_frame import set_eval_frame
from .translator import InstructionTranslator, convert_instruction
Expand All @@ -24,6 +26,7 @@ def __init__(self, callback):

def __enter__(self):
self.old_callback = set_eval_frame(self.callback)
return self

def __exit__(self, exc_type, exc_value, traceback):
set_eval_frame(self.old_callback)
Expand All @@ -45,7 +48,22 @@ def _compile(
frame: types.FrameType,
compiler_fn: Callable,
):
# TODO(zrr1999): This part can be removed when running the converted bytecode in the future.
paddle_modules = [
"paddle.nn",
"paddle.fluid",
"paddle.tensor",
# TODO(zrr1999): add more modules
]
module = inspect.getmodule(frame)
if module is None:
raise RuntimeError('Cannot find module for frame')
package_name = module.__name__

code = frame.f_code
for paddle_module in paddle_modules:
if package_name.startswith(paddle_module):
return GuardedCode(code)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在_compile时跳过 paddle模块里的函数,从而可以支持paddle.add 这种,而不进入执行动态图和静态图的分支

同上,这是因为目前只能跑原来的字节码,如果跑转换后的字节码理应是不会进入这些函数的 Eval Frame 里的,不过这个 PR 用于验证 ResNet 所需要的字节码的支持完备性是可以暂时这样的~在之后跑转换后的字节码时这部分逻辑应该可以删掉

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这块我也新加了个TODO,This part can be removed when running the converted bytecode in the future.

instructions = list(map(convert_instruction, dis.get_instructions(code)))

tracer = InstructionTranslator(instructions, frame, compiler_fn)
Expand All @@ -64,9 +82,6 @@ def has_tensor_in_frame(frame: types.FrameType) -> bool:
if frame.f_code.co_name == 'in_dygraph_mode':
return False

# print(frame)
# print(dis.disassemble(frame.f_code))

for v in frame.f_locals.values():
# TODO: supprt containers
if isinstance(v, paddle.Tensor):
Expand Down
34 changes: 28 additions & 6 deletions src/paddlefx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,18 @@ def snake_case(s):


def _qualified_name(func):
if hasattr(func, 'node'):
name = func.node.name
elif hasattr(func, '__name__'):
name = func.__name__
elif hasattr(func, 'name'):
name = func.name
else:
raise NotImplementedError(f'cannot get name of {func}')

# things like getattr just appear in builtins
if getattr(builtins, func.__name__, None) is func:
return func.__name__
name = func.__name__
if getattr(builtins, name, None) is func:
return name
module = _find_module_of_method(func)
return f'{module}.{name}'

Expand All @@ -42,7 +50,10 @@ def _is_illegal_name(name: str, obj: Any) -> bool:


def _find_module_of_method(orig_method):
name = orig_method.__name__
if hasattr(orig_method, '__name__'):
name = orig_method.__name__
else:
name = orig_method.__class__.__name__
module = orig_method.__module__
if module is not None:
return module
Expand Down Expand Up @@ -138,7 +149,7 @@ def create_node(self, op, target=None, args=None, kwargs=None, name=None):
'placeholder',
'output',
)
args = () if args is None else args
args = () if args is None else tuple(args)
kwargs = {} if kwargs is None else kwargs
name = name if name is not None else self._name(target or op)
if name[0].isdigit():
Expand All @@ -161,6 +172,10 @@ def output(self, result):
def _name(self, op):
if hasattr(op, '__name__'):
op = op.__name__
if hasattr(op, 'name'):
op = op.name
if hasattr(op, 'node'):
op = op.node.name

if _is_magic(op):
op = op[2:-2]
Expand All @@ -185,6 +200,11 @@ def get_param(self, target):
def placeholder(self, name):
return self.create_node('placeholder', target=name, name=name.replace('*', ''))

def call_module(self, target, args, kwargs):
return self.create_node(
'call_module', target, args, kwargs, name=target.replace('.', '_')
)

def erase_node(self, to_erase: Node) -> None:
if len(to_erase.users) > 0:
raise RuntimeError(
Expand Down Expand Up @@ -281,7 +301,9 @@ def print_tabular(self, print_mode="tabulate"):
"""Prints the intermediate representation of the graph in tabular
format.

Note that this API requires the ``tabulate`` module to be installed.
Note that this API allows users to choose between using the ``raw``,
``tabulate`` or ``rich`` mode. If the user specifies a mode that is not
installed, the API will automatically fall back on the ``raw`` mode.
"""
assert print_mode in ["raw", "tabulate", "rich"]
if print_mode == "raw":
Expand Down
9 changes: 8 additions & 1 deletion src/paddlefx/graph_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,11 @@ def __init__(self, root, graph: Graph):

def _generate_forward(self):
body, free_variables = self.graph.python_code(root_module='self')
if "self" not in free_variables:
free_variables.insert(0, "self")
body = '\n'.join(' ' + line for line in body.split('\n')) + '\n'
self.src = f"""\
def forward(self, {', '.join(free_variables)}):
def forward({', '.join(free_variables)}):
self = self.root
{body}
"""
Expand All @@ -82,6 +84,11 @@ def forward(self, {', '.join(free_variables)}):
for k, v in gbls.items():
setattr(cls, k, v)

def get_source(self, update: bool = True):
if update:
self._generate_forward()
return self.src


# copy an attribute value with qualified name 'target' from 'from_module' to 'to_module'
# This installs empty Modules where none exist yet if they are subpaths of target
Expand Down
7 changes: 5 additions & 2 deletions src/paddlefx/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(self, node: Node, tracer: Tracer):
self.tracer = tracer

def __repr__(self):
return f'Proxy({self.node.name})'
return f'{self.node.name}'

def __getattr__(self, k):
# note: not added to the graph yet, if this is a method call
Expand All @@ -45,7 +45,7 @@ def __iter__(self):
if current_instruction.opname == "UNPACK_SEQUENCE":
return (self[i] for i in range(current_instruction.argval))
elif current_instruction.opname == "GET_ITER":
raise NotImplementedError()
return (self[i] for i in range(current_instruction.argval))
raise ValueError("Cannot find UNPACK_SEQUENCE instruction")


Expand All @@ -66,6 +66,9 @@ def node(self):
).node
return self._node

def __str__(self):
return f'{self.root}.{self.node.name}'

def __call__(self, *args, **kwargs):
return _create_proxy(
self.tracer, 'call_method', self.attr, (self.root,) + args, kwargs
Expand Down
2 changes: 1 addition & 1 deletion src/paddlefx/symbolic_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def _proxy_placeholder(self, name):
n = self.graph.create_node('placeholder', name, (), {})
return Proxy(n, self)

def create_node(self, op, target, args, kwargs, name=None):
def create_node(self, op, target, args=None, kwargs=None, name=None):
return self.graph.create_node(op, target, args, kwargs, name)

def create_arg(self, a):
Expand Down
Loading