Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor dualmodule._modules #604

Merged
merged 5 commits into from
Feb 1, 2024
Merged

Refactor dualmodule._modules #604

merged 5 commits into from
Feb 1, 2024

Conversation

marigoold
Copy link
Contributor

background: https://github.com/siliconflow/sd-team/issues/191#issue-2063207008
test code:

import oneflow as flow
from onediff.infer_compiler.with_oneflow_compile import oneflow_compile
from onediff.infer_compiler.transform import register
import torch

class Modulelist(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.module_list = torch.nn.Sequential(
            *[torch.nn.Linear(4, 4, bias=None) for _ in range(5)]
        )

    def forward(self, x):
        return self.module_list(x)

class OfModulelist(flow.nn.Module):
    def __init__(self):
        super().__init__()
        self.module_list = flow.nn.Sequential(
            *[flow.nn.Linear(4, 4, bias=None) for _ in range(5)]
        )

    def forward(self, x):
        return self.module_list(x)

register(torch2oflow_class_map={Modulelist: OfModulelist})

model = Modulelist()
of_model = oneflow_compile(model, use_graph=False)
print("torch enumerate: ", [type(x) for _, x in enumerate(model.module_list)])
print("oneflow enumerate: ", [type(x) for _, x in enumerate(of_model.module_list)])

before:

torch enumerate:  [<class 'torch.nn.modules.linear.Linear'>, <class 'torch.nn.modules.linear.Linear'>, <class 'torch.nn.modules.linear.Linear'>, <class 'torch.nn.modules.linear.Linear'>, <class 'torch.nn.modules.linear.Linear'>]
oneflow enumerate:  [<class 'torch.nn.modules.container.Sequential'>]

after:

torch enumerate:  [<class 'torch.nn.modules.linear.Linear'>, <class 'torch.nn.modules.linear.Linear'>, <class 'torch.nn.modules.linear.Linear'>, <class 'torch.nn.modules.linear.Linear'>, <class 'torch.nn.mo
dules.linear.Linear'>]
oneflow enumerate:  [<class 'torch.nn.modules.linear.Linear'>, <class 'torch.nn.modules.linear.Linear'>, <class 'torch.nn.modules.linear.Linear'>, <class 'torch.nn.modules.linear.Linear'>, <class 'torch.nn.
modules.linear.Linear'>]

@marigoold marigoold changed the title Dev wy refactor dualmodule Refactor dualmodule._modules Feb 1, 2024
@marigoold
Copy link
Contributor Author

之所以不复制 dict 而是直接指向 torch module 的 dict,是因为考虑到 torch module 改动可以直接同步到 dual module 这里

@hjchen2
Copy link
Contributor

hjchen2 commented Feb 1, 2024

之所以不复制 dict 而是直接指向 torch module 的 dict,是因为考虑到 torch module 改动可以直接同步到 dual module 这里

对torch module的改动没法同步到oneflow module上吧

@marigoold
Copy link
Contributor Author

之所以不复制 dict 而是直接指向 torch module 的 dict,是因为考虑到 torch module 改动可以直接同步到 dual module 这里

对torch module的改动没法同步到oneflow module上吧

如果 oneflow module 还没有创建,是可以在创建的时候同步过去的;如果已经创建了,修改 torch module 意味着计算图改变,这个时候还是得重新编译吧。
这个改动主要是为了让遍历 deployable_module_model.named_modules() 时表现和遍历 torch module 相同

torch_module, oneflow_module
)
))
object.__setattr__(self, "_modules", torch_module._modules)
Copy link
Contributor

@ccssu ccssu Feb 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这么设置 model=oneflow_compile(model) ; model.state_dict() 是怎样的? 和以前兼容吗

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这么设置 model=oneflow_compile(model) ; model.state_dict() 是怎样的? 和以前兼容吗

理论上就不会有 _deployable_module_model._torch_module 这一层了,那个 hook 就不需要了。不过我还是测试一下

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这么设置 model=oneflow_compile(model) ; model.state_dict() 是怎样的? 和以前兼容吗

对比了一下 key,是一样的

@marigoold marigoold mentioned this pull request Feb 1, 2024
@marigoold marigoold merged commit 63f2810 into main Feb 1, 2024
4 of 5 checks passed
@marigoold marigoold deleted the dev_wy_refactor_dualmodule branch February 1, 2024 07:31
marigoold added a commit that referenced this pull request Feb 1, 2024
对于一些 module,经过 oneflow_compile 之后,丢失了 `__dict__` 中的信息。
比如 Conv2d,变成 MixedDualModule 之后,in_channels 实际上存到了
`MixedDualModule._torch_module.__dict__` 里面了。而
`MixedDualModule.extra_repr` 会在当前的 `__dict__` 里面找 in_channels
这些参数,当然找不到,从而 print 时候报错,很影响 debug。

这个 PR 把 DualModule 的 extra_repr 定向到了 torch module,这样就可以找到了。
另外在 repr 返回类名的时候,把被编译的类名也加了进去,能知道是什么 module 被编译了。
修改前后对比如下所示:

before:
```python
In [1]: import torch

In [2]: from onediff.infer_compiler.with_oneflow_compile import oneflow_compile

In [3]: c = torch.nn.Conv2d(3, 3, 3)

In [4]: oneflow_compile(c)
Out[4]: ---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
File ~/miniconda3/envs/py10/lib/python3.10/site-packages/IPython/core/formatters.py:708, in PlainTextFormatter.__call__(self, obj)
    701 stream = StringIO()
    702 printer = pretty.RepresentationPrinter(stream, self.verbose,
    703     self.max_width, self.newline,
    704     max_seq_length=self.max_seq_length,
    705     singleton_pprinters=self.singleton_printers,
    706     type_pprinters=self.type_printers,
    707     deferred_pprinters=self.deferred_printers)
--> 708 printer.pretty(obj)
    709 printer.flush()
    710 return stream.getvalue()

File ~/miniconda3/envs/py10/lib/python3.10/site-packages/IPython/lib/pretty.py:410, in RepresentationPrinter.pretty(self, obj)
    407                         return meth(obj, self, cycle)
    408                 if cls is not object \
    409                         and callable(cls.__dict__.get('__repr__')):
--> 410                     return _repr_pprint(obj, self, cycle)
    412     return _default_pprint(obj, self, cycle)
    413 finally:

File ~/miniconda3/envs/py10/lib/python3.10/site-packages/IPython/lib/pretty.py:778, in _repr_pprint(obj, p, cycle)
    776 """A pprint that just redirects to the normal repr function."""
    777 # Find newlines and replace them with p.break_()
--> 778 output = repr(obj)
    779 lines = output.splitlines()
    780 with p.group():

File ~/miniconda3/envs/py10/lib/python3.10/site-packages/torch/nn/modules/module.py:2378, in Module.__repr__(self)
   2375 def __repr__(self):
   2376     # We treat the extra repr like the sub-module, one item per line
   2377     extra_lines = []
-> 2378     extra_repr = self.extra_repr()
   2379     # empty string will be split into list ['']
   2380     if extra_repr:

File ~/miniconda3/envs/py10/lib/python3.10/site-packages/torch/nn/modules/conv.py:172, in _ConvNd.extra_repr(self)
    170 if self.padding_mode != 'zeros':
    171     s += ', padding_mode={padding_mode}'
--> 172 return s.format(**self.__dict__)

KeyError: 'in_channels'

In [5]:
```

after:
```python
In [1]: import torch

In [2]: from onediff.infer_compiler.with_oneflow_compile import oneflow_compile

In [3]: c = torch.nn.Conv2d(3, 3, 3)

In [4]: oneflow_compile(c)
Out[4]:
MixedDeployableModule(of Conv2d)(
  3, 3, kernel_size=(3, 3), stride=(1, 1)
  (_deployable_module_model): MixedDualModule(of Conv2d)(
    3, 3, kernel_size=(3, 3), stride=(1, 1)
    (_torch_module): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1))
  )
)

In [5]:
```

可以看到 after 里面的结果还是有些臃肿,如果能再加上
#604 里面的修改,把 _modules
也重定向一下,结果就会变成这样:
```python
In [1]: import torch

In [2]: from onediff.infer_compiler.with_oneflow_compile import oneflow_compile

In [3]: c = torch.nn.Conv2d(3, 3, 3)

In [4]: oneflow_compile(c)
Out[4]: MixedDeployableModule(of Conv2d)(3, 3, kernel_size=(3, 3), stride=(1, 1))
```
看起来就更简洁了。
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants