Skip to content

Commit

Permalink
add repr of dualmodule (#610)
Browse files Browse the repository at this point in the history
对于一些 module,经过 oneflow_compile 之后,丢失了 `__dict__` 中的信息。
比如 Conv2d,变成 MixedDualModule 之后,in_channels 实际上存到了
`MixedDualModule._torch_module.__dict__` 里面了。而
`MixedDualModule.extra_repr` 会在当前的 `__dict__` 里面找 in_channels
这些参数,当然找不到,从而 print 时候报错,很影响 debug。

这个 PR 把 DualModule 的 extra_repr 定向到了 torch module,这样就可以找到了。
另外在 repr 返回类名的时候,把被编译的类名也加了进去,能知道是什么 module 被编译了。
修改前后对比如下所示:

before:
```python
In [1]: import torch

In [2]: from onediff.infer_compiler.with_oneflow_compile import oneflow_compile

In [3]: c = torch.nn.Conv2d(3, 3, 3)

In [4]: oneflow_compile(c)
Out[4]: ---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
File ~/miniconda3/envs/py10/lib/python3.10/site-packages/IPython/core/formatters.py:708, in PlainTextFormatter.__call__(self, obj)
    701 stream = StringIO()
    702 printer = pretty.RepresentationPrinter(stream, self.verbose,
    703     self.max_width, self.newline,
    704     max_seq_length=self.max_seq_length,
    705     singleton_pprinters=self.singleton_printers,
    706     type_pprinters=self.type_printers,
    707     deferred_pprinters=self.deferred_printers)
--> 708 printer.pretty(obj)
    709 printer.flush()
    710 return stream.getvalue()

File ~/miniconda3/envs/py10/lib/python3.10/site-packages/IPython/lib/pretty.py:410, in RepresentationPrinter.pretty(self, obj)
    407                         return meth(obj, self, cycle)
    408                 if cls is not object \
    409                         and callable(cls.__dict__.get('__repr__')):
--> 410                     return _repr_pprint(obj, self, cycle)
    412     return _default_pprint(obj, self, cycle)
    413 finally:

File ~/miniconda3/envs/py10/lib/python3.10/site-packages/IPython/lib/pretty.py:778, in _repr_pprint(obj, p, cycle)
    776 """A pprint that just redirects to the normal repr function."""
    777 # Find newlines and replace them with p.break_()
--> 778 output = repr(obj)
    779 lines = output.splitlines()
    780 with p.group():

File ~/miniconda3/envs/py10/lib/python3.10/site-packages/torch/nn/modules/module.py:2378, in Module.__repr__(self)
   2375 def __repr__(self):
   2376     # We treat the extra repr like the sub-module, one item per line
   2377     extra_lines = []
-> 2378     extra_repr = self.extra_repr()
   2379     # empty string will be split into list ['']
   2380     if extra_repr:

File ~/miniconda3/envs/py10/lib/python3.10/site-packages/torch/nn/modules/conv.py:172, in _ConvNd.extra_repr(self)
    170 if self.padding_mode != 'zeros':
    171     s += ', padding_mode={padding_mode}'
--> 172 return s.format(**self.__dict__)

KeyError: 'in_channels'

In [5]:
```

after:
```python
In [1]: import torch

In [2]: from onediff.infer_compiler.with_oneflow_compile import oneflow_compile

In [3]: c = torch.nn.Conv2d(3, 3, 3)

In [4]: oneflow_compile(c)
Out[4]:
MixedDeployableModule(of Conv2d)(
  3, 3, kernel_size=(3, 3), stride=(1, 1)
  (_deployable_module_model): MixedDualModule(of Conv2d)(
    3, 3, kernel_size=(3, 3), stride=(1, 1)
    (_torch_module): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1))
  )
)

In [5]:
```

可以看到 after 里面的结果还是有些臃肿,如果能再加上
#604 里面的修改,把 _modules
也重定向一下,结果就会变成这样:
```python
In [1]: import torch

In [2]: from onediff.infer_compiler.with_oneflow_compile import oneflow_compile

In [3]: c = torch.nn.Conv2d(3, 3, 3)

In [4]: oneflow_compile(c)
Out[4]: MixedDeployableModule(of Conv2d)(3, 3, kernel_size=(3, 3), stride=(1, 1))
```
看起来就更简洁了。
  • Loading branch information
marigoold authored Feb 1, 2024
1 parent 63f2810 commit d53d4ef
Showing 1 changed file with 12 additions and 0 deletions.
12 changes: 12 additions & 0 deletions src/onediff/infer_compiler/with_oneflow_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,9 @@ def __setattr__(self, name: str, value: Any) -> None:
setattr(self._oneflow_module, name, v)
setattr(self._torch_module, name, value)

def extra_repr(self) -> str:
return self._torch_module.extra_repr()


class DualModuleList(torch.nn.ModuleList):
def __init__(self, torch_modules, oneflow_modules):
Expand Down Expand Up @@ -158,6 +161,9 @@ class MixedDualModule(DualModule, module_cls):
def __init__(self, torch_module, oneflow_module):
DualModule.__init__(self, torch_module, oneflow_module)

def _get_name(self) -> str:
return f"{self.__class__.__name__}(of {module_cls.__name__})"

return MixedDualModule


Expand Down Expand Up @@ -314,6 +320,9 @@ def load_graph(self, file_path, device=None, run_warmup=True):
def save_graph(self, file_path):
self.get_graph().save_graph(file_path)

def extra_repr(self) -> str:
return self._deployable_module_model.extra_repr()


class OneflowGraph(flow.nn.Graph):
@flow.nn.Graph.with_dynamic_input_shape()
Expand Down Expand Up @@ -382,6 +391,9 @@ def from_existing(
)
return instance

def _get_name(self):
return f"{self.__class__.__name__}(of {module_cls.__name__})"

return MixedDeployableModule


Expand Down

0 comments on commit d53d4ef

Please sign in to comment.