Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
对于一些 module,经过 oneflow_compile 之后,丢失了 `__dict__` 中的信息。 比如 Conv2d,变成 MixedDualModule 之后,in_channels 实际上存到了 `MixedDualModule._torch_module.__dict__` 里面了。而 `MixedDualModule.extra_repr` 会在当前的 `__dict__` 里面找 in_channels 这些参数,当然找不到,从而 print 时候报错,很影响 debug。 这个 PR 把 DualModule 的 extra_repr 定向到了 torch module,这样就可以找到了。 另外在 repr 返回类名的时候,把被编译的类名也加了进去,能知道是什么 module 被编译了。 修改前后对比如下所示: before: ```python In [1]: import torch In [2]: from onediff.infer_compiler.with_oneflow_compile import oneflow_compile In [3]: c = torch.nn.Conv2d(3, 3, 3) In [4]: oneflow_compile(c) Out[4]: --------------------------------------------------------------------------- KeyError Traceback (most recent call last) File ~/miniconda3/envs/py10/lib/python3.10/site-packages/IPython/core/formatters.py:708, in PlainTextFormatter.__call__(self, obj) 701 stream = StringIO() 702 printer = pretty.RepresentationPrinter(stream, self.verbose, 703 self.max_width, self.newline, 704 max_seq_length=self.max_seq_length, 705 singleton_pprinters=self.singleton_printers, 706 type_pprinters=self.type_printers, 707 deferred_pprinters=self.deferred_printers) --> 708 printer.pretty(obj) 709 printer.flush() 710 return stream.getvalue() File ~/miniconda3/envs/py10/lib/python3.10/site-packages/IPython/lib/pretty.py:410, in RepresentationPrinter.pretty(self, obj) 407 return meth(obj, self, cycle) 408 if cls is not object \ 409 and callable(cls.__dict__.get('__repr__')): --> 410 return _repr_pprint(obj, self, cycle) 412 return _default_pprint(obj, self, cycle) 413 finally: File ~/miniconda3/envs/py10/lib/python3.10/site-packages/IPython/lib/pretty.py:778, in _repr_pprint(obj, p, cycle) 776 """A pprint that just redirects to the normal repr function.""" 777 # Find newlines and replace them with p.break_() --> 778 output = repr(obj) 779 lines = output.splitlines() 780 with p.group(): File ~/miniconda3/envs/py10/lib/python3.10/site-packages/torch/nn/modules/module.py:2378, in Module.__repr__(self) 2375 def __repr__(self): 2376 # We treat the extra repr like the sub-module, one item per line 2377 extra_lines = [] -> 2378 extra_repr = self.extra_repr() 2379 # empty string will be split into list [''] 2380 if extra_repr: File ~/miniconda3/envs/py10/lib/python3.10/site-packages/torch/nn/modules/conv.py:172, in _ConvNd.extra_repr(self) 170 if self.padding_mode != 'zeros': 171 s += ', padding_mode={padding_mode}' --> 172 return s.format(**self.__dict__) KeyError: 'in_channels' In [5]: ``` after: ```python In [1]: import torch In [2]: from onediff.infer_compiler.with_oneflow_compile import oneflow_compile In [3]: c = torch.nn.Conv2d(3, 3, 3) In [4]: oneflow_compile(c) Out[4]: MixedDeployableModule(of Conv2d)( 3, 3, kernel_size=(3, 3), stride=(1, 1) (_deployable_module_model): MixedDualModule(of Conv2d)( 3, 3, kernel_size=(3, 3), stride=(1, 1) (_torch_module): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1)) ) ) In [5]: ``` 可以看到 after 里面的结果还是有些臃肿,如果能再加上 #604 里面的修改,把 _modules 也重定向一下,结果就会变成这样: ```python In [1]: import torch In [2]: from onediff.infer_compiler.with_oneflow_compile import oneflow_compile In [3]: c = torch.nn.Conv2d(3, 3, 3) In [4]: oneflow_compile(c) Out[4]: MixedDeployableModule(of Conv2d)(3, 3, kernel_size=(3, 3), stride=(1, 1)) ``` 看起来就更简洁了。
- Loading branch information