Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

reuse graph with constant folding for sd webui #782

Merged
merged 5 commits into from
Apr 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 4 additions & 9 deletions onediff_sd_webui_extensions/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,28 +65,23 @@ Select `onediff_diffusion_model` from the Script menu, enter a prompt in the tex

When switching models, if the new model has the same structure as the old model, OneDiff will reuse the previously compiled graph, which means you don't need to compile the new model again, which significantly reduces the time it takes you to switch models.

> Note: Please make sure that your PyTorch version is at least 2.1.0, and set the environment variable `ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION` to 0 when starting the sd-webui service. And the feature is not supported for quantized model.
> Note: Please make sure that your PyTorch version is at least 2.1.0. And the feature is not supported for quantized model.

### LoRA

OneDiff supports the complete functionality related to LoRA. You can use OneDiff-based LoRA just like the native LoRA in sd-webui.

FAQ:


1. Does OneDiff support model types other than LoRA, such as LyCORIS?

If your LoRA model only contains the weights of the Linear module, you can directly use OneDiff without any modifications. But if your LoRA model includes the weights of the Conv module (such as LyCORIS), you need to disable constant folding optimization by setting the env var `ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION` to 0 (which may cause a performance drop of around 4.4%), otherwise the weights of the Conv module may not be loaded into the model.

2. After switching LoRA, should I recompile the model?
1. After switching LoRA, should I recompile the model?

OneDiff supports dynamically switching LoRA without recompiling the model, because the model with LoRA and the one without LoRA share the same parameter pointer, which have already been captured by the static graph.

3. What's the time cost of LoRA fusing?
2. What's the time cost of LoRA fusing?

The initial few times of LoRA fusing may take a bit of time (1~2s), but when stabilized, the time cost is ~700ms.

4. Will LoRA fusing affect the inference efficiency of the model?
3. Will LoRA fusing affect the inference efficiency of the model?

No, the model's inference efficiency remains the same after fusing LoRA as it was before fusing LoRA.

Expand Down
19 changes: 15 additions & 4 deletions onediff_sd_webui_extensions/onediff_lora.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import torch
import oneflow as flow
from onediff.infer_compiler.with_oneflow_compile import DeployableModule


class HijackLoraActivate:
def __init__(self):
def __init__(self, conv_dict=None):
from modules import extra_networks
self.conv_dict = conv_dict

if "lora" in extra_networks.extra_network_registry:
cls_extra_network_lora = type(extra_networks.extra_network_registry["lora"])
Expand All @@ -16,7 +18,7 @@ def __enter__(self):
if self.lora_class is None:
return
self.orig_func = self.lora_class.activate
self.lora_class.activate = hijacked_activate(self.lora_class.activate)
self.lora_class.activate = hijacked_activate(self.lora_class.activate, conv_dict=self.conv_dict)

def __exit__(self, exc_type, exc_val, exc_tb):
if self.lora_class is None:
Expand All @@ -26,7 +28,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
self.orig_func = None


def hijacked_activate(activate_func):
def hijacked_activate(activate_func, *, conv_dict=None):
import networks

if hasattr(activate_func, "_onediff_hijacked"):
Expand All @@ -36,7 +38,7 @@ def activate(self, p, params_list):
activate_func(self, p, params_list)
if isinstance(p.sd_model.model.diffusion_model, DeployableModule):
onediff_sd_model: DeployableModule = p.sd_model.model.diffusion_model
for sub_module in onediff_sd_model.modules():
for name, sub_module in onediff_sd_model.named_modules():
if not isinstance(
sub_module,
(
Expand All @@ -49,5 +51,14 @@ def activate(self, p, params_list):
continue
networks.network_apply_weights(sub_module)

# for LyCORIS cases
if conv_dict is not None and isinstance(sub_module, torch.nn.Conv2d):
target_tensor = conv_dict.get(name + ".weight", None)
if target_tensor is None:
continue
target_tensor.copy_(
flow.utils.tensor.from_torch(sub_module.weight.permute(0, 2, 3, 1))
)

activate._onediff_hijacked = True
return activate
73 changes: 52 additions & 21 deletions onediff_sd_webui_extensions/scripts/onediff.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import os
import re
import warnings
import gradio as gr
from pathlib import Path
from typing import Union, Dict
from collections import defaultdict
import oneflow as flow
import modules.scripts as scripts
import modules.shared as shared
from modules.sd_models import select_checkpoint
Expand Down Expand Up @@ -108,6 +111,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):

class Script(scripts.Script):
current_type = None
convname_dict = None

def title(self):
return "onediff_diffusion_model"
Expand Down Expand Up @@ -148,8 +152,8 @@ def ui(self, is_img2img):
def show(self, is_img2img):
return True

def need_compile(self, model):
recompile = False
def check_model_structure_change(self, model):
is_changed = False

def get_model_type(model):
return {
Expand All @@ -160,21 +164,16 @@ def get_model_type(model):
}

if self.current_type == None:
recompile = True
is_changed = True
else:
for key, v in self.current_type.items():
if v != getattr(model, key):
recompile = True
is_changed = True
break

if recompile == True:
if is_changed == True:
self.current_type = get_model_type(model)
elif parse_boolean_from_env("ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION", "1"):
warnings.warn(
f"If you want to reuse the compiled graph, please set environ var `ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION` as '0', or the compiled graph will work incorrectly.",
RuntimeWarning,
)
return recompile
return is_changed

def run(self, p, quantization=False):
# For OneDiff Community, the input param `quantization` is a HTML string
Expand All @@ -189,11 +188,26 @@ def run(self, p, quantization=False):
current_checkpoint + "_quantized" if quantization else current_checkpoint
)

if (
quantization
and ckpt_name != compiled_ckpt_name
or self.need_compile(shared.sd_model)
):
model_changed = ckpt_name != compiled_ckpt_name
model_structure_changed = self.check_model_structure_change(shared.sd_model)
need_recompile = (quantization and model_changed) or model_structure_changed
if not need_recompile:
logger.info(
f"Model {current_checkpoint} has same sd type of graph type {self.current_type}, skip compile"
)
if model_changed:
# need to transpose conv weights
for k in self.convname_dict:
orig_tensor = original_diffusion_model.get_parameter(k)
target_tensor = self.convname_dict[k]
if target_tensor is None:
need_recompile = True
break
target_tensor.copy_(
flow.utils.tensor.from_torch(orig_tensor.permute(0, 2, 3, 1))
)

if need_recompile:
compile_options = {}

compiled_unet = compile_unet(
Expand All @@ -202,13 +216,30 @@ def run(self, p, quantization=False):
options=compile_options,
)
compiled_ckpt_name = ckpt_name
else:
logger.info(
f"Model {current_checkpoint} has same sd type of graph type {self.current_type}, skip compile"
)
self.convname_dict = None

with UnetCompileCtx(), VaeCompileCtx(), SD21CompileCtx(), HijackLoraActivate():
with UnetCompileCtx(), VaeCompileCtx(), SD21CompileCtx(), HijackLoraActivate(
self.convname_dict
):
proc = process_images(p)

# AutoNHWC will transpose conv weight, which generate a new tensor in graph
# The part is to find the corresponding relationship between the tensors before/after transpose
def convert_var_name(s: str, prefix="variable_transpose_"):
s = re.sub(r"_[0-9]+$", "", s.removeprefix(prefix)).removeprefix("model.")
return s

if not quantization and self.convname_dict is None:
self.convname_dict = {}
run_state = (
compiled_unet._deployable_module_dpl_graph._c_nn_graph.get_runtime_var_states()
)
self.convname_dict = {
convert_var_name(k): v
for k, v in zip(run_state[0], run_state[1])
if k.startswith("variable_")
}
return proc


onediff_do_hijack()