Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add vae compile to sd-webui #473

Merged
merged 13 commits into from
Jan 3, 2024
5 changes: 5 additions & 0 deletions onediff_sd_webui_extensions/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@ Updated on DEC 26, 2023. Device: RTX 3090. Resolution: 1024x1024
| --------------- | --------------- | ------------------ | ---------------------- |
| 2.99it/s | 6.40it/s | 6.71it/s | 224.41% |

Time to enerate a 1024x1024 image with sdxl (30 steps) on 3090
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Time to enerate a 1024x1024 image with sdxl (30 steps) on 3090
End2end time(seconds) to generate a 1024x1024 image with SDXL (30 steps) on NVIDIA RTX 3090:

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

| torch(Baseline) | TensorRT-v9.0.1 | onediff(Optimized) | Percentage improvement |
| --------------- | --------------- | ------------------ | ---------------------- |
| 11.03s | 5.55 | 5.29s | 208.51% |

## Installation Guide

It is recommended to create a Python virtual environment in advance. For example `conda create -n sd-webui python=3.10`.
Expand Down
19 changes: 18 additions & 1 deletion onediff_sd_webui_extensions/compile_ldm.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import math
import os
import oneflow as flow
from onediff.infer_compiler import oneflow_compile, register

from ldm.modules.attention import BasicTransformerBlock, CrossAttention
from ldm.modules.diffusionmodules.openaimodel import ResBlock, UNetModel
from ldm.modules.diffusionmodules.util import GroupNorm32
from modules import shared
from sd_webui_onediff_utils import (
CrossAttentionOflow,
GroupNorm32Oflow,
Expand Down Expand Up @@ -59,3 +60,19 @@ def compile_ldm_unet(unet_model, *, use_graph=True, options={}):
if isinstance(module, ResBlock):
module.use_checkpoint = False
return oneflow_compile(unet_model, use_graph=use_graph, options=options)


class SD21CompileCtx(object):
"""to avoid results for NaN when the model is v2-1_768-ema-pruned"""

_var_name = "ONEFLOW_ATTENTION_ALLOW_HALF_PRECISION_ACCUMULATION"

fpzh2011 marked this conversation as resolved.
Show resolved Hide resolved
def __enter__(self):
self._original = os.getenv(self._var_name)
if shared.opts.sd_model_checkpoint.startswith("v2-1"):
os.environ[self._var_name] = "0"

def __exit__(self, exc_type, exc_val, exc_tb):
if self._original is not None:
os.environ[self._var_name] = self._original
return False
38 changes: 38 additions & 0 deletions onediff_sd_webui_extensions/compile_vae.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from modules import shared
from modules.sd_vae_approx import model as get_vae_model, sd_vae_approx_models
from onediff.infer_compiler import oneflow_compile

__all__ = ["VaeCompileCtx"]


compiled_models = {}


class VaeCompileCtx(object):
fpzh2011 marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, use_graph=True, options={}):
self._use_graph = use_graph
self._options = options
# https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/75336dfc84cae280036bc52a6805eb10d9ae30ba/modules/sd_vae_approx.py#L43
self._model_name = (
"vaeapprox-sdxl.pt"
if getattr(shared.sd_model, "is_sdxl", False)
else "model.pt"
)
self._original_model = get_vae_model()

def __enter__(self):
if self._original_model is None:
return
global compiled_models
model = compiled_models.get(self._model_name)
if model is None:
model = oneflow_compile(
fpzh2011 marked this conversation as resolved.
Show resolved Hide resolved
self._original_model, use_graph=self._use_graph, options=self._options
)
compiled_models[self._model_name] = model
sd_vae_approx_models[self._model_name] = model

def __exit__(self, exc_type, exc_val, exc_tb):
if self._original_model is not None:
sd_vae_approx_models[self._model_name] = self._original_model
return False
46 changes: 24 additions & 22 deletions onediff_sd_webui_extensions/scripts/onediff.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
import modules.shared as shared
from modules.processing import process_images

from compile_ldm import compile_ldm_unet, SD21CompileCtx
from compile_sgm import compile_sgm_unet
from compile_ldm import compile_ldm_unet
from compile_vae import VaeCompileCtx

from onediff.optimization.quant_optimizer import (
quantize_model,
Expand All @@ -33,11 +34,7 @@ def is_compiled(ckpt_name):


def compile_unet(
unet_model,
quantization=False,
*,
use_graph=True,
options={},
unet_model, quantization=False, *, use_graph=True, options={},
):
from ldm.modules.diffusionmodules.openaimodel import UNetModel as UNetModelLDM
from sgm.modules.diffusionmodules.openaimodel import UNetModel as UNetModelSGM
Expand All @@ -57,6 +54,22 @@ def compile_unet(
return unet_model


class UnetCompileCtx(object):
"""The unet model is stored in a global variable.
The global variables need to be replaced with compiled_unet before process_images is run,
and then the original model restored so that subsequent reasoning with onediff disabled meets expectations.
"""

def __enter__(self):
self._original_model = shared.sd_model.model.diffusion_model
global compiled_unet
shared.sd_model.model.diffusion_model = compiled_unet

def __exit__(self, exc_type, exc_val, exc_tb):
shared.sd_model.model.diffusion_model = self._original_model
return False


class Script(scripts.Script):
def title(self):
return "onediff_diffusion_model"
Expand All @@ -68,7 +81,7 @@ def ui(self, is_img2img):
"""
if not varify_can_use_quantization():
ret = gr.HTML(
"""
"""
<div style="padding: 20px; border: 1px solid #e0e0e0; border-radius: 5px; background-color: #f9f9f9;">
<div style="font-size: 18px; font-weight: bold; margin-bottom: 15px; color: #31708f;">
Hints Message
Expand All @@ -88,7 +101,7 @@ def ui(self, is_img2img):
</p>
</div>
"""
)
)

else:
ret = gr.components.Checkbox(label="Model Quantization(int8) Speed Up")
Expand All @@ -107,22 +120,11 @@ def run(self, p, quantization=False):
)

if not is_compiled(ckpt_name):
graph_file = generate_graph_path(
ckpt_name, original_diffusion_model.__class__.__name__
)
graph_file_device = shared.device
compile_options = {
"graph_file_device": graph_file_device,
"graph_file": graph_file,
}
compiled_unet = compile_unet(
original_diffusion_model,
quantization=quantization,
options=compile_options,
original_diffusion_model, quantization=quantization,
)
compiled_ckpt_name = ckpt_name

shared.sd_model.model.diffusion_model = compiled_unet
proc = process_images(p)
shared.sd_model.model.diffusion_model = original_diffusion_model
with UnetCompileCtx(), VaeCompileCtx(), SD21CompileCtx():
fpzh2011 marked this conversation as resolved.
Show resolved Hide resolved
proc = process_images(p)
return proc
2 changes: 1 addition & 1 deletion src/onediff/optimization/quant_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
def varify_can_use_quantization():
if not is_quantization_enabled():
message = get_support_message()
logger.error(message)
logger.warn(message)
return False
return True

Expand Down
Loading