Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix bug of dual module (setattr, and compatible with DualModule input) #613

Merged
merged 29 commits into from
Feb 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
d78c84f
support text encoder
marigoold Jan 30, 2024
4e56a68
refine code
marigoold Jan 30, 2024
46a6a3f
compatible for prev diffusers version
marigoold Jan 30, 2024
ffdf6ed
Update __init__.py
marigoold Jan 30, 2024
9877fb1
refine
marigoold Jan 30, 2024
9aa6af4
Merge branch 'dev_wy_lora_support_textencoder' of github.com:Oneflow-…
marigoold Jan 30, 2024
e252c38
update readme
marigoold Jan 31, 2024
7ede9af
refine doc
marigoold Jan 31, 2024
ef1eb8c
remove unfuse in fuse func
marigoold Jan 31, 2024
cead728
Merge branch 'main' into dev_wy_lora_support_textencoder
marigoold Jan 31, 2024
54903ce
refine
marigoold Jan 31, 2024
55c65b9
rename
marigoold Jan 31, 2024
7892a5b
remove out dated lora.py
marigoold Jan 31, 2024
93388a4
update readme
marigoold Jan 31, 2024
cb05e3e
refine
marigoold Jan 31, 2024
a964b23
Update lora.py
marigoold Jan 31, 2024
46b54c2
fix bug
marigoold Jan 31, 2024
2cf586e
Merge branch 'dev_wy_lora_support_textencoder' of github.com:Oneflow-…
marigoold Jan 31, 2024
8743725
refine
marigoold Jan 31, 2024
cf0452f
Merge branch 'main' into dev_wy_lora_support_textencoder
marigoold Feb 1, 2024
628a608
dual modulelist setattr fix bug, compatible with DualModule input
marigoold Feb 1, 2024
c773ac4
remove utils/__init__.py
marigoold Feb 2, 2024
6009d1a
modify examples
marigoold Feb 2, 2024
4a43b0b
update doc, and var name
marigoold Feb 2, 2024
57e8ca1
Merge branch 'dev_wy_lora_support_textencoder' into fix_wy_dualmodule…
marigoold Feb 2, 2024
bf57ec6
compatible for PEFT
marigoold Feb 2, 2024
0e7b951
Merge branch 'main' into fix_wy_dualmodulelist_setattr
marigoold Feb 3, 2024
e2c3991
Merge branch 'main' into fix_wy_dualmodulelist_setattr
strint Feb 4, 2024
573a1e5
Merge branch 'main' into fix_wy_dualmodulelist_setattr
strint Feb 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion onediff_diffusers_extensions/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,6 @@ The results are shown below
| watercolor_v1_sdxl_lora.safetensors | 12M | 1.54 s | 2.01 s | **0.35 s** | |



### Note

1. OneDiff extensions for LoRA is currently not supported for PEFT, and only supports diffusers of at least version 0.21.0.
Expand Down
4 changes: 2 additions & 2 deletions onediff_diffusers_extensions/onediffx/lora/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def load_and_fuse_lora(
# or add_{k,v,q,out_proj}_proj_lora layers.
rank = value_dict["lora.down.weight"].shape[0]

if isinstance(attn_processor, LoRACompatibleConv):
if isinstance(attn_processor, (LoRACompatibleConv, torch.nn.Conv2d)):
conv_fuse_lora(
attn_processor,
value_dict,
Expand All @@ -182,7 +182,7 @@ def load_and_fuse_lora(
offload_device=offload_device,
offload_weight=offload_weight,
)
elif isinstance(attn_processor, LoRACompatibleLinear):
elif isinstance(attn_processor, (LoRACompatibleLinear, torch.nn.Linear)):
linear_fuse_lora(
attn_processor,
value_dict,
Expand Down
14 changes: 11 additions & 3 deletions src/onediff/infer_compiler/with_oneflow_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,8 @@ def __getattr__(self, name):
else getattr(self._oneflow_module, name)
)
if isinstance(torch_attr, torch.nn.ModuleList):
oneflow_attr = (
[None] * len(torch_attr) if oneflow_attr is None else oneflow_attr
)
if oneflow_attr is None:
oneflow_attr = flow.nn.ModuleList([None] * len(torch_attr))
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里原来的方法创建的 oneflow_attr 是个 list,在下面 DualModule 的 setattr 中,setattr(self._oneflow_module, key, value) 对于 key='0' 会挂掉。这里本意应该是也创建一个 flow.nn.ModuleList

return DualModuleList(torch_attr, oneflow_attr)

elif isinstance(torch_attr, torch.nn.Module):
Expand Down Expand Up @@ -159,15 +158,24 @@ def __setattr__(self, key, value):


def get_mixed_dual_module(module_cls):
if issubclass(module_cls, DualModule) and "MixedDualModule" in module_cls.__name__:
return module_cls

class MixedDualModule(DualModule, module_cls):
def __init__(self, torch_module, oneflow_module):
while isinstance(torch_module, DualModule):
torch_module = torch_module._torch_module
DualModule.__init__(self, torch_module, oneflow_module)

def _get_name(self) -> str:
return f"{self.__class__.__name__}(of {module_cls.__name__})"

return MixedDualModule

@torch2oflow.register
def _(mod: DualModule, verbose=False):
return torch2oflow(mod._torch_module, verbose)

Comment on lines +175 to +178
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

转换 torch model 在那种情况会出现 DualModule ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

转换 torch model 在那种情况会出现 DualModule ?

PEFT 这里有一处逻辑是,找到 LoRA 对应的 Linear,用自己定义的 LoRALinear 包一下。
“找到 LoRA 对应的 Linear” 这里会导致 DualModule getattr 返回一个 DualModule
“用自己定义的 LoRALinear 包一下” 这里会生成一个 torch module,但它的 submodule 包含 DualModule,再创建 DualModule 的时候,会导致这里 torch2oflow 接受一个 DualModule


def handle_deployable_exception(func):
@wraps(func)
Expand Down
Loading