Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

webui support save/load graph, refine UI #825

Merged
merged 17 commits into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions onediff_sd_webui_extensions/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,29 @@ When switching models, if the new model has the same structure as the old model,

> Note: The feature is not supported for quantized model.


### Graph saving and loading

OneDiff supports saving compiled graph to disk and loading graph from disk. In scenarios where recompiling is required after switching model, you can skip the compilation process by loading the compiled graph from the disk, to saving time of model switching.


#### Graph saving

After selecting onediff, a text box named `Saved graph name` will appear at the bottom right. You can input the file name of the compiled graph you want to save here. After generating the image, the compiled graph will be saved in the `stable-diffusion-webui/extensions/onediff_sd_webui_extensions/models/your-compiled-graph-name` path.

![Graph saving](./images/saved_graph_name.jpg)

> Note: When the text box is empty or the file with the specified name already exists, the static graph will not be saved.


#### Graph loading

After selecting onediff, a dropdown menu named `Graph Checkpoints` will appear at the bottom left. Here, you can select the static graph you want to load. This dropdown menu will display all files located in the path `stable-diffusion-webui/extensions/onediff_sd_webui_extensions/models/`.
strint marked this conversation as resolved.
Show resolved Hide resolved

![Graph loading](./images/graph_checkpoints.jpg)
strint marked this conversation as resolved.
Show resolved Hide resolved

> Note: To properly use this feature, please ensure that you have added the `--disable-safe-unpickle` parameter when launching sd-webui.

### LoRA

OneDiff supports the complete functionality related to LoRA. You can use OneDiff-based LoRA just like the native LoRA in sd-webui.
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
58 changes: 27 additions & 31 deletions onediff_sd_webui_extensions/scripts/onediff.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import zipfile
import warnings
import gradio as gr
from pathlib import Path
Expand All @@ -7,15 +8,16 @@
import modules.shared as shared
from modules.sd_models import select_checkpoint
from modules.processing import process_images
from modules.ui_common import create_refresh_button

from ui_utils import hints_message, get_graph_checkpoints, refresh_graph_checkpoints, graph_checkpoints_path
from compile_ldm import compile_ldm_unet, SD21CompileCtx
from compile_sgm import compile_sgm_unet
from compile_vae import VaeCompileCtx
from onediff_lora import HijackLoraActivate
from onediff_hijack import do_hijack as onediff_do_hijack

from onediff.infer_compiler.utils.log_utils import logger
from onediff.infer_compiler.utils.param_utils import get_constant_folding_info
from onediff.optimization.quant_optimizer import (
quantize_model,
varify_can_use_quantization,
Expand Down Expand Up @@ -111,33 +113,15 @@ def ui(self, is_img2img):
The return value should be an array of all components that are used in processing.
Values of those returned components will be passed to run() and process() functions.
"""
with gr.Row():
# TODO: set choices as Tuple[str, str] after the version of gradio specified webui upgrades
graph_checkpoint = gr.Dropdown(label="Graph checkpoints", choices=["None"] + get_graph_checkpoints(), value="None", elem_id="onediff_graph_checkpoint")
refresh_button = create_refresh_button(graph_checkpoint, refresh_graph_checkpoints, lambda: {"choices": ["None"] + get_graph_checkpoints()}, "onediff_refresh_graph")
save_graph_name = gr.Textbox(label="Saved graph name")
if not varify_can_use_quantization():
ret = gr.HTML(
"""
<div style="padding: 20px; border: 1px solid #e0e0e0; border-radius: 5px; background-color: #f9f9f9;">
<div style="font-size: 18px; font-weight: bold; margin-bottom: 15px; color: #31708f;">
Hints Message
</div>
<div style="padding: 10px; border: 1px solid #31708f; border-radius: 5px; background-color: #f9f9f9;">
Hints: Enterprise function is not supported on your system.
</div>
<p style="margin-top: 15px;">
If you need Enterprise Level Support for your system or business, please send an email to
<a href="mailto:business@siliconflow.com" style="color: #31708f; text-decoration: none;">business@siliconflow.com</a>.
<br>
Tell us about your use case, deployment scale, and requirements.
</p>
<p>
<strong>GitHub Issue:</strong>
<a href="https://github.com/siliconflow/onediff/issues" style="color: #31708f; text-decoration: none;">https://github.com/siliconflow/onediff/issues</a>
</p>
</div>
"""
)

else:
ret = gr.components.Checkbox(label="Model Quantization(int8) Speed Up")
return [ret]
gr.HTML(hints_message)
is_quantized = gr.components.Checkbox(label="Model Quantization(int8) Speed Up", visible=varify_can_use_quantization())
return [is_quantized, graph_checkpoint, save_graph_name]

def show(self, is_img2img):
return True
Expand Down Expand Up @@ -165,10 +149,7 @@ def get_model_type(model):
self.current_type = get_model_type(model)
return is_changed

def run(self, p, quantization=False):
# For OneDiff Community, the input param `quantization` is a HTML string
if isinstance(quantization, str):
quantization = False
def run(self, p, quantization=False, graph_checkpoint=None, saved_graph_name=""):

global compiled_unet, compiled_ckpt_name
current_checkpoint = shared.opts.sd_model_checkpoint
Expand All @@ -181,19 +162,34 @@ def run(self, p, quantization=False):
model_changed = ckpt_name != compiled_ckpt_name
model_structure_changed = self.check_model_structure_change(shared.sd_model)
need_recompile = (quantization and model_changed) or model_structure_changed
import ipdb; ipdb.set_trace()
marigoold marked this conversation as resolved.
Show resolved Hide resolved

if need_recompile:
compiled_unet = compile_unet(
original_diffusion_model, quantization=quantization
)
compiled_ckpt_name = ckpt_name

if graph_checkpoint != "None":
try:
compiled_unet.load_graph(graph_checkpoints_path() + f"/{graph_checkpoint}", run_warmup=True)
except zipfile.BadZipFile as e:
print("Load graph failed. Please make sure that the --disable-safe-unpickle parameter is added when starting the webui")
except Exception as e:
print("Load graph failed. Please make sure graph checkpoint has the same sd version (or unet architure) with current checkpoint")
else:
logger.info(
f"Model {current_checkpoint} has same sd type of graph type {self.current_type}, skip compile"
)

with UnetCompileCtx(), VaeCompileCtx(), SD21CompileCtx(), HijackLoraActivate():
proc = process_images(p)

if saved_graph_name != "":
saved_graph_name = graph_checkpoints_path() + f"/{saved_graph_name}"
if not Path(saved_graph_name).exists():
compiled_unet.save_graph(graph_checkpoints_path() + f"/{saved_graph_name}")

return proc


Expand Down
38 changes: 38 additions & 0 deletions onediff_sd_webui_extensions/ui_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from pathlib import Path

hints_message = """
<div style="padding: 20px; border: 1px solid #e0e0e0; border-radius: 5px; background-color: #f9f9f9;">
<div style="font-size: 18px; font-weight: bold; margin-bottom: 15px; color: #31708f;">
Hints Message
</div>
<div style="padding: 10px; border: 1px solid #31708f; border-radius: 5px; background-color: #f9f9f9;">
Hints: Enterprise function is not supported on your system.
</div>
<p style="margin-top: 15px;">
If you need Enterprise Level Support for your system or business, please send an email to
<a href="mailto:business@siliconflow.com" style="color: #31708f; text-decoration: none;">business@siliconflow.com</a>.
<br>
Tell us about your use case, deployment scale, and requirements.
</p>
<p>
<strong>GitHub Issue:</strong>
<a href="https://github.com/siliconflow/onediff/issues" style="color: #31708f; text-decoration: none;">https://github.com/siliconflow/onediff/issues</a>
</p>
</div>
"""

graph_checkpoints = []

def graph_checkpoints_path():
return str(Path(__file__).parent / "models")
strint marked this conversation as resolved.
Show resolved Hide resolved

def get_graph_checkpoints():
global graph_checkpoints
if len(graph_checkpoints) == 0:
refresh_graph_checkpoints()
return graph_checkpoints

def refresh_graph_checkpoints(path: Path = None):
global graph_checkpoints
path = path or graph_checkpoints_path()
graph_checkpoints = [f.stem for f in Path(path).iterdir() if f.is_file()]