-
Notifications
You must be signed in to change notification settings - Fork 112
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Official LoRA compatible with OneDiff #507
Conversation
onediff_sd_webui_extensions/_lora_patcher/extra_networks_lora.py
Outdated
Show resolved
Hide resolved
onediff_sd_webui_extensions/_lora_patcher/extra_networks_lora.py
Outdated
Show resolved
Hide resolved
onediff_sd_webui_extensions/_lora_patcher/extra_networks_lora.py
Outdated
Show resolved
Hide resolved
onediff_sd_webui_extensions/_lora_patcher/extra_networks_lora.py
Outdated
Show resolved
Hide resolved
@@ -24,6 +24,10 @@ It is recommended to create a Python virtual environment in advance. For example | |||
git clone https://github.com/AUTOMATIC1111/stable-diffusion-webui.git | |||
git clone https://github.com/siliconflow/onediff.git | |||
cp -r onediff/onediff_sd_webui_extensions stable-diffusion-webui/extensions/ | |||
|
|||
# Enable LoRA compatible with OneDiff (Ignore this line if you do not need LoRA compatible with OneDiff) | |||
cp onediff/onediff_sd_webui_extensions/_lora_patcher/extra_network_lora.py stable-diffusion-webui/extensions-builtin/Lora |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
看起来可以使用:
model_management_hijacker.register( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hijack 一下 ExtraNetworkLora.activate
…ow-Inc/diffusers into dev_wy_lora_compatible_for_onediff
…ow-Inc/diffusers into dev_wy_lora_compatible_for_onediff
from onediff.infer_compiler.with_oneflow_compile import DeployableModule | ||
|
||
|
||
class HijackLoraActivate: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个 context manager 看起来也行。
这里没有选择用那个 hijack 函数,是觉得比较麻烦?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个 context manager 看起来也行。
这里没有选择用那个 hijack 函数,是觉得比较麻烦?
有两个考虑
第一个是用 Hijack 永久替换会改变函数的 meta 信息,对 debug 不友好,我个人不太喜欢这种操作
第二个是 CondFunc 我不太会选择那个 cond 应该根据什么来判断
这里用 ctx manager 直接作用在前向上头,前向完就恢复,副作用最小,和其他类 hijack 操作(替换 sd 模型)也比较一致
return activate_func | ||
|
||
def activate(self, p, params_list): | ||
activate_func(self, p, params_list) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是这样的顺序么,不应该先 network_apply_weights 然后再 activate_func 么
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是这样的顺序么,不应该先 network_apply_weights 然后再 activate_func 么
是这样的顺序没错,activate_func 是 ExtraNetworkLora 原生的 activate,里面加载了 LoRA 模型,所以应当在 activate_func 之后再 apply weights
if self.lora_class is None: | ||
return | ||
self.orig_func = self.lora_class.activate | ||
self.lora_class.activate = hijacked_activate(self.lora_class.activate) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
另外不需要 hijack 一下 deactivate 么
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
另外不需要 hijack 一下 deactivate 么
这里不需要。一开始 hijack 是我对代码理解不到位,以为卸载也要手动做。其实卸载 LoRA 模型也是另一种形式的 load,等价于 load 了一个空 LoRA,也在 apply weights 里面做。原生的 deactivate 只是清空了 LoRA 相关的 errors。
还需要补下
|
…ow-Inc/diffusers into dev_wy_lora_compatible_for_onediff
Updated on JAN 13, 2024. Device: RTX 3090. Resolution: 1024x1024 | ||
| | torch(Baseline) | TensorRT-v9.0.1 | onediff(Optimized) | Percentage improvement | | ||
| -------- | --------------- | --------------- | ------------------ | ---------------------- | | ||
| w/o LoRA | 2.99it/s | 6.40it/s | 7.10it/s | 237.46% | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
写成 SDXL 吧?
| | torch(Baseline) | TensorRT-v9.0.1 | onediff(Optimized) | Percentage improvement | | ||
| -------- | --------------- | --------------- | ------------------ | ---------------------- | | ||
| w/o LoRA | 2.99it/s | 6.40it/s | 7.10it/s | 237.46% | | ||
| w/ LoRA | 2.95it/s | N/A | 7.09it/s | 240.34% | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
写成 SDXL with LoRA
sd-webui 如何实现 LoRA?
sd-webui 官方实现 LoRA 的方法是在对应 module(比如 nn.Linear, nn.Conv2d)的 forward 函数中插入 apply_weight 函数,在函数中是把 LoRA 的权重加到 weight 上
这种方法为何不适配 OneDiff?
由于这种方法修改的是 torch module 的 forward,只能影响到 torch 的 module,无法影响到 OneFlow 的 module,。
如何修改使其适配 OneDiff?
hack掉 LoRA 的 activate,对 unet 进行判断,如果是 OneDiff compile 过的 unet,就在这里手动把 LoRA 权重加上去。修改和复原权重的函数都来源于上文中提到的 apply_weight 相关函数,没有作任何改动
关于切换 LoRA 的时间开销
在 activate 中,使用 cost_cnt 对手动 apply weight 计算开销。由于 LoRA fusing 本质就是把矩阵乘法和加法,所以时间开销和要 load 的 LoRA 数量有关,这里只测试 load 单个 LoRA 的结果。时间大约是 700ms / LoRA cnt(很奇怪的是前几次推理切换 LoRA 的时间很长,之后的时间稳定在了 700 ms 左右,而且这部分计时不包括从磁盘中读取 lora)