diff --git a/onediff_sd_webui_extensions/README.md b/onediff_sd_webui_extensions/README.md index 8af4250d3..e4a0e3f3a 100644 --- a/onediff_sd_webui_extensions/README.md +++ b/onediff_sd_webui_extensions/README.md @@ -67,6 +67,33 @@ When switching models, if the new model has the same structure as the old model, > Note: The feature is not supported for quantized model. + +### Compiler cache saving and loading + +OneDiff supports saving compiler cache to disk and loading cache from disk. In scenarios where recompiling is required after switching model, you can skip the compilation process by loading the compiler cache from the disk, to saving time of model switching. + +The compiler cache will be saved at `/path/to/your/stable-diffusion-webui/extensions/onediff_sd_webui_extensions/compiler_caches/` by default. If you want to specify the path, you can modify it in webui settings. + +![Path to save compiler cache in Settings](./images/setting_dir_of_compiler_cache.png) + +#### Compiler cache saving + +After selecting onediff, a text box named `Saved cache name` will appear at the bottom right. You can input the file name of the compiler cache you want to save here. After generating the image, the compiler cache will be saved in the `stable-diffusion-webui/extensions/onediff_sd_webui_extensions/compiler_caches/your-compiler-cache-name` path. + +![Compiler caches](./images/saved_cache_name.png) + + +> Note: When the text box is empty or the file with the specified name already exists, the compiler cache will not be saved. + + +#### Compiler cache loading + +After selecting onediff, a dropdown menu named `Compile cache` will appear at the bottom left. Here, you can select the compiler cache you want to load. This dropdown menu will display all files located in the path `stable-diffusion-webui/extensions/onediff_sd_webui_extensions/compiler_caches/`. And click the button on the right side to refresh the `Compile cache` list. + +![Compiler cache loading](./images/compiler_caches.png) + +> Note: To properly use this feature, please ensure that you have added the `--disable-safe-unpickle` parameter when launching sd-webui. + ### LoRA OneDiff supports the complete functionality related to LoRA. You can use OneDiff-based LoRA just like the native LoRA in sd-webui. diff --git a/onediff_sd_webui_extensions/images/compiler_caches.png b/onediff_sd_webui_extensions/images/compiler_caches.png new file mode 100644 index 000000000..955c6a677 Binary files /dev/null and b/onediff_sd_webui_extensions/images/compiler_caches.png differ diff --git a/onediff_sd_webui_extensions/images/saved_cache_name.png b/onediff_sd_webui_extensions/images/saved_cache_name.png new file mode 100644 index 000000000..6b7ea315e Binary files /dev/null and b/onediff_sd_webui_extensions/images/saved_cache_name.png differ diff --git a/onediff_sd_webui_extensions/images/setting_dir_of_compiler_cache.png b/onediff_sd_webui_extensions/images/setting_dir_of_compiler_cache.png new file mode 100644 index 000000000..8fd199cb8 Binary files /dev/null and b/onediff_sd_webui_extensions/images/setting_dir_of_compiler_cache.png differ diff --git a/onediff_sd_webui_extensions/scripts/onediff.py b/onediff_sd_webui_extensions/scripts/onediff.py index e090c23a6..3c7e887cd 100644 --- a/onediff_sd_webui_extensions/scripts/onediff.py +++ b/onediff_sd_webui_extensions/scripts/onediff.py @@ -1,4 +1,5 @@ import os +import zipfile import warnings import gradio as gr from pathlib import Path @@ -7,7 +8,10 @@ import modules.shared as shared from modules.sd_models import select_checkpoint from modules.processing import process_images +from modules.ui_common import create_refresh_button +from modules import script_callbacks +from ui_utils import hints_message, get_all_compiler_caches, refresh_all_compiler_caches, all_compiler_caches_path from compile_ldm import compile_ldm_unet, SD21CompileCtx from compile_sgm import compile_sgm_unet from compile_vae import VaeCompileCtx @@ -15,11 +19,11 @@ from onediff_hijack import do_hijack as onediff_do_hijack from onediff.infer_compiler.utils.log_utils import logger -from onediff.infer_compiler.utils.param_utils import get_constant_folding_info from onediff.optimization.quant_optimizer import ( quantize_model, varify_can_use_quantization, ) +from onediff.infer_compiler.utils.env_var import parse_boolean_from_env from onediff import __version__ as onediff_version from oneflow import __version__ as oneflow_version @@ -113,33 +117,17 @@ def ui(self, is_img2img): The return value should be an array of all components that are used in processing. Values of those returned components will be passed to run() and process() functions. """ + with gr.Row(): + # TODO: set choices as Tuple[str, str] after the version of gradio specified webui upgrades + compiler_cache = gr.Dropdown(label="Compiler caches (Beta)", choices=["None"] + get_all_compiler_caches(), value="None", elem_id="onediff_compiler_cache") + refresh_button = create_refresh_button(compiler_cache, refresh_all_compiler_caches, lambda: {"choices": ["None"] + get_all_compiler_caches()}, "onediff_refresh_compiler_caches") + save_cache_name = gr.Textbox(label="Saved cache name (Beta)") + with gr.Row(): + always_recompile = gr.components.Checkbox(label="always_recompile", visible=parse_boolean_from_env("ONEDIFF_DEBUG")) if not varify_can_use_quantization(): - ret = gr.HTML( - """ -
-
- Hints Message -
-
- Hints: Enterprise function is not supported on your system. -
-

- If you need Enterprise Level Support for your system or business, please send an email to - business@siliconflow.com. -
- Tell us about your use case, deployment scale, and requirements. -

-

- GitHub Issue: - https://github.com/siliconflow/onediff/issues -

-
- """ - ) - - else: - ret = gr.components.Checkbox(label="Model Quantization(int8) Speed Up") - return [ret] + gr.HTML(hints_message) + is_quantized = gr.components.Checkbox(label="Model Quantization(int8) Speed Up", visible=varify_can_use_quantization()) + return [is_quantized, compiler_cache, save_cache_name, always_recompile] def show(self, is_img2img): return True @@ -167,10 +155,7 @@ def get_model_type(model): self.current_type = get_model_type(model) return is_changed - def run(self, p, quantization=False): - # For OneDiff Community, the input param `quantization` is a HTML string - if isinstance(quantization, str): - quantization = False + def run(self, p, quantization=False, compiler_cache=None, saved_cache_name="", always_recompile=False): global compiled_unet, compiled_ckpt_name, is_unet_quantized current_checkpoint = shared.opts.sd_model_checkpoint @@ -183,6 +168,7 @@ def run(self, p, quantization=False): (quantization and ckpt_changed) # always recompile when switching ckpt with 'int8 speed model' enabled or model_changed # always recompile when switching model to another structure or quantization_changed # always recompile when switching model from non-quantized to quantized (and vice versa) + or always_recompile ) is_unet_quantized = quantization @@ -191,6 +177,18 @@ def run(self, p, quantization=False): compiled_unet = compile_unet( original_diffusion_model, quantization=quantization ) + + if compiler_cache != "None": + compiler_cache_path = all_compiler_caches_path() + f"/{compiler_cache}" + if not Path(compiler_cache_path).exists(): + raise FileNotFoundError(f"Cannot find cache {compiler_cache_path}, please make sure it exists") + try: + compiled_unet.load_graph(compiler_cache_path, run_warmup=True) + except zipfile.BadZipFile as e: + raise RuntimeError("Load cache failed. Please make sure that the --disable-safe-unpickle parameter is added when starting the webui") + except Exception as e: + raise RuntimeError("Load cache failed. Please make sure cache has the same sd version (or unet architure) with current checkpoint") + else: logger.info( f"Model {current_checkpoint} has same sd type of graph type {self.current_type}, skip compile" @@ -198,7 +196,23 @@ def run(self, p, quantization=False): with UnetCompileCtx(), VaeCompileCtx(), SD21CompileCtx(), HijackLoraActivate(): proc = process_images(p) + + if saved_cache_name != "": + if not os.access(str(all_compiler_caches_path()), os.W_OK): + raise PermissionError(f"The directory {all_compiler_caches_path()} does not have write permissions, and compiler cache cannot be written to this directory. \ + Please change it in the settings to a directory with write permissions") + if not Path(all_compiler_caches_path()).exists(): + Path(all_compiler_caches_path()).mkdir() + saved_cache_name = all_compiler_caches_path() + f"/{saved_cache_name}" + if not Path(saved_cache_name).exists(): + compiled_unet.save_graph(saved_cache_name) + return proc +def on_ui_settings(): + section = ('onediff', "OneDiff") + shared.opts.add_option("onediff_compiler_caches_path", shared.OptionInfo( + str(Path(__file__).parent.parent / "compiler_caches"), "Directory for onediff compiler caches", section=section)) +script_callbacks.on_ui_settings(on_ui_settings) onediff_do_hijack() diff --git a/onediff_sd_webui_extensions/ui_utils.py b/onediff_sd_webui_extensions/ui_utils.py new file mode 100644 index 000000000..7feea4eaa --- /dev/null +++ b/onediff_sd_webui_extensions/ui_utils.py @@ -0,0 +1,42 @@ +from pathlib import Path + +hints_message = """ +
+
+ Hints Message +
+
+ Hints: Enterprise function is not supported on your system. +
+

+ If you need Enterprise Level Support for your system or business, please send an email to + business@siliconflow.com. +
+ Tell us about your use case, deployment scale, and requirements. +

+

+ GitHub Issue: + https://github.com/siliconflow/onediff/issues +

+
+ """ + +all_compiler_caches = [] + +def all_compiler_caches_path(): + import modules.shared as shared + caches_path = Path(shared.opts.onediff_compiler_caches_path) + if not caches_path.exists(): + caches_path.mkdir(parents=True) + return shared.opts.onediff_compiler_caches_path + +def get_all_compiler_caches(): + global all_compiler_caches + if len(all_compiler_caches) == 0: + refresh_all_compiler_caches() + return all_compiler_caches + +def refresh_all_compiler_caches(path: Path = None): + global all_compiler_caches + path = path or all_compiler_caches_path() + all_compiler_caches = [f.stem for f in Path(path).iterdir() if f.is_file()] \ No newline at end of file diff --git a/src/onediff/infer_compiler/oneflow/deployable_module.py b/src/onediff/infer_compiler/oneflow/deployable_module.py index 8f752ad29..98c005578 100644 --- a/src/onediff/infer_compiler/oneflow/deployable_module.py +++ b/src/onediff/infer_compiler/oneflow/deployable_module.py @@ -6,7 +6,7 @@ from ..utils.oneflow_exec_mode import oneflow_exec_mode, oneflow_exec_mode_enabled from ..utils.args_tree_util import input_output_processor from ..utils.log_utils import logger -from ..utils.param_utils import parse_device, check_device +from ..utils.param_utils import parse_device, check_device, generate_constant_folding_info from ..utils.graph_management_utils import graph_file_management from ..utils.online_quantization_utils import quantize_and_deploy_wrapper from ..utils.options import OneflowCompileOptions @@ -149,6 +149,7 @@ def load_graph(self, file_path, device=None, run_warmup=True, *, state_dict=None self.get_graph().load_graph( file_path, device, run_warmup, state_dict=state_dict ) + generate_constant_folding_info(self) def save_graph(self, file_path, *, process_state_dict=lambda x: x): self.get_graph().save_graph(file_path, process_state_dict=process_state_dict)