Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sd-webui refactor, and support refiner model #930

Merged
merged 19 commits into from
Jun 17, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions onediff_sd_webui_extensions/compile/__init__.py
marigoold marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from .compile_ldm import SD21CompileCtx
from .compile_utils import get_compiled_graph
from .compile_vae import VaeCompileCtx
from .onediff_compiled_graph import OneDiffCompiledGraph

__all__ = [
"get_compiled_graph",
"SD21CompileCtx",
"VaeCompileCtx",
"OneDiffCompiledGraph",
]
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,16 @@
from ldm.modules.diffusionmodules.openaimodel import ResBlock, UNetModel
from ldm.modules.diffusionmodules.util import GroupNorm32
from modules import shared
from sd_webui_onediff_utils import (

from onediff.infer_compiler import oneflow_compile
from onediff.infer_compiler.backends.oneflow.transform import proxy_class, register

from .sd_webui_onediff_utils import (
CrossAttentionOflow,
GroupNorm32Oflow,
timestep_embedding,
)

from onediff.infer_compiler import oneflow_compile
from onediff.infer_compiler.backends.oneflow.transform import proxy_class, register

__all__ = ["compile_ldm_unet"]


Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,4 @@
import oneflow as flow
from sd_webui_onediff_utils import (
CrossAttentionOflow,
GroupNorm32Oflow,
timestep_embedding,
)
from sgm.modules.attention import (
BasicTransformerBlock,
CrossAttention,
Expand All @@ -15,6 +10,12 @@
from onediff.infer_compiler import oneflow_compile
from onediff.infer_compiler.backends.oneflow.transform import proxy_class, register

from .sd_webui_onediff_utils import (
CrossAttentionOflow,
GroupNorm32Oflow,
timestep_embedding,
)

__all__ = ["compile_sgm_unet"]


Expand Down
68 changes: 68 additions & 0 deletions onediff_sd_webui_extensions/compile/compile_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import warnings
from pathlib import Path
from typing import Dict, Union

from ldm.modules.diffusionmodules.openaimodel import UNetModel as UNetModelLDM
from modules.sd_models import select_checkpoint
from sgm.modules.diffusionmodules.openaimodel import UNetModel as UNetModelSGM
from ui_utils import check_structure_change_and_update

from onediff.optimization.quant_optimizer import (
quantize_model,
varify_can_use_quantization,
)
from onediff.utils import logger

from .compile_ldm import compile_ldm_unet
from .compile_sgm import compile_sgm_unet
from .onediff_compiled_graph import OneDiffCompiledGraph


def compile_unet(
unet_model, quantization=False, *, options=None,
):
if isinstance(unet_model, UNetModelLDM):
compiled_unet = compile_ldm_unet(unet_model, options=options)
elif isinstance(unet_model, UNetModelSGM):
compiled_unet = compile_sgm_unet(unet_model, options=options)
else:
warnings.warn(
f"Unsupported model type: {type(unet_model)} for compilation , skip",
RuntimeWarning,
)
compiled_unet = unet_model
# In OneDiff Community, quantization can be True when called by api
if quantization and varify_can_use_quantization():
calibrate_info = get_calibrate_info(
f"{Path(select_checkpoint().filename).stem}_sd_calibrate_info.txt"
)
compiled_unet = quantize_model(
compiled_unet, inplace=False, calibrate_info=calibrate_info
)
return compiled_unet


def get_calibrate_info(filename: str) -> Union[None, Dict]:
calibration_path = Path(select_checkpoint().filename).parent / filename
if not calibration_path.exists():
return None

logger.info(f"Got calibrate info at {str(calibration_path)}")
calibrate_info = {}
with open(calibration_path, "r") as f:
for line in f.readlines():
line = line.strip()
items = line.split(" ")
calibrate_info[items[0]] = [
float(items[1]),
int(items[2]),
[float(x) for x in items[3].split(",")],
]
return calibrate_info


def get_compiled_graph(sd_model, quantization) -> OneDiffCompiledGraph:
compiled_unet = compile_unet(
sd_model.model.diffusion_model, quantization=quantization
)
return OneDiffCompiledGraph(sd_model, compiled_unet, quantization)
31 changes: 31 additions & 0 deletions onediff_sd_webui_extensions/compile/onediff_compiled_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import dataclasses

import torch
from modules import sd_models_types

from onediff.infer_compiler import DeployableModule


@dataclasses.dataclass
class OneDiffCompiledGraph:
name: str = None
filename: str = None
sha: str = None
eager_module: torch.nn.Module = None
graph_module: DeployableModule = None
quantized: bool = False

def __init__(
self,
sd_model: sd_models_types.WebuiSdModel = None,
graph_module: DeployableModule = None,
quantized=False,
):
if sd_model is None:
return
self.name = sd_model.sd_checkpoint_info.name
self.filename = sd_model.sd_checkpoint_info.filename
self.sha = sd_model.sd_model_hash
self.eager_module = sd_model.model.diffusion_model
self.graph_module = graph_module
self.quantized = quantized
3 changes: 1 addition & 2 deletions onediff_sd_webui_extensions/onediff_hijack.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import compile_ldm
import compile_sgm
import oneflow
from compile import compile_ldm, compile_sgm


# https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/1c0a0c4c26f78c32095ebc7f8af82f5c04fca8c0/modules/sd_hijack_unet.py#L8
Expand Down
138 changes: 137 additions & 1 deletion onediff_sd_webui_extensions/onediff_lora.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
from typing import Any, Mapping

import torch
from modules import sd_models
from modules.sd_hijack_utils import CondFunc
from onediff_shared import onediff_enabled

from onediff.infer_compiler import DeployableModule
from onediff.infer_compiler.backends.oneflow.param_utils import (
Expand Down Expand Up @@ -53,7 +58,138 @@ def activate(self, p, params_list):
continue
networks.network_apply_weights(sub_module)
if isinstance(sub_module, torch.nn.Conv2d):
update_graph_related_tensor(sub_module)
# TODO(WangYi): refine here
try:
update_graph_related_tensor(sub_module)
except:
pass

activate._onediff_hijacked = True
return activate


def onediff_hijack_load_model_weights(
orig_func, model, checkpoint_info: sd_models.CheckpointInfo, state_dict: dict, timer
):
# load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer)
sd_model_hash = checkpoint_info.calculate_shorthash()
import onediff_shared

if onediff_shared.current_unet_graph.sha == sd_model_hash:
model.model.diffusion_model = onediff_shared.current_unet_graph.graph_module
state_dict = {
k: v
for k, v in state_dict.items()
if not k.startswith("model.diffusion_model.")
}

# for stable-diffusion-webui/modules/sd_models.py:load_model_weights model.is_ssd check
state_dict[
"model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight"
] = model.get_parameter(
"model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight"
)
return orig_func(model, checkpoint_info, state_dict, timer)


def onediff_hijack_load_state_dict(
orig_func,
self,
state_dict: Mapping[str, Any],
strict: bool = True,
assign: bool = False,
):
if (
len(state_dict) > 0
and next(iter(state_dict.values())).is_cuda
and next(self.parameters()).is_meta
):
return orig_func(self, state_dict, strict, assign=True)
else:
return orig_func(self, state_dict, strict, assign)


# fmt: off
def onediff_hijaced_LoadStateDictOnMeta___enter__(orig_func, self):
from modules import shared
if shared.cmd_opts.disable_model_loading_ram_optimization:
return

sd = self.state_dict
device = self.device

def load_from_state_dict(original, module, state_dict, prefix, *args, **kwargs):
used_param_keys = []

for name, param in module._parameters.items():
if param is None:
continue

key = prefix + name
sd_param = sd.pop(key, None)
if sd_param is not None:
state_dict[key] = sd_param.to(dtype=self.get_weight_dtype(key))
used_param_keys.append(key)

if param.is_meta:
dtype = sd_param.dtype if sd_param is not None else param.dtype
module._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device, dtype=dtype), requires_grad=param.requires_grad)

for name in module._buffers:
key = prefix + name

sd_param = sd.pop(key, None)
if sd_param is not None:
state_dict[key] = sd_param
used_param_keys.append(key)

original(module, state_dict, prefix, *args, **kwargs)

for key in used_param_keys:
state_dict.pop(key, None)

# def load_state_dict(original, module, state_dict, strict=True):
def load_state_dict(original, module, state_dict, strict=True):
"""torch makes a lot of copies of the dictionary with weights, so just deleting entries from state_dict does not help
because the same values are stored in multiple copies of the dict. The trick used here is to give torch a dict with
all weights on meta device, i.e. deleted, and then it doesn't matter how many copies torch makes.

In _load_from_state_dict, the correct weight will be obtained from a single dict with the right weights (sd).

The dangerous thing about this is if _load_from_state_dict is not called, (if some exotic module overloads
the function and does not call the original) the state dict will just fail to load because weights
would be on the meta device.
"""

if state_dict is sd:
state_dict = {k: v.to(device="meta", dtype=v.dtype) for k, v in state_dict.items()}

# ------------------- DIFF HERE -------------------
# original(module, state_dict, strict=strict)
if len(state_dict) > 0 and next(iter(state_dict.values())).is_cuda and next(module.parameters()).is_meta:
assign = True
else:
assign = False
# orig_func(original, module, state_dict, strict=strict, assign=assign)
original(module, state_dict, strict=strict, assign=assign)

module_load_state_dict = self.replace(torch.nn.Module, 'load_state_dict', lambda *args, **kwargs: load_state_dict(module_load_state_dict, *args, **kwargs))
module_load_from_state_dict = self.replace(torch.nn.Module, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(module_load_from_state_dict, *args, **kwargs))
linear_load_from_state_dict = self.replace(torch.nn.Linear, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(linear_load_from_state_dict, *args, **kwargs))
conv2d_load_from_state_dict = self.replace(torch.nn.Conv2d, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(conv2d_load_from_state_dict, *args, **kwargs))
mha_load_from_state_dict = self.replace(torch.nn.MultiheadAttention, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(mha_load_from_state_dict, *args, **kwargs))
layer_norm_load_from_state_dict = self.replace(torch.nn.LayerNorm, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(layer_norm_load_from_state_dict, *args, **kwargs))
group_norm_load_from_state_dict = self.replace(torch.nn.GroupNorm, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(group_norm_load_from_state_dict, *args, **kwargs))
# fmt: on


CondFunc(
"modules.sd_disable_initialization.LoadStateDictOnMeta.__enter__",
onediff_hijaced_LoadStateDictOnMeta___enter__,
lambda _, *args, **kwargs: onediff_enabled,
)
CondFunc(
"modules.sd_models.load_model_weights",
onediff_hijack_load_model_weights,
lambda _, *args, **kwargs: onediff_enabled,
)
14 changes: 14 additions & 0 deletions onediff_sd_webui_extensions/onediff_shared.py
marigoold marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from typing import Dict

from compile.onediff_compiled_graph import OneDiffCompiledGraph

current_unet_graph = OneDiffCompiledGraph()
current_quantization = False
refiner_dict: Dict[str, str] = dict()
current_unet_type = {
"is_sdxl": False,
"is_sd2": False,
"is_sd1": False,
"is_ssd": False,
}
onediff_enabled = False
Loading
Loading