Skip to content

Commit

Permalink
sd-webui supports reuse compiled graph (#742)
Browse files Browse the repository at this point in the history
  • Loading branch information
marigoold authored Mar 20, 2024
1 parent 5f068ea commit 82635c3
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 17 deletions.
6 changes: 6 additions & 0 deletions onediff_sd_webui_extensions/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
- [Performance of Community Edition](#performance-of-community-edition)
- [Installation Guide](#installation-guide)
- [Extensions Usage](#extensions-usage)
- [Fast Model Switching](#fast-model-switching)
- [LoRA](#lora)
- [Quantization](#quantization)
- [Contact](#contact)
Expand Down Expand Up @@ -60,6 +61,11 @@ Select `onediff_diffusion_model` from the Script menu, enter a prompt in the tex

![onediff_script](images/onediff_script.jpg)

### Fast Model Switching

When switching models, if the new model has the same structure as the old model, OneDiff will reuse the previously compiled graph, which means you don't need to compile the new model again, which significantly reduces the time it takes you to switch models.

> Note: Please make sure that your PyTorch version is at least 2.1.0, and set the environment variable `ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION` to 0 when starting the sd-webui service. And the feature is not supported for quantized model.
### LoRA

Expand Down
60 changes: 43 additions & 17 deletions onediff_sd_webui_extensions/scripts/onediff.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from onediff_lora import HijackLoraActivate

from onediff.infer_compiler.utils.log_utils import logger
from onediff.infer_compiler.utils.env_var import parse_boolean_from_env
from onediff.optimization.quant_optimizer import (
quantize_model,
varify_can_use_quantization,
Expand All @@ -38,12 +39,6 @@ def generate_graph_path(ckpt_name: str, model_name: str) -> str:
return graph_file_path


def is_compiled(ckpt_name):
global compiled_unet, compiled_ckpt_name

return compiled_unet is not None and compiled_ckpt_name == ckpt_name


def get_calibrate_info(filename: str) -> Union[None, Dict]:
calibration_path = Path(select_checkpoint().filename).parent / filename
if not calibration_path.exists():
Expand Down Expand Up @@ -84,7 +79,9 @@ def compile_unet(
)
compiled_unet = unet_model
if quantization:
calibrate_info = get_calibrate_info(f"{Path(select_checkpoint().filename).stem}_sd_calibrate_info.txt")
calibrate_info = get_calibrate_info(
f"{Path(select_checkpoint().filename).stem}_sd_calibrate_info.txt"
)
compiled_unet = quantize_model(
compiled_unet, inplace=False, calibrate_info=calibrate_info
)
Expand All @@ -108,6 +105,8 @@ def __exit__(self, exc_type, exc_val, exc_tb):


class Script(scripts.Script):
current_type = None

def title(self):
return "onediff_diffusion_model"

Expand Down Expand Up @@ -147,6 +146,34 @@ def ui(self, is_img2img):
def show(self, is_img2img):
return True

def need_compile(self, model):
recompile = False

def get_model_type(model):
return {
"is_sdxl": model.is_sdxl,
"is_sd2": model.is_sd2,
"is_sd1": model.is_sd1,
"is_ssd": model.is_ssd,
}

if self.current_type == None:
recompile = True
else:
for key, v in self.current_type.items():
if v != getattr(model, key):
recompile = True
break

if recompile == True:
self.current_type = get_model_type(model)
elif parse_boolean_from_env("ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION", "1"):
warnings.warn(
f"If you want to reuse the compiled graph, please set environ var `ONEFLOW_MLIR_ENABLE_INFERENCE_OPTIMIZATION` as '0', or the compiled graph will work incorrectly.",
RuntimeWarning,
)
return recompile

def run(self, p, quantization=False):
global compiled_unet, compiled_ckpt_name
current_checkpoint = shared.opts.sd_model_checkpoint
Expand All @@ -156,16 +183,11 @@ def run(self, p, quantization=False):
current_checkpoint + "_quantized" if quantization else current_checkpoint
)

if not is_compiled(ckpt_name):
# graph_file = generate_graph_path(
# ckpt_name, original_diffusion_model.__class__.__name__
# )
# graph_file_device = shared.device
# compile_options = {
# "graph_file_device": graph_file_device,
# "graph_file": graph_file,
# }
# TODO: fix compile_options
if (
quantization
and ckpt_name != compiled_ckpt_name
or self.need_compile(shared.sd_model)
):
compile_options = {}

compiled_unet = compile_unet(
Expand All @@ -174,6 +196,10 @@ def run(self, p, quantization=False):
options=compile_options,
)
compiled_ckpt_name = ckpt_name
else:
logger.info(
f"Model {current_checkpoint} has same sd type of graph type {self.current_type}, skip compile"
)

with UnetCompileCtx(), VaeCompileCtx(), SD21CompileCtx(), HijackLoraActivate():
proc = process_images(p)
Expand Down

0 comments on commit 82635c3

Please sign in to comment.