Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Official LoRA compatible with OneDiff #507

Merged
merged 25 commits into from
Jan 13, 2024
Merged

Conversation

marigoold
Copy link
Collaborator

@marigoold marigoold commented Jan 9, 2024

  • sd-webui 如何实现 LoRA?
    sd-webui 官方实现 LoRA 的方法是在对应 module(比如 nn.Linear, nn.Conv2d)的 forward 函数中插入 apply_weight 函数,在函数中是把 LoRA 的权重加到 weight 上

  • 这种方法为何不适配 OneDiff?
    由于这种方法修改的是 torch module 的 forward,只能影响到 torch 的 module,无法影响到 OneFlow 的 module,。

  • 如何修改使其适配 OneDiff?
    hack掉 LoRA 的 activate,对 unet 进行判断,如果是 OneDiff compile 过的 unet,就在这里手动把 LoRA 权重加上去。修改和复原权重的函数都来源于上文中提到的 apply_weight 相关函数,没有作任何改动

  • 关于切换 LoRA 的时间开销
    在 activate 中,使用 cost_cnt 对手动 apply weight 计算开销。由于 LoRA fusing 本质就是把矩阵乘法和加法,所以时间开销和要 load 的 LoRA 数量有关,这里只测试 load 单个 LoRA 的结果。时间大约是 700ms / LoRA cnt(很奇怪的是前几次推理切换 LoRA 的时间很长,之后的时间稳定在了 700 ms 左右,而且这部分计时不包括从磁盘中读取 lora)

@marigoold marigoold marked this pull request as ready for review January 9, 2024 14:27
@marigoold marigoold requested review from doombeaker and strint January 9, 2024 14:28
@strint
Copy link
Collaborator

strint commented Jan 10, 2024

@@ -24,6 +24,10 @@ It is recommended to create a Python virtual environment in advance. For example
git clone https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
git clone https://github.com/siliconflow/onediff.git
cp -r onediff/onediff_sd_webui_extensions stable-diffusion-webui/extensions/

# Enable LoRA compatible with OneDiff (Ignore this line if you do not need LoRA compatible with OneDiff)
cp onediff/onediff_sd_webui_extensions/_lora_patcher/extra_network_lora.py stable-diffusion-webui/extensions-builtin/Lora
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hijack 一下 ExtraNetworkLora.activate

from onediff.infer_compiler.with_oneflow_compile import DeployableModule


class HijackLoraActivate:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个 context manager 看起来也行。

这里没有选择用那个 hijack 函数,是觉得比较麻烦?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个 context manager 看起来也行。

这里没有选择用那个 hijack 函数,是觉得比较麻烦?

有两个考虑
第一个是用 Hijack 永久替换会改变函数的 meta 信息,对 debug 不友好,我个人不太喜欢这种操作
第二个是 CondFunc 我不太会选择那个 cond 应该根据什么来判断
这里用 ctx manager 直接作用在前向上头,前向完就恢复,副作用最小,和其他类 hijack 操作(替换 sd 模型)也比较一致

return activate_func

def activate(self, p, params_list):
activate_func(self, p, params_list)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是这样的顺序么,不应该先 network_apply_weights 然后再 activate_func 么

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是这样的顺序么,不应该先 network_apply_weights 然后再 activate_func 么

是这样的顺序没错,activate_func 是 ExtraNetworkLora 原生的 activate,里面加载了 LoRA 模型,所以应当在 activate_func 之后再 apply weights

if self.lora_class is None:
return
self.orig_func = self.lora_class.activate
self.lora_class.activate = hijacked_activate(self.lora_class.activate)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

另外不需要 hijack 一下 deactivate 么

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

另外不需要 hijack 一下 deactivate 么

这里不需要。一开始 hijack 是我对代码理解不到位,以为卸载也要手动做。其实卸载 LoRA 模型也是另一种形式的 load,等价于 load 了一个空 LoRA,也在 apply weights 里面做。原生的 deactivate 只是清空了 LoRA 相关的 errors。

@strint
Copy link
Collaborator

strint commented Jan 13, 2024

还需要补下

  • 性能;
  • readme;

Updated on JAN 13, 2024. Device: RTX 3090. Resolution: 1024x1024
| | torch(Baseline) | TensorRT-v9.0.1 | onediff(Optimized) | Percentage improvement |
| -------- | --------------- | --------------- | ------------------ | ---------------------- |
| w/o LoRA | 2.99it/s | 6.40it/s | 7.10it/s | 237.46% |
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

写成 SDXL 吧?

| | torch(Baseline) | TensorRT-v9.0.1 | onediff(Optimized) | Percentage improvement |
| -------- | --------------- | --------------- | ------------------ | ---------------------- |
| w/o LoRA | 2.99it/s | 6.40it/s | 7.10it/s | 237.46% |
| w/ LoRA | 2.95it/s | N/A | 7.09it/s | 240.34% |
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

写成 SDXL with LoRA

@strint strint merged commit 57489bb into main Jan 13, 2024
2 of 4 checks passed
@strint strint deleted the dev_wy_lora_compatible_for_onediff branch January 13, 2024 14:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants