-
Notifications
You must be signed in to change notification settings - Fork 113
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix bug of dual module (setattr, and compatible with DualModule input) #613
Changes from all commits
d78c84f
4e56a68
46a6a3f
ffdf6ed
9877fb1
9aa6af4
e252c38
7ede9af
ef1eb8c
cead728
54903ce
55c65b9
7892a5b
93388a4
cb05e3e
a964b23
46b54c2
2cf586e
8743725
cf0452f
628a608
c773ac4
6009d1a
4a43b0b
57e8ca1
bf57ec6
0e7b951
e2c3991
573a1e5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -93,9 +93,8 @@ def __getattr__(self, name): | |
else getattr(self._oneflow_module, name) | ||
) | ||
if isinstance(torch_attr, torch.nn.ModuleList): | ||
oneflow_attr = ( | ||
[None] * len(torch_attr) if oneflow_attr is None else oneflow_attr | ||
) | ||
if oneflow_attr is None: | ||
oneflow_attr = flow.nn.ModuleList([None] * len(torch_attr)) | ||
return DualModuleList(torch_attr, oneflow_attr) | ||
|
||
elif isinstance(torch_attr, torch.nn.Module): | ||
|
@@ -159,15 +158,24 @@ def __setattr__(self, key, value): | |
|
||
|
||
def get_mixed_dual_module(module_cls): | ||
if issubclass(module_cls, DualModule) and "MixedDualModule" in module_cls.__name__: | ||
return module_cls | ||
|
||
class MixedDualModule(DualModule, module_cls): | ||
def __init__(self, torch_module, oneflow_module): | ||
while isinstance(torch_module, DualModule): | ||
torch_module = torch_module._torch_module | ||
DualModule.__init__(self, torch_module, oneflow_module) | ||
|
||
def _get_name(self) -> str: | ||
return f"{self.__class__.__name__}(of {module_cls.__name__})" | ||
|
||
return MixedDualModule | ||
|
||
@torch2oflow.register | ||
def _(mod: DualModule, verbose=False): | ||
return torch2oflow(mod._torch_module, verbose) | ||
|
||
Comment on lines
+175
to
+178
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 转换 torch model 在那种情况会出现 DualModule ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
PEFT 这里有一处逻辑是,找到 LoRA 对应的 Linear,用自己定义的 LoRALinear 包一下。 |
||
|
||
def handle_deployable_exception(func): | ||
@wraps(func) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里原来的方法创建的 oneflow_attr 是个 list,在下面 DualModule 的 setattr 中,setattr(self._oneflow_module, key, value) 对于 key='0' 会挂掉。这里本意应该是也创建一个 flow.nn.ModuleList