-
Notifications
You must be signed in to change notification settings - Fork 16
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
support more base Instructions and support resnet #41
Conversation
examples/simple_dynamo.py
Outdated
@@ -31,6 +31,7 @@ def func(a, b): | |||
in_a = paddle.rand([3, 4]) | |||
in_b = paddle.rand([3, 4]) | |||
out = paddle.add(in_a, in_b) | |||
# out = paddle.add(out, out) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
无用的注释需要删掉,其他地方还有多处,以及无用的 print
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的,print和一些代码注释我都删掉啦
|
||
original_res = func(in_a, in_b) | ||
optimized_res = optimized_func(in_a, in_b) | ||
np.testing.assert_equal(original_res.numpy(), optimized_res.numpy()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
据我所知,目前还是跑原来的字节码吧?所以这里对比貌似没啥意义?@gglin001
可以加一个 NOTE 或者 TODO 在这里~
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
目前还是跑原来的字节码吧?所以这里对比貌似没啥意义?
是的, 目前的对比只是确保 返回了原始的 code, 不是trace 到后转换的 code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
据我所知,目前还是跑原来的字节码吧?所以这里对比貌似没啥意义?@gglin001
可以加一个 NOTE 或者 TODO 在这里~
这块我加了一个# TODO(zrr1999): optimized_res
is the result of running the converted bytecode in the future.
code = frame.f_code | ||
for paddle_module in paddle_modules: | ||
if package_name.startswith(paddle_module): | ||
return GuardedCode(code) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
在_compile时跳过 paddle模块里的函数,从而可以支持paddle.add 这种,而不进入执行动态图和静态图的分支
同上,这是因为目前只能跑原来的字节码,如果跑转换后的字节码理应是不会进入这些函数的 Eval Frame 里的,不过这个 PR 用于验证 ResNet 所需要的字节码的支持完备性是可以暂时这样的~在之后跑转换后的字节码时这部分逻辑应该可以删掉
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这块我也新加了个TODO,This part can be removed when running the converted bytecode in the future.
src/paddlefx/translator.py
Outdated
def pop(self): | ||
return self.stack.pop() | ||
|
||
def append(self, item): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
根据栈的语义,取名 push 会更好?
代码里若干处 self.stack.append
可统一替换为 self.push
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的,这部分已经替换完毕
src/paddlefx/translator.py
Outdated
def IS_OP(self, inst: Instruction): | ||
args = list(reversed([self.pop() for _ in range(2)])) | ||
res = self.output.create_node('call_function', operator.is_, args, {}) | ||
self.stack.append(res) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个是否可以统一到 BINARY_MAPPER
呢,看起来可以复用 _binary_constructor
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
以及各处的 list(reversed([self.pop() for _ in range(n)]))
逻辑应该可以复用 self.popn(n)
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
list(reversed([self.pop() for _ in range(n)]))替换成了self.popn(n, reverse=True),IS_OP加入了BINARY_MAPPER
src/paddlefx/translator.py
Outdated
if k == "self": | ||
self.f_locals[k] = self.output._proxy_placeholder(k) | ||
else: | ||
self.f_locals[k] = self.output._proxy_placeholder(k) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
emmm,这俩分支有区别嘛?我好像没看粗来:joy:?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这块之前是因为跟_generate_forward那块的self有些冲突,后来实验了很多次,然后忘了改回来,现在是直接在 _generate_forward 里判断有没有self
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
emmm,我本地尝试跑了一下,dynamo trace 得到的结果貌似有点奇怪?
my_compiler() called with FX graph:
opcode name target args kwargs
------------- ---------- --------------------------------- ----------------------------------- --------
placeholder self self () {}
placeholder x x () {}
call_function getattr_1 <built-in function getattr> (self, 'conv1') {}
call_function getattr_2 <built-in function getattr> (getattr_1, 'forward') {}
call_function getattr_3 <built-in function getattr> (getattr_2, '__name__') {}
call_function getattr_3 Proxy(getattr_2) [Proxy(x)] {}
call_function getattr_4 <built-in function getattr> (self, 'bn1') {}
call_function getattr_5 <built-in function getattr> (getattr_4, 'forward') {}
call_function getattr_6 <built-in function getattr> (getattr_5, '__name__') {}
call_function getattr_6 Proxy(getattr_5) [getattr_3] {}
call_function getattr_7 <built-in function getattr> (self, 'relu') {}
call_function getattr_8 <built-in function getattr> (getattr_7, 'forward') {}
call_function getattr_9 <built-in function getattr> (getattr_8, '__name__') {}
call_function getattr_9 Proxy(getattr_8) [getattr_6] {}
call_function getattr_10 <built-in function getattr> (self, 'maxpool') {}
call_function getattr_11 <built-in function getattr> (getattr_10, 'forward') {}
call_function getattr_12 <built-in function getattr> (getattr_11, '__name__') {}
call_function getattr_12 Proxy(getattr_11) [getattr_9] {}
call_function getattr_13 <built-in function getattr> (self, 'layer1') {}
call_function getattr_14 <built-in function getattr> (getattr_13, 'forward') {}
call_function getattr_15 <built-in function getattr> (getattr_14, '__name__') {}
call_function getattr_15 Proxy(getattr_14) [getattr_12] {}
call_function getattr_16 <built-in function getattr> (self, 'layer2') {}
call_function getattr_17 <built-in function getattr> (getattr_16, 'forward') {}
call_function getattr_18 <built-in function getattr> (getattr_17, '__name__') {}
call_function getattr_18 Proxy(getattr_17) [getattr_15] {}
call_function getattr_19 <built-in function getattr> (self, 'layer3') {}
call_function getattr_20 <built-in function getattr> (getattr_19, 'forward') {}
call_function getattr_21 <built-in function getattr> (getattr_20, '__name__') {}
call_function getattr_21 Proxy(getattr_20) [getattr_18] {}
call_function getattr_22 <built-in function getattr> (self, 'layer4') {}
call_function getattr_23 <built-in function getattr> (getattr_22, 'forward') {}
call_function getattr_24 <built-in function getattr> (getattr_23, '__name__') {}
call_function getattr_24 Proxy(getattr_23) [getattr_21] {}
call_function getattr_25 <built-in function getattr> (self, 'avgpool') {}
call_function getattr_26 <built-in function getattr> (getattr_25, 'forward') {}
call_function getattr_27 <built-in function getattr> (getattr_26, '__name__') {}
call_function getattr_27 Proxy(getattr_26) [getattr_24] {}
call_function gt <built-in function gt> [0, Proxy(getattr_31)] {}
call_function flatten_1 <function flatten at 0x125377280> [1, getattr_27] {}
call_function getattr_28 <built-in function getattr> (self, 'fc') {}
call_function getattr_29 <built-in function getattr> (getattr_28, 'forward') {}
call_function getattr_30 <built-in function getattr> (getattr_29, '__name__') {}
call_function getattr_30 Proxy(getattr_29) [flatten_1] {}
output output output [Proxy(getattr_32), gt, getattr_30] {}
call_function getattr_31 <built-in function getattr> (self, 'num_classes') {}
call_function getattr_32 <built-in function getattr> (self, 'with_pool') {}
好像和直接 FX trace 出来的差的有点多,my_compiler
里加上 gl.forward(paddle.rand([2, 3, 224, 224]))
好像也跑不起来?
而且为什么 args 里既有 tuple 又有 list 呢?
src/paddlefx/translator.py
Outdated
self.push(None) | ||
|
||
def CALL_FUNCTION(self, inst: Instruction): | ||
args = [self.pop() for _ in range(inst.argval)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里是不是反了?貌似没有地方有 reverse=False
的需求?
比如对于如下 function call:
import dis
def foo():
bar(1, 2, 3)
dis.dis(foo)
字节码如下:
24 0 LOAD_GLOBAL 0 (bar)
2 LOAD_CONST 1 (1)
4 LOAD_CONST 2 (2)
6 LOAD_CONST 3 (3)
8 CALL_FUNCTION 3
10 POP_TOP
12 LOAD_CONST 0 (None)
14 RETURN_VALUE
参数 1、2、3 依次入栈,依次出栈将会是反的,因此需要 reverse 才可以,BUILD_TUPLE
等都是相同的,入栈是依次的,出栈后需要 reverse
细节可以参考 dis 文档、Python - ceval.c source
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的,这块我已经都改成统一的反过来出栈啦
python codegen这块貌似还有些问题,我明天看看能不能搞明白,现在output后面还有其他的node就很奇怪。print_table我修复了一下,之前用成了call_function,现在是call_module应该是跟resnet_trace里的基本一致了,args的我在graph加上了转换tuple的代码 |
现在应该可以了 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, 👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM~
添加
paddlefx.optimize 可以捕获ResNet模型,将其转为fx graph。
在_compile时跳过 paddle模块里的函数,从而可以支持paddle.add 这种,而不进入执行动态图和静态图的分支