diff --git a/onediff_sd_webui_extensions/README.md b/onediff_sd_webui_extensions/README.md index 45f6decfc..69456b78a 100644 --- a/onediff_sd_webui_extensions/README.md +++ b/onediff_sd_webui_extensions/README.md @@ -65,7 +65,7 @@ Select `onediff_diffusion_model` from the Script menu, enter a prompt in the tex When switching models, if the new model has the same structure as the old model, OneDiff will reuse the previously compiled graph, which means you don't need to compile the new model again, which significantly reduces the time it takes you to switch models. -> Note: Please make sure that your PyTorch version is at least 2.1.0, and set the environment variable `ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION` to 0 when starting the sd-webui service. And the feature is not supported for quantized model. +> Note: Please make sure that your PyTorch version is at least 2.1.0. And the feature is not supported for quantized model. ### LoRA @@ -73,20 +73,15 @@ OneDiff supports the complete functionality related to LoRA. You can use OneDiff FAQ: - -1. Does OneDiff support model types other than LoRA, such as LyCORIS? - - If your LoRA model only contains the weights of the Linear module, you can directly use OneDiff without any modifications. But if your LoRA model includes the weights of the Conv module (such as LyCORIS), you need to disable constant folding optimization by setting the env var `ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION` to 0 (which may cause a performance drop of around 4.4%), otherwise the weights of the Conv module may not be loaded into the model. - -2. After switching LoRA, should I recompile the model? +1. After switching LoRA, should I recompile the model? OneDiff supports dynamically switching LoRA without recompiling the model, because the model with LoRA and the one without LoRA share the same parameter pointer, which have already been captured by the static graph. -3. What's the time cost of LoRA fusing? +2. What's the time cost of LoRA fusing? The initial few times of LoRA fusing may take a bit of time (1~2s), but when stabilized, the time cost is ~700ms. -4. Will LoRA fusing affect the inference efficiency of the model? +3. Will LoRA fusing affect the inference efficiency of the model? No, the model's inference efficiency remains the same after fusing LoRA as it was before fusing LoRA. diff --git a/onediff_sd_webui_extensions/onediff_lora.py b/onediff_sd_webui_extensions/onediff_lora.py index 4dcfbad02..f4289b096 100644 --- a/onediff_sd_webui_extensions/onediff_lora.py +++ b/onediff_sd_webui_extensions/onediff_lora.py @@ -1,10 +1,12 @@ import torch +import oneflow as flow from onediff.infer_compiler.with_oneflow_compile import DeployableModule class HijackLoraActivate: - def __init__(self): + def __init__(self, conv_dict=None): from modules import extra_networks + self.conv_dict = conv_dict if "lora" in extra_networks.extra_network_registry: cls_extra_network_lora = type(extra_networks.extra_network_registry["lora"]) @@ -16,7 +18,7 @@ def __enter__(self): if self.lora_class is None: return self.orig_func = self.lora_class.activate - self.lora_class.activate = hijacked_activate(self.lora_class.activate) + self.lora_class.activate = hijacked_activate(self.lora_class.activate, conv_dict=self.conv_dict) def __exit__(self, exc_type, exc_val, exc_tb): if self.lora_class is None: @@ -26,7 +28,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.orig_func = None -def hijacked_activate(activate_func): +def hijacked_activate(activate_func, *, conv_dict=None): import networks if hasattr(activate_func, "_onediff_hijacked"): @@ -36,7 +38,7 @@ def activate(self, p, params_list): activate_func(self, p, params_list) if isinstance(p.sd_model.model.diffusion_model, DeployableModule): onediff_sd_model: DeployableModule = p.sd_model.model.diffusion_model - for sub_module in onediff_sd_model.modules(): + for name, sub_module in onediff_sd_model.named_modules(): if not isinstance( sub_module, ( @@ -49,5 +51,14 @@ def activate(self, p, params_list): continue networks.network_apply_weights(sub_module) + # for LyCORIS cases + if conv_dict is not None and isinstance(sub_module, torch.nn.Conv2d): + target_tensor = conv_dict.get(name + ".weight", None) + if target_tensor is None: + continue + target_tensor.copy_( + flow.utils.tensor.from_torch(sub_module.weight.permute(0, 2, 3, 1)) + ) + activate._onediff_hijacked = True return activate diff --git a/onediff_sd_webui_extensions/scripts/onediff.py b/onediff_sd_webui_extensions/scripts/onediff.py index b34927f0e..ff6f073b8 100644 --- a/onediff_sd_webui_extensions/scripts/onediff.py +++ b/onediff_sd_webui_extensions/scripts/onediff.py @@ -1,8 +1,11 @@ import os +import re import warnings import gradio as gr from pathlib import Path from typing import Union, Dict +from collections import defaultdict +import oneflow as flow import modules.scripts as scripts import modules.shared as shared from modules.sd_models import select_checkpoint @@ -108,6 +111,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): class Script(scripts.Script): current_type = None + convname_dict = None def title(self): return "onediff_diffusion_model" @@ -148,8 +152,8 @@ def ui(self, is_img2img): def show(self, is_img2img): return True - def need_compile(self, model): - recompile = False + def check_model_structure_change(self, model): + is_changed = False def get_model_type(model): return { @@ -160,21 +164,16 @@ def get_model_type(model): } if self.current_type == None: - recompile = True + is_changed = True else: for key, v in self.current_type.items(): if v != getattr(model, key): - recompile = True + is_changed = True break - if recompile == True: + if is_changed == True: self.current_type = get_model_type(model) - elif parse_boolean_from_env("ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION", "1"): - warnings.warn( - f"If you want to reuse the compiled graph, please set environ var `ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION` as '0', or the compiled graph will work incorrectly.", - RuntimeWarning, - ) - return recompile + return is_changed def run(self, p, quantization=False): # For OneDiff Community, the input param `quantization` is a HTML string @@ -189,11 +188,26 @@ def run(self, p, quantization=False): current_checkpoint + "_quantized" if quantization else current_checkpoint ) - if ( - quantization - and ckpt_name != compiled_ckpt_name - or self.need_compile(shared.sd_model) - ): + model_changed = ckpt_name != compiled_ckpt_name + model_structure_changed = self.check_model_structure_change(shared.sd_model) + need_recompile = (quantization and model_changed) or model_structure_changed + if not need_recompile: + logger.info( + f"Model {current_checkpoint} has same sd type of graph type {self.current_type}, skip compile" + ) + if model_changed: + # need to transpose conv weights + for k in self.convname_dict: + orig_tensor = original_diffusion_model.get_parameter(k) + target_tensor = self.convname_dict[k] + if target_tensor is None: + need_recompile = True + break + target_tensor.copy_( + flow.utils.tensor.from_torch(orig_tensor.permute(0, 2, 3, 1)) + ) + + if need_recompile: compile_options = {} compiled_unet = compile_unet( @@ -202,13 +216,30 @@ def run(self, p, quantization=False): options=compile_options, ) compiled_ckpt_name = ckpt_name - else: - logger.info( - f"Model {current_checkpoint} has same sd type of graph type {self.current_type}, skip compile" - ) + self.convname_dict = None - with UnetCompileCtx(), VaeCompileCtx(), SD21CompileCtx(), HijackLoraActivate(): + with UnetCompileCtx(), VaeCompileCtx(), SD21CompileCtx(), HijackLoraActivate( + self.convname_dict + ): proc = process_images(p) + + # AutoNHWC will transpose conv weight, which generate a new tensor in graph + # The part is to find the corresponding relationship between the tensors before/after transpose + def convert_var_name(s: str, prefix="variable_transpose_"): + s = re.sub(r"_[0-9]+$", "", s.removeprefix(prefix)).removeprefix("model.") + return s + + if not quantization and self.convname_dict is None: + self.convname_dict = {} + run_state = ( + compiled_unet._deployable_module_dpl_graph._c_nn_graph.get_runtime_var_states() + ) + self.convname_dict = { + convert_var_name(k): v + for k, v in zip(run_state[0], run_state[1]) + if k.startswith("variable_") + } return proc + onediff_do_hijack()