Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix unet graph file cache #887

Merged
merged 20 commits into from
May 30, 2024
5 changes: 4 additions & 1 deletion onediff_comfy_nodes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""OneDiff ComfyUI Speedup Module"""
from ._config import is_disable_oneflow_backend
from ._nodes import (
ControlnetSpeedup,
ModelSpeedup,
OneDiffApplyModelBooster,
OneDiffCheckpointLoaderSimple,
Expand All @@ -12,6 +13,7 @@
NODE_CLASS_MAPPINGS = {
"ModelSpeedup": ModelSpeedup,
"VaeSpeedup": VaeSpeedup,
"ControlnetSpeedup": ControlnetSpeedup,
"OneDiffModelBooster": OneDiffApplyModelBooster,
"OneDiffCheckpointLoaderSimple": OneDiffCheckpointLoaderSimple,
"OneDiffControlNetLoader": OneDiffControlNetLoader,
Expand All @@ -20,6 +22,7 @@
NODE_DISPLAY_NAME_MAPPINGS = {
"ModelSpeedup": "Model Speedup",
"VaeSpeedup": "VAE Speedup",
"ControlnetSpeedup": "ControlNet Speedup",
"OneDiffModelBooster": "Apply Model Booster - OneDff",
"OneDiffCheckpointLoaderSimple": "Load Checkpoint - OneDiff",
}
Expand All @@ -37,7 +40,7 @@ def lazy_load_extra_nodes():
update_node_mappings(nodes_torch_compile_booster)

if is_oneflow_available() and not is_disable_oneflow_backend():
from .extras_nodes import nodes_oneflow_booster, nodes_compare
from .extras_nodes import nodes_compare, nodes_oneflow_booster

update_node_mappings(nodes_oneflow_booster)
update_node_mappings(nodes_compare)
Expand Down
27 changes: 16 additions & 11 deletions onediff_comfy_nodes/extras_nodes/nodes_oneflow_booster.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,27 +7,33 @@
from comfy import model_management
from comfy.cli_args import args

from onediff.infer_compiler.backends.oneflow.utils.version_util import is_community_version
from onediff.infer_compiler.backends.oneflow.utils.version_util import (
is_community_version,
)

from ..modules import BoosterScheduler
from ..modules.oneflow import (
BasicOneFlowBoosterExecutor,
DeepcacheBoosterExecutor,
PatchBoosterExecutor,
)
from ..modules.oneflow.config import ONEDIFF_QUANTIZED_OPTIMIZED_MODELS
from ..modules.oneflow.hijack_animatediff import animatediff_hijacker
from ..modules.oneflow.hijack_comfyui_instantid import comfyui_instantid_hijacker
from ..modules.oneflow.hijack_ipadapter_plus import ipadapter_plus_hijacker
from ..modules.oneflow.hijack_model_management import model_management_hijacker
from ..modules.oneflow.hijack_model_patcher import model_patch_hijacker
from ..modules.oneflow.hijack_nodes import nodes_hijacker
from ..modules.oneflow.hijack_samplers import samplers_hijack
from ..modules.oneflow.hijack_comfyui_instantid import comfyui_instantid_hijacker
from ..modules.oneflow.hijack_model_patcher import model_patch_hijacker
from ..modules.oneflow.hijack_utils import comfy_utils_hijack
from ..modules.oneflow import BasicOneFlowBoosterExecutor
from ..modules.oneflow import DeepcacheBoosterExecutor
from ..modules.oneflow import PatchBoosterExecutor

from ..modules.oneflow.utils import OUTPUT_FOLDER, load_graph, save_graph
from ..modules import BoosterScheduler
from ..utils.import_utils import is_onediff_quant_available


if is_onediff_quant_available() and not is_community_version():
from ..modules.oneflow.booster_quantization import OnelineQuantizationBoosterExecutor # type: ignore
from ..modules.oneflow.booster_quantization import (
OnelineQuantizationBoosterExecutor,
) # type: ignore

model_management_hijacker.hijack() # add flow.cuda.empty_cache()
nodes_hijacker.hijack()
Expand All @@ -41,9 +47,8 @@
import comfy_extras.nodes_video_model
from nodes import CheckpointLoaderSimple


# https://github.com/comfyanonymous/ComfyUI/commit/bb4940d837f0cfd338ff64776b084303be066c67#diff-fab3fbd81daf87571b12fb3e4d80fc7d6bbbcf0f3dafed1dbc55d81998d82539L54
if hasattr(args, "dont_upcast_attention") and not args.dont_upcast_attention:
if hasattr(args, "dont_upcast_attention") and not args.dont_upcast_attention:
os.environ["ONEFLOW_ATTENTION_ALLOW_HALF_PRECISION_SCORE_ACCUMULATION_MAX_M"] = "0"


Expand Down
9 changes: 7 additions & 2 deletions onediff_comfy_nodes/modules/oneflow/booster_patch.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
import os
from functools import singledispatchmethod

from comfy.sd import VAE
from comfy.model_patcher import ModelPatcher
from onediff.infer_compiler.backends.oneflow import OneflowDeployableModule as DeployableModule
from comfy.controlnet import ControlLora, ControlNet
from onediff.infer_compiler.backends.oneflow import (
OneflowDeployableModule as DeployableModule,
)


from ..booster_interface import BoosterExecutor

Expand All @@ -18,7 +23,7 @@ def _set_batch_size_patch(self, diff_model: DeployableModule, latent_image):
file_path = diff_model.get_graph_file()
if file_path is None:
return diff_model

file_dir = os.path.dirname(file_path)
file_name = os.path.basename(file_path)
names = file_name.split("_")
Expand Down
17 changes: 12 additions & 5 deletions onediff_comfy_nodes/modules/oneflow/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,38 @@
import sys
from pathlib import Path

from onediff.infer_compiler.backends.oneflow.utils.version_util import is_community_version
from onediff.infer_compiler.backends.oneflow.transform import transform_mgr
from onediff.infer_compiler.backends.oneflow.utils.version_util import (
is_community_version,
)

# disable patch for loading diffusers
transform_mgr.set_load_diffusers_patch(False)

# Set up paths
ONEDIFF_QUANTIZED_OPTIMIZED_MODELS = "onediff_quant"
COMFYUI_ROOT = os.getenv("COMFYUI_ROOT")

custom_nodes_path = os.path.join(COMFYUI_ROOT, "custom_nodes")
infer_compiler_registry_path = os.path.join(os.path.dirname(__file__), "infer_compiler_registry")
infer_compiler_registry_path = os.path.join(
os.path.dirname(__file__), "infer_compiler_registry"
)

# Add paths to sys.path if not already there
if custom_nodes_path not in sys.path:
sys.path.append(custom_nodes_path)

if infer_compiler_registry_path not in sys.path:
sys.path.append(infer_compiler_registry_path)

# infer_compiler_registry/register_comfy
import register_comfy # load plugins

_USE_UNET_INT8 = not is_community_version()
if _USE_UNET_INT8:
# infer_compiler_registry/register_onediff_quant
import register_onediff_quant # load plugins
from folder_paths import (folder_names_and_paths, models_dir,
supported_pt_extensions)
from folder_paths import folder_names_and_paths, models_dir, supported_pt_extensions

unet_int8_model_dir = Path(models_dir) / "unet_int8"
unet_int8_model_dir.mkdir(parents=True, exist_ok=True)
Expand Down
1 change: 0 additions & 1 deletion src/onediff/infer_compiler/backends/oneflow/env_var.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ class OneflowCompileOptions:
max_cached_graph_size: int = 9
graph_file: str = None
graph_file_device: torch.device = None

# Optimization related environment variables
run_graph_by_vm: bool = None
graph_delay_variable_op_execution: bool = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,38 @@ def calculate_model_hash(model):

@cost_time(debug=transform_mgr.debug_mode, message="generate graph file name")
def generate_graph_file_name(file_path, deployable_module, args, kwargs):
if isinstance(file_path, Path):
file_path = str(file_path)

if file_path.endswith(".graph"):
file_path = file_path[:-6]
def _prepare_file_path(file_path):
if isinstance(file_path, Path):
file_path = str(file_path)
if file_path.endswith(".graph"):
file_path = file_path[:-6]
return file_path

def _generate_input_structure_key(args, kwargs):
args_tree = ArgsTree((args, kwargs), gen_name=False, tensor_type=torch.Tensor)
out_lst = []
for v in args_tree.iter_nodes():
if isinstance(v, (int, float, str)):
out_lst.append(str(v))
elif isinstance(v, dict):
out_lst.append("_".join([str(k) for k in v.keys()]))
else:
out_lst.append(type(v).__name__)

args_tree = ArgsTree((args, kwargs), False, tensor_type=torch.Tensor)
count = len([v for v in args_tree.iter_nodes() if isinstance(v, flow.Tensor)])
return hashlib.sha256("_".join(out_lst).encode("utf-8")).hexdigest()

model = deployable_module._deployable_module_model.oneflow_module
def _generate_model_structure_key(deployable_module):
model = deployable_module._deployable_module_model.oneflow_module
model_hash = hashlib.sha256(f"{model}".encode("utf-8")).hexdigest()
return model_hash

cache_key = calculate_model_hash(model) + "_" + flow.__version__
return f"{file_path}_{count}_{cache_key}.graph"
# Convert Path object to string if necessary and remove the .graph extension
file_path = _prepare_file_path(file_path)
input_structure_key = _generate_input_structure_key(args, kwargs)[:6]
model_structure_key = _generate_model_structure_key(deployable_module)[:8]
# Combine cache keys
cache_key = f"{input_structure_key}_{model_structure_key}"
return f"{file_path}_{cache_key}.graph"
ccssu marked this conversation as resolved.
Show resolved Hide resolved


def graph_file_management(func):
Expand All @@ -50,19 +69,10 @@ def wrapper(self, *args, **kwargs):
)

if is_first_load:
setattr(self, "_load_graph_first_run", False)
graph_file = generate_graph_file_name(
graph_file, self, args=args, kwargs=kwargs
)
setattr(self, "_load_graph_first_run", False)
# Avoid graph file conflicts
if importlib.util.find_spec("register_comfy"):
from register_comfy import CrossAttntionStateDictPatch as state_patch

attn2_patch_sum = state_patch.attn2_patch_sum(input_kwargs=kwargs)
if attn2_patch_sum > 0:
graph_file = graph_file.replace(
".graph", f"_attn2_{attn2_patch_sum}.graph"
)

def process_state_dict_before_saving(state_dict: Dict):
nonlocal self, args, kwargs, graph_file
Expand Down Expand Up @@ -97,14 +107,15 @@ def handle_graph_saving():
nonlocal graph_file, compile_options, is_first_load
if not is_first_load:
return
try:
parent_dir = os.path.dirname(graph_file)
if parent_dir != "":
os.makedirs(parent_dir, exist_ok=True)

# Avoid graph file conflicts
if os.path.exists(graph_file):
raise FileExistsError(f"File {graph_file} exists!")
parent_dir = os.path.dirname(graph_file)
if parent_dir != "":
os.makedirs(parent_dir, exist_ok=True)

# Avoid graph file conflicts
if os.path.exists(graph_file):
raise FileExistsError(f"File {graph_file} exists!")
try:

self.save_graph(
graph_file, process_state_dict=process_state_dict_before_saving
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ def _setup_logger(self):
def get_mocked_packages(self):
return self.mocker.mocked_packages

def set_load_diffusers_patch(self, load_diffusers: bool = False):
if not load_diffusers:
self.mocker.mocked_packages.add("register_diffusers")
self.mocker.mocked_packages.add("register_diffusers_enterprise_lite")
ccssu marked this conversation as resolved.
Show resolved Hide resolved

def load_class_proxies_from_packages(self, package_names: List[Union[Path, str]]):
self.logger.debug(f"Loading modules: {package_names}")
for package_name in package_names:
Expand Down