diff --git a/onediff_sd_webui_extensions/README.md b/onediff_sd_webui_extensions/README.md index e4a0e3f3a..0e7b14d14 100644 --- a/onediff_sd_webui_extensions/README.md +++ b/onediff_sd_webui_extensions/README.md @@ -4,8 +4,10 @@ - [Installation Guide](#installation-guide) - [Extensions Usage](#extensions-usage) - [Fast Model Switching](#fast-model-switching) + - [Compiler cache saving and loading](#compiler-cache-saving-and-loading) - [LoRA](#lora) - [Quantization](#quantization) +- [Use OneDiff by API](#use-onediff-by-api) - [Contact](#contact) ## Performance of Community Edition diff --git a/onediff_sd_webui_extensions/compile/__init__.py b/onediff_sd_webui_extensions/compile/__init__.py new file mode 100644 index 000000000..90afcaceb --- /dev/null +++ b/onediff_sd_webui_extensions/compile/__init__.py @@ -0,0 +1,11 @@ +from .compile_ldm import SD21CompileCtx +from .compile_utils import get_compiled_graph +from .compile_vae import VaeCompileCtx +from .onediff_compiled_graph import OneDiffCompiledGraph + +__all__ = [ + "get_compiled_graph", + "SD21CompileCtx", + "VaeCompileCtx", + "OneDiffCompiledGraph", +] diff --git a/onediff_sd_webui_extensions/compile_ldm.py b/onediff_sd_webui_extensions/compile/compile_ldm.py similarity index 98% rename from onediff_sd_webui_extensions/compile_ldm.py rename to onediff_sd_webui_extensions/compile/compile_ldm.py index e87f7f696..7b04e16aa 100644 --- a/onediff_sd_webui_extensions/compile_ldm.py +++ b/onediff_sd_webui_extensions/compile/compile_ldm.py @@ -9,15 +9,16 @@ from ldm.modules.diffusionmodules.openaimodel import ResBlock, UNetModel from ldm.modules.diffusionmodules.util import GroupNorm32 from modules import shared -from sd_webui_onediff_utils import ( + +from onediff.infer_compiler import oneflow_compile +from onediff.infer_compiler.backends.oneflow.transform import proxy_class, register + +from .sd_webui_onediff_utils import ( CrossAttentionOflow, GroupNorm32Oflow, timestep_embedding, ) -from onediff.infer_compiler import oneflow_compile -from onediff.infer_compiler.backends.oneflow.transform import proxy_class, register - __all__ = ["compile_ldm_unet"] diff --git a/onediff_sd_webui_extensions/compile_sgm.py b/onediff_sd_webui_extensions/compile/compile_sgm.py similarity index 98% rename from onediff_sd_webui_extensions/compile_sgm.py rename to onediff_sd_webui_extensions/compile/compile_sgm.py index 154b3dc5c..09b86be59 100644 --- a/onediff_sd_webui_extensions/compile_sgm.py +++ b/onediff_sd_webui_extensions/compile/compile_sgm.py @@ -1,9 +1,4 @@ import oneflow as flow -from sd_webui_onediff_utils import ( - CrossAttentionOflow, - GroupNorm32Oflow, - timestep_embedding, -) from sgm.modules.attention import ( BasicTransformerBlock, CrossAttention, @@ -15,6 +10,12 @@ from onediff.infer_compiler import oneflow_compile from onediff.infer_compiler.backends.oneflow.transform import proxy_class, register +from .sd_webui_onediff_utils import ( + CrossAttentionOflow, + GroupNorm32Oflow, + timestep_embedding, +) + __all__ = ["compile_sgm_unet"] diff --git a/onediff_sd_webui_extensions/compile/compile_utils.py b/onediff_sd_webui_extensions/compile/compile_utils.py new file mode 100644 index 000000000..9d39fbc96 --- /dev/null +++ b/onediff_sd_webui_extensions/compile/compile_utils.py @@ -0,0 +1,67 @@ +import warnings +from pathlib import Path +from typing import Dict, Union + +from ldm.modules.diffusionmodules.openaimodel import UNetModel as UNetModelLDM +from modules.sd_models import select_checkpoint +from sgm.modules.diffusionmodules.openaimodel import UNetModel as UNetModelSGM + +from onediff.optimization.quant_optimizer import ( + quantize_model, + varify_can_use_quantization, +) +from onediff.utils import logger + +from .compile_ldm import compile_ldm_unet +from .compile_sgm import compile_sgm_unet +from .onediff_compiled_graph import OneDiffCompiledGraph + + +def compile_unet( + unet_model, quantization=False, *, options=None, +): + if isinstance(unet_model, UNetModelLDM): + compiled_unet = compile_ldm_unet(unet_model, options=options) + elif isinstance(unet_model, UNetModelSGM): + compiled_unet = compile_sgm_unet(unet_model, options=options) + else: + warnings.warn( + f"Unsupported model type: {type(unet_model)} for compilation , skip", + RuntimeWarning, + ) + compiled_unet = unet_model + # In OneDiff Community, quantization can be True when called by api + if quantization and varify_can_use_quantization(): + calibrate_info = get_calibrate_info( + f"{Path(select_checkpoint().filename).stem}_sd_calibrate_info.txt" + ) + compiled_unet = quantize_model( + compiled_unet, inplace=False, calibrate_info=calibrate_info + ) + return compiled_unet + + +def get_calibrate_info(filename: str) -> Union[None, Dict]: + calibration_path = Path(select_checkpoint().filename).parent / filename + if not calibration_path.exists(): + return None + + logger.info(f"Got calibrate info at {str(calibration_path)}") + calibrate_info = {} + with open(calibration_path, "r") as f: + for line in f.readlines(): + line = line.strip() + items = line.split(" ") + calibrate_info[items[0]] = [ + float(items[1]), + int(items[2]), + [float(x) for x in items[3].split(",")], + ] + return calibrate_info + + +def get_compiled_graph(sd_model, quantization) -> OneDiffCompiledGraph: + compiled_unet = compile_unet( + sd_model.model.diffusion_model, quantization=quantization + ) + return OneDiffCompiledGraph(sd_model, compiled_unet, quantization) diff --git a/onediff_sd_webui_extensions/compile_vae.py b/onediff_sd_webui_extensions/compile/compile_vae.py similarity index 100% rename from onediff_sd_webui_extensions/compile_vae.py rename to onediff_sd_webui_extensions/compile/compile_vae.py diff --git a/onediff_sd_webui_extensions/compile/onediff_compiled_graph.py b/onediff_sd_webui_extensions/compile/onediff_compiled_graph.py new file mode 100644 index 000000000..d6a09aca3 --- /dev/null +++ b/onediff_sd_webui_extensions/compile/onediff_compiled_graph.py @@ -0,0 +1,31 @@ +import dataclasses + +import torch +from modules import sd_models_types + +from onediff.infer_compiler import DeployableModule + + +@dataclasses.dataclass +class OneDiffCompiledGraph: + name: str = None + filename: str = None + sha: str = None + eager_module: torch.nn.Module = None + graph_module: DeployableModule = None + quantized: bool = False + + def __init__( + self, + sd_model: sd_models_types.WebuiSdModel = None, + graph_module: DeployableModule = None, + quantized=False, + ): + if sd_model is None: + return + self.name = sd_model.sd_checkpoint_info.name + self.filename = sd_model.sd_checkpoint_info.filename + self.sha = sd_model.sd_model_hash + self.eager_module = sd_model.model.diffusion_model + self.graph_module = graph_module + self.quantized = quantized diff --git a/onediff_sd_webui_extensions/sd_webui_onediff_utils.py b/onediff_sd_webui_extensions/compile/sd_webui_onediff_utils.py similarity index 100% rename from onediff_sd_webui_extensions/sd_webui_onediff_utils.py rename to onediff_sd_webui_extensions/compile/sd_webui_onediff_utils.py diff --git a/onediff_sd_webui_extensions/onediff_hijack.py b/onediff_sd_webui_extensions/onediff_hijack.py index c8da677c6..355180202 100644 --- a/onediff_sd_webui_extensions/onediff_hijack.py +++ b/onediff_sd_webui_extensions/onediff_hijack.py @@ -1,6 +1,11 @@ -import compile_ldm -import compile_sgm +from typing import Any, Mapping + import oneflow +import torch +from compile import compile_ldm, compile_sgm +from modules import sd_models +from modules.sd_hijack_utils import CondFunc +from onediff_shared import onediff_enabled # https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/1c0a0c4c26f78c32095ebc7f8af82f5c04fca8c0/modules/sd_hijack_unet.py#L8 @@ -95,3 +100,130 @@ def undo_hijack(): name="send_model_to_cpu", new_name="__onediff_original_send_model_to_cpu", ) + + +def onediff_hijack_load_model_weights( + orig_func, model, checkpoint_info: sd_models.CheckpointInfo, state_dict: dict, timer +): + # load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer) + sd_model_hash = checkpoint_info.calculate_shorthash() + import onediff_shared + + if onediff_shared.current_unet_graph.sha == sd_model_hash: + model.model.diffusion_model = onediff_shared.current_unet_graph.graph_module + state_dict = { + k: v + for k, v in state_dict.items() + if not k.startswith("model.diffusion_model.") + } + + # for stable-diffusion-webui/modules/sd_models.py:load_model_weights model.is_ssd check + state_dict[ + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight" + ] = model.get_parameter( + "model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight" + ) + return orig_func(model, checkpoint_info, state_dict, timer) + + +def onediff_hijack_load_state_dict( + orig_func, + self, + state_dict: Mapping[str, Any], + strict: bool = True, + assign: bool = False, +): + if ( + len(state_dict) > 0 + and next(iter(state_dict.values())).is_cuda + and next(self.parameters()).is_meta + ): + return orig_func(self, state_dict, strict, assign=True) + else: + return orig_func(self, state_dict, strict, assign) + + +# fmt: off +def onediff_hijaced_LoadStateDictOnMeta___enter__(orig_func, self): + from modules import shared + if shared.cmd_opts.disable_model_loading_ram_optimization: + return + + sd = self.state_dict + device = self.device + + def load_from_state_dict(original, module, state_dict, prefix, *args, **kwargs): + used_param_keys = [] + + for name, param in module._parameters.items(): + if param is None: + continue + + key = prefix + name + sd_param = sd.pop(key, None) + if sd_param is not None: + state_dict[key] = sd_param.to(dtype=self.get_weight_dtype(key)) + used_param_keys.append(key) + + if param.is_meta: + dtype = sd_param.dtype if sd_param is not None else param.dtype + module._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device, dtype=dtype), requires_grad=param.requires_grad) + + for name in module._buffers: + key = prefix + name + + sd_param = sd.pop(key, None) + if sd_param is not None: + state_dict[key] = sd_param + used_param_keys.append(key) + + original(module, state_dict, prefix, *args, **kwargs) + + for key in used_param_keys: + state_dict.pop(key, None) + + # def load_state_dict(original, module, state_dict, strict=True): + def load_state_dict(original, module, state_dict, strict=True): + """torch makes a lot of copies of the dictionary with weights, so just deleting entries from state_dict does not help + because the same values are stored in multiple copies of the dict. The trick used here is to give torch a dict with + all weights on meta device, i.e. deleted, and then it doesn't matter how many copies torch makes. + + In _load_from_state_dict, the correct weight will be obtained from a single dict with the right weights (sd). + + The dangerous thing about this is if _load_from_state_dict is not called, (if some exotic module overloads + the function and does not call the original) the state dict will just fail to load because weights + would be on the meta device. + """ + + if state_dict is sd: + state_dict = {k: v.to(device="meta", dtype=v.dtype) for k, v in state_dict.items()} + + # ------------------- DIFF HERE ------------------- + # original(module, state_dict, strict=strict) + if len(state_dict) > 0 and next(iter(state_dict.values())).is_cuda and next(module.parameters()).is_meta: + assign = True + else: + assign = False + # orig_func(original, module, state_dict, strict=strict, assign=assign) + original(module, state_dict, strict=strict, assign=assign) + + module_load_state_dict = self.replace(torch.nn.Module, 'load_state_dict', lambda *args, **kwargs: load_state_dict(module_load_state_dict, *args, **kwargs)) + module_load_from_state_dict = self.replace(torch.nn.Module, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(module_load_from_state_dict, *args, **kwargs)) + linear_load_from_state_dict = self.replace(torch.nn.Linear, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(linear_load_from_state_dict, *args, **kwargs)) + conv2d_load_from_state_dict = self.replace(torch.nn.Conv2d, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(conv2d_load_from_state_dict, *args, **kwargs)) + mha_load_from_state_dict = self.replace(torch.nn.MultiheadAttention, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(mha_load_from_state_dict, *args, **kwargs)) + layer_norm_load_from_state_dict = self.replace(torch.nn.LayerNorm, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(layer_norm_load_from_state_dict, *args, **kwargs)) + group_norm_load_from_state_dict = self.replace(torch.nn.GroupNorm, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(group_norm_load_from_state_dict, *args, **kwargs)) +# fmt: on + + +CondFunc( + "modules.sd_disable_initialization.LoadStateDictOnMeta.__enter__", + onediff_hijaced_LoadStateDictOnMeta___enter__, + lambda _, *args, **kwargs: onediff_enabled, +) +CondFunc( + "modules.sd_models.load_model_weights", + onediff_hijack_load_model_weights, + lambda _, *args, **kwargs: onediff_enabled, +) diff --git a/onediff_sd_webui_extensions/onediff_lora.py b/onediff_sd_webui_extensions/onediff_lora.py index 0bee88e9d..a1f4da8da 100644 --- a/onediff_sd_webui_extensions/onediff_lora.py +++ b/onediff_sd_webui_extensions/onediff_lora.py @@ -53,7 +53,11 @@ def activate(self, p, params_list): continue networks.network_apply_weights(sub_module) if isinstance(sub_module, torch.nn.Conv2d): - update_graph_related_tensor(sub_module) + # TODO(WangYi): refine here + try: + update_graph_related_tensor(sub_module) + except: + pass activate._onediff_hijacked = True return activate diff --git a/onediff_sd_webui_extensions/onediff_shared.py b/onediff_sd_webui_extensions/onediff_shared.py new file mode 100644 index 000000000..8d9e4cf15 --- /dev/null +++ b/onediff_sd_webui_extensions/onediff_shared.py @@ -0,0 +1,11 @@ +from compile.onediff_compiled_graph import OneDiffCompiledGraph + +current_unet_graph = OneDiffCompiledGraph() +current_quantization = False +current_unet_type = { + "is_sdxl": False, + "is_sd2": False, + "is_sd1": False, + "is_ssd": False, +} +onediff_enabled = False diff --git a/onediff_sd_webui_extensions/scripts/onediff.py b/onediff_sd_webui_extensions/scripts/onediff.py index 5e5766c04..0561469d8 100644 --- a/onediff_sd_webui_extensions/scripts/onediff.py +++ b/onediff_sd_webui_extensions/scripts/onediff.py @@ -1,98 +1,32 @@ -import os -import warnings -import zipfile from pathlib import Path -from typing import Dict, Union import gradio as gr import modules.scripts as scripts +import modules.sd_models as sd_models import modules.shared as shared -from compile_ldm import SD21CompileCtx, compile_ldm_unet -from compile_sgm import compile_sgm_unet -from compile_vae import VaeCompileCtx +import onediff_shared +import oneflow as flow +from compile import SD21CompileCtx, VaeCompileCtx, get_compiled_graph from modules import script_callbacks +from modules.devices import torch_gc from modules.processing import process_images -from modules.sd_models import select_checkpoint from modules.ui_common import create_refresh_button from onediff_hijack import do_hijack as onediff_do_hijack from onediff_lora import HijackLoraActivate -from oneflow import __version__ as oneflow_version from ui_utils import ( - all_compiler_caches_path, + check_structure_change_and_update, get_all_compiler_caches, hints_message, + load_graph, + onediff_enabled, refresh_all_compiler_caches, + save_graph, ) -from onediff import __version__ as onediff_version -from onediff.optimization.quant_optimizer import ( - quantize_model, - varify_can_use_quantization, -) +from onediff.optimization.quant_optimizer import varify_can_use_quantization from onediff.utils import logger, parse_boolean_from_env """oneflow_compiled UNetModel""" -compiled_unet = None -is_unet_quantized = False -compiled_ckpt_name = None - - -def generate_graph_path(ckpt_name: str, model_name: str) -> str: - base_output_dir = shared.opts.outdir_samples or shared.opts.outdir_txt2img_samples - save_ckpt_graphs_path = os.path.join(base_output_dir, "graphs", ckpt_name) - os.makedirs(save_ckpt_graphs_path, exist_ok=True) - - file_name = f"{model_name}_graph_{onediff_version}_oneflow_{oneflow_version}" - - graph_file_path = os.path.join(save_ckpt_graphs_path, file_name) - - return graph_file_path - - -def get_calibrate_info(filename: str) -> Union[None, Dict]: - calibration_path = Path(select_checkpoint().filename).parent / filename - if not calibration_path.exists(): - return None - - logger.info(f"Got calibrate info at {str(calibration_path)}") - calibrate_info = {} - with open(calibration_path, "r") as f: - for line in f.readlines(): - line = line.strip() - items = line.split(" ") - calibrate_info[items[0]] = [ - float(items[1]), - int(items[2]), - [float(x) for x in items[3].split(",")], - ] - return calibrate_info - - -def compile_unet( - unet_model, quantization=False, *, options=None, -): - from ldm.modules.diffusionmodules.openaimodel import UNetModel as UNetModelLDM - from sgm.modules.diffusionmodules.openaimodel import UNetModel as UNetModelSGM - - if isinstance(unet_model, UNetModelLDM): - compiled_unet = compile_ldm_unet(unet_model, options=options) - elif isinstance(unet_model, UNetModelSGM): - compiled_unet = compile_sgm_unet(unet_model, options=options) - else: - warnings.warn( - f"Unsupported model type: {type(unet_model)} for compilation , skip", - RuntimeWarning, - ) - compiled_unet = unet_model - # In OneDiff Community, quantization can be True when called by api - if quantization and varify_can_use_quantization(): - calibrate_info = get_calibrate_info( - f"{Path(select_checkpoint().filename).stem}_sd_calibrate_info.txt" - ) - compiled_unet = quantize_model( - compiled_unet, inplace=False, calibrate_info=calibrate_info - ) - return compiled_unet class UnetCompileCtx(object): @@ -103,8 +37,9 @@ class UnetCompileCtx(object): def __enter__(self): self._original_model = shared.sd_model.model.diffusion_model - global compiled_unet - shared.sd_model.model.diffusion_model = compiled_unet + shared.sd_model.model.diffusion_model = ( + onediff_shared.current_unet_graph.graph_module + ) def __exit__(self, exc_type, exc_val, exc_tb): shared.sd_model.model.diffusion_model = self._original_model @@ -112,16 +47,10 @@ def __exit__(self, exc_type, exc_val, exc_tb): class Script(scripts.Script): - current_type = None - def title(self): return "onediff_diffusion_model" def ui(self, is_img2img): - """this function should create gradio UI elements. See https://gradio.app/docs/#components - The return value should be an array of all components that are used in processing. - Values of those returned components will be passed to run() and process() functions. - """ with gr.Row(): # TODO: set choices as Tuple[str, str] after the version of gradio specified webui upgrades compiler_cache = gr.Dropdown( @@ -142,7 +71,11 @@ def ui(self, is_img2img): label="always_recompile", visible=parse_boolean_from_env("ONEDIFF_DEBUG"), ) - gr.HTML(hints_message, elem_id="hintMessage", visible=not varify_can_use_quantization()) + gr.HTML( + hints_message, + elem_id="hintMessage", + visible=not varify_can_use_quantization(), + ) is_quantized = gr.components.Checkbox( label="Model Quantization(int8) Speed Up", visible=varify_can_use_quantization(), @@ -152,29 +85,6 @@ def ui(self, is_img2img): def show(self, is_img2img): return True - def check_model_change(self, model): - is_changed = False - - def get_model_type(model): - return { - "is_sdxl": model.is_sdxl, - "is_sd2": model.is_sd2, - "is_sd1": model.is_sd1, - "is_ssd": model.is_ssd, - } - - if self.current_type is None: - is_changed = True - else: - for key, v in self.current_type.items(): - if v != getattr(model, key): - is_changed = True - break - - if is_changed is True: - self.current_type = get_model_type(model) - return is_changed - def run( self, p, @@ -183,68 +93,52 @@ def run( saved_cache_name="", always_recompile=False, ): - - global compiled_unet, compiled_ckpt_name, is_unet_quantized - current_checkpoint = shared.opts.sd_model_checkpoint - original_diffusion_model = shared.sd_model.model.diffusion_model - - ckpt_changed = current_checkpoint != compiled_ckpt_name - model_changed = self.check_model_change(shared.sd_model) - quantization_changed = quantization != is_unet_quantized + # restore checkpoint_info from refiner to base model if necessary + if ( + sd_models.checkpoint_aliases.get( + p.override_settings.get("sd_model_checkpoint") + ) + is None + ): + p.override_settings.pop("sd_model_checkpoint", None) + sd_models.reload_model_weights() + torch_gc() + flow.cuda.empty_cache() + + current_checkpoint_name = shared.sd_model.sd_checkpoint_info.name + ckpt_changed = ( + shared.sd_model.sd_checkpoint_info.name + != onediff_shared.current_unet_graph.name + ) + structure_changed = check_structure_change_and_update( + onediff_shared.current_unet_type, shared.sd_model + ) + quantization_changed = ( + quantization != onediff_shared.current_unet_graph.quantized + ) need_recompile = ( ( quantization and ckpt_changed ) # always recompile when switching ckpt with 'int8 speed model' enabled - or model_changed # always recompile when switching model to another structure + or structure_changed # always recompile when switching model to another structure or quantization_changed # always recompile when switching model from non-quantized to quantized (and vice versa) or always_recompile ) - - is_unet_quantized = quantization - compiled_ckpt_name = current_checkpoint if need_recompile: - compiled_unet = compile_unet( - original_diffusion_model, quantization=quantization + onediff_shared.current_unet_graph = get_compiled_graph( + shared.sd_model, quantization ) - - # Due to the version of gradio compatible with sd-webui, the CompilerCache dropdown box always returns a string - if compiler_cache not in [None, "None"]: - compiler_cache_path = all_compiler_caches_path() + f"/{compiler_cache}" - if not Path(compiler_cache_path).exists(): - raise FileNotFoundError( - f"Cannot find cache {compiler_cache_path}, please make sure it exists" - ) - try: - compiled_unet.load_graph(compiler_cache_path, run_warmup=True) - except zipfile.BadZipFile: - raise RuntimeError( - "Load cache failed. Please make sure that the --disable-safe-unpickle parameter is added when starting the webui" - ) - except Exception as e: - raise RuntimeError( - f"Load cache failed ({e}). Please make sure cache has the same sd version (or unet architure) with current checkpoint" - ) - + load_graph(onediff_shared.current_unet_graph, compiler_cache) else: logger.info( - f"Model {current_checkpoint} has same sd type of graph type {self.current_type}, skip compile" + f"Model {current_checkpoint_name} has same sd type of graph type {onediff_shared.current_unet_type}, skip compile" ) - with UnetCompileCtx(), VaeCompileCtx(), SD21CompileCtx(), HijackLoraActivate(): + with UnetCompileCtx(), VaeCompileCtx(), SD21CompileCtx(), HijackLoraActivate(), onediff_enabled(): proc = process_images(p) - - if saved_cache_name != "": - if not os.access(str(all_compiler_caches_path()), os.W_OK): - raise PermissionError( - f"The directory {all_compiler_caches_path()} does not have write permissions, and compiler cache cannot be written to this directory. \ - Please change it in the settings to a directory with write permissions" - ) - if not Path(all_compiler_caches_path()).exists(): - Path(all_compiler_caches_path()).mkdir() - saved_cache_name = all_compiler_caches_path() + f"/{saved_cache_name}" - if not Path(saved_cache_name).exists(): - compiled_unet.save_graph(saved_cache_name) - + save_graph(onediff_shared.current_unet_graph, saved_cache_name) + torch_gc() + flow.cuda.empty_cache() return proc @@ -260,5 +154,10 @@ def on_ui_settings(): ) +def cfg_denoisers_callback(params): + pass + + script_callbacks.on_ui_settings(on_ui_settings) +# script_callbacks.on_cfg_denoiser(cfg_denoisers_callback) onediff_do_hijack() diff --git a/onediff_sd_webui_extensions/ui_utils.py b/onediff_sd_webui_extensions/ui_utils.py index 7e442be4a..bdb875a38 100644 --- a/onediff_sd_webui_extensions/ui_utils.py +++ b/onediff_sd_webui_extensions/ui_utils.py @@ -1,7 +1,15 @@ +import os +from contextlib import contextmanager from pathlib import Path from textwrap import dedent +from zipfile import BadZipFile -hints_message = dedent("""\ +import onediff_shared + +from onediff.infer_compiler import DeployableModule + +hints_message = dedent( + """\