-
Notifications
You must be signed in to change notification settings - Fork 109
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor dualmodule._modules #604
Conversation
之所以不复制 dict 而是直接指向 torch module 的 dict,是因为考虑到 torch module 改动可以直接同步到 dual module 这里 |
对torch module的改动没法同步到oneflow module上吧 |
如果 oneflow module 还没有创建,是可以在创建的时候同步过去的;如果已经创建了,修改 torch module 意味着计算图改变,这个时候还是得重新编译吧。 |
torch_module, oneflow_module | ||
) | ||
)) | ||
object.__setattr__(self, "_modules", torch_module._modules) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这么设置 model=oneflow_compile(model) ; model.state_dict() 是怎样的? 和以前兼容吗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这么设置 model=oneflow_compile(model) ; model.state_dict() 是怎样的? 和以前兼容吗
理论上就不会有 _deployable_module_model._torch_module 这一层了,那个 hook 就不需要了。不过我还是测试一下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这么设置 model=oneflow_compile(model) ; model.state_dict() 是怎样的? 和以前兼容吗
对比了一下 key,是一样的
对于一些 module,经过 oneflow_compile 之后,丢失了 `__dict__` 中的信息。 比如 Conv2d,变成 MixedDualModule 之后,in_channels 实际上存到了 `MixedDualModule._torch_module.__dict__` 里面了。而 `MixedDualModule.extra_repr` 会在当前的 `__dict__` 里面找 in_channels 这些参数,当然找不到,从而 print 时候报错,很影响 debug。 这个 PR 把 DualModule 的 extra_repr 定向到了 torch module,这样就可以找到了。 另外在 repr 返回类名的时候,把被编译的类名也加了进去,能知道是什么 module 被编译了。 修改前后对比如下所示: before: ```python In [1]: import torch In [2]: from onediff.infer_compiler.with_oneflow_compile import oneflow_compile In [3]: c = torch.nn.Conv2d(3, 3, 3) In [4]: oneflow_compile(c) Out[4]: --------------------------------------------------------------------------- KeyError Traceback (most recent call last) File ~/miniconda3/envs/py10/lib/python3.10/site-packages/IPython/core/formatters.py:708, in PlainTextFormatter.__call__(self, obj) 701 stream = StringIO() 702 printer = pretty.RepresentationPrinter(stream, self.verbose, 703 self.max_width, self.newline, 704 max_seq_length=self.max_seq_length, 705 singleton_pprinters=self.singleton_printers, 706 type_pprinters=self.type_printers, 707 deferred_pprinters=self.deferred_printers) --> 708 printer.pretty(obj) 709 printer.flush() 710 return stream.getvalue() File ~/miniconda3/envs/py10/lib/python3.10/site-packages/IPython/lib/pretty.py:410, in RepresentationPrinter.pretty(self, obj) 407 return meth(obj, self, cycle) 408 if cls is not object \ 409 and callable(cls.__dict__.get('__repr__')): --> 410 return _repr_pprint(obj, self, cycle) 412 return _default_pprint(obj, self, cycle) 413 finally: File ~/miniconda3/envs/py10/lib/python3.10/site-packages/IPython/lib/pretty.py:778, in _repr_pprint(obj, p, cycle) 776 """A pprint that just redirects to the normal repr function.""" 777 # Find newlines and replace them with p.break_() --> 778 output = repr(obj) 779 lines = output.splitlines() 780 with p.group(): File ~/miniconda3/envs/py10/lib/python3.10/site-packages/torch/nn/modules/module.py:2378, in Module.__repr__(self) 2375 def __repr__(self): 2376 # We treat the extra repr like the sub-module, one item per line 2377 extra_lines = [] -> 2378 extra_repr = self.extra_repr() 2379 # empty string will be split into list [''] 2380 if extra_repr: File ~/miniconda3/envs/py10/lib/python3.10/site-packages/torch/nn/modules/conv.py:172, in _ConvNd.extra_repr(self) 170 if self.padding_mode != 'zeros': 171 s += ', padding_mode={padding_mode}' --> 172 return s.format(**self.__dict__) KeyError: 'in_channels' In [5]: ``` after: ```python In [1]: import torch In [2]: from onediff.infer_compiler.with_oneflow_compile import oneflow_compile In [3]: c = torch.nn.Conv2d(3, 3, 3) In [4]: oneflow_compile(c) Out[4]: MixedDeployableModule(of Conv2d)( 3, 3, kernel_size=(3, 3), stride=(1, 1) (_deployable_module_model): MixedDualModule(of Conv2d)( 3, 3, kernel_size=(3, 3), stride=(1, 1) (_torch_module): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1)) ) ) In [5]: ``` 可以看到 after 里面的结果还是有些臃肿,如果能再加上 #604 里面的修改,把 _modules 也重定向一下,结果就会变成这样: ```python In [1]: import torch In [2]: from onediff.infer_compiler.with_oneflow_compile import oneflow_compile In [3]: c = torch.nn.Conv2d(3, 3, 3) In [4]: oneflow_compile(c) Out[4]: MixedDeployableModule(of Conv2d)(3, 3, kernel_size=(3, 3), stride=(1, 1)) ``` 看起来就更简洁了。
background: https://github.com/siliconflow/sd-team/issues/191#issue-2063207008
test code:
before:
after: