diff --git a/onediff_comfy_nodes/README.md b/onediff_comfy_nodes/README.md index 80a532f5a..789a4322f 100644 --- a/onediff_comfy_nodes/README.md +++ b/onediff_comfy_nodes/README.md @@ -137,9 +137,12 @@ This example demonstrates how to utilize LoRAs. You have the flexibility to modi ### ControlNet -While there is an example demonstrating OpenPose ControlNet, it's important to note that OneDiff seamlessly supports a wide range of ControlNet types, including depth mapping, canny, and more. +> doc link: [ControlNet](https://github.com/siliconflow/onediff/tree/main/onediff_comfy_nodes/workflows/ControlNet) -[ControlNet Speedup](workflows/model-speedup-controlnet.png) + +While there is an example demonstrating OpenPose ControlNet, it's important to note that OneDiff seamlessly supports a wide range of ControlNet types, including depth mapping, canny, and more. + +[ControlNet Speedup](workflows/ControlNet/controlnet_onediff.png) ### SVD diff --git a/onediff_comfy_nodes/__init__.py b/onediff_comfy_nodes/__init__.py index 88ed2801b..ec6d5e2aa 100644 --- a/onediff_comfy_nodes/__init__.py +++ b/onediff_comfy_nodes/__init__.py @@ -7,11 +7,11 @@ VaeSpeedup, VaeGraphLoader, VaeGraphSaver, - ControlNetSpeedup, SVDSpeedup, ModuleDeepCacheSpeedup, OneDiffCheckpointLoaderSimple, OneDiffControlNetLoader, + OneDiffDeepCacheCheckpointLoaderSimple, ) from ._compare_node import CompareModel, ShowImageDiff @@ -25,11 +25,11 @@ "VaeSpeedup": VaeSpeedup, "VaeGraphSaver": VaeGraphSaver, "VaeGraphLoader": VaeGraphLoader, - "ControlNetSpeedup": ControlNetSpeedup, "SVDSpeedup": SVDSpeedup, "ModuleDeepCacheSpeedup": ModuleDeepCacheSpeedup, "OneDiffCheckpointLoaderSimple": OneDiffCheckpointLoaderSimple, "OneDiffControlNetLoader": OneDiffControlNetLoader, + "OneDiffDeepCacheCheckpointLoaderSimple": OneDiffDeepCacheCheckpointLoaderSimple, } NODE_DISPLAY_NAME_MAPPINGS = { @@ -41,11 +41,11 @@ "VaeSpeedup": "VAE Speedup", "VaeGraphLoader": "VAE Graph Loader", "VaeGraphSaver": "VAE Graph Saver", - "ControlNetSpeedup": "ControlNet Speedup", "SVDSpeedup": "SVD Speedup", "ModuleDeepCacheSpeedup": "Model DeepCache Speedup", "OneDiffCheckpointLoaderSimple": "Load Checkpoint - OneDiff", "OneDiffControlNetLoader": "Load ControlNet Model - OneDiff", + "OneDiffDeepCacheCheckpointLoaderSimple": "Load Checkpoint - OneDiff DeepCache", } diff --git a/onediff_comfy_nodes/_nodes.py b/onediff_comfy_nodes/_nodes.py index 12aa257e4..aa4cd727e 100644 --- a/onediff_comfy_nodes/_nodes.py +++ b/onediff_comfy_nodes/_nodes.py @@ -30,8 +30,11 @@ from .utils.loader_sample_tools import compoile_unet, quantize_unet from .utils.graph_path import generate_graph_path from .modules.hijack_model_management import model_management_hijacker +from .modules.hijack_nodes import nodes_hijacker +from .utils.deep_cache_speedup import deep_cache_speedup model_management_hijacker.hijack() # add flow.cuda.empty_cache() +nodes_hijacker.hijack() __all__ = [ @@ -290,44 +293,6 @@ def save_graph(self, images, vae, filename_prefix): return {} -class ControlNetSpeedup: - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "control_net": ("CONTROL_NET",), - "static_mode": (["enable", "disable"],), - } - } - - RETURN_TYPES = ("CONTROL_NET",) - RETURN_NAMES = ("control_net",) - FUNCTION = "apply_controlnet" - - CATEGORY = "OneDiff" - - def apply_controlnet(self, control_net, static_mode): - if static_mode == "enable": - from comfy.controlnet import ControlNet, ControlLora - from .modules.onediff_controlnet import ( - OneDiffControlNet, - OneDiffControlLora, - ) - - if isinstance(control_net, ControlLora): - control_net = OneDiffControlLora.from_controllora(control_net) - return (control_net,) - elif isinstance(control_net, ControlNet): - control_net = OneDiffControlNet.from_controlnet(control_net) - return (control_net,) - else: - raise TypeError( - f"control_net must be ControlNet or ControlLora, got {type(control_net)}" - ) - else: - return (control_net,) - - class Quant8Model: @classmethod def INPUT_TYPES(s): @@ -423,112 +388,16 @@ def deep_cache_convert( start_step, end_step, ): - use_graph = static_mode == "enable" - - offload_device = model_management.unet_offload_device() - oneflow_model = OneFlowDeepCacheSpeedUpModelPatcher( - model.model, - load_device=model_management.get_torch_device(), - offload_device=offload_device, + return deep_cache_speedup( + model=model, + use_graph=(static_mode == "enable"), + cache_interval=cache_interval, cache_layer_id=cache_layer_id, cache_block_id=cache_block_id, - use_graph=use_graph, + start_step=start_step, + end_step=end_step, ) - current_t = -1 - current_step = -1 - cache_h = None - - def apply_model(model_function, kwargs): - nonlocal current_t, current_step, cache_h - - xa = kwargs["input"] - t = kwargs["timestep"] - c_concat = kwargs["c"].get("c_concat", None) - c_crossattn = kwargs["c"].get("c_crossattn", None) - y = kwargs["c"].get("y", None) - control = kwargs["c"].get("control", None) - transformer_options = kwargs["c"].get("transformer_options", None) - - # https://github.com/comfyanonymous/ComfyUI/blob/629e4c552cc30a75d2756cbff8095640af3af163/comfy/model_base.py#L51-L69 - sigma = t - xc = oneflow_model.model.model_sampling.calculate_input(sigma, xa) - if c_concat is not None: - xc = torch.cat([xc] + [c_concat], dim=1) - - context = c_crossattn - dtype = oneflow_model.model.get_dtype() - xc = xc.to(dtype) - t = oneflow_model.model.model_sampling.timestep(t).float() - context = context.to(dtype) - extra_conds = {} - for o in kwargs: - extra = kwargs[o] - if hasattr(extra, "to"): - extra = extra.to(dtype) - extra_conds[o] = extra - - x = xc - timesteps = t - y = None if y is None else y.to(dtype) - transformer_options["original_shape"] = list(x.shape) - transformer_options["current_index"] = 0 - transformer_patches = transformer_options.get("patches", {}) - """ - Apply the model to an input batch. - :param x: an [N x C x ...] Tensor of inputs. - :param timesteps: a 1-D batch of timesteps. - :param context: conditioning plugged in via crossattn - :param y: an [N] Tensor of labels, if class-conditional. - :return: an [N x C x ...] Tensor of outputs. - """ - - # reference https://gist.github.com/laksjdjf/435c512bc19636e9c9af4ee7bea9eb86 - if t[0].item() > current_t: - current_step = -1 - - current_t = t[0].item() - apply = 1000 - end_step <= current_t <= 1000 - start_step # t is 999->0 - - if apply: - current_step += 1 - else: - current_step = -1 - current_t = t[0].item() - - is_slow_step = current_step % cache_interval == 0 and apply - - model_output = None - if is_slow_step: - cache_h = None - model_output, cache_h = oneflow_model.deep_cache_unet( - x, - timesteps, - context, - y, - control, - transformer_options, - **extra_conds, - ) - else: - model_output, cache_h = oneflow_model.fast_deep_cache_unet( - x, - cache_h, - timesteps, - context, - y, - control, - transformer_options, - **extra_conds, - ) - - return oneflow_model.model.model_sampling.calculate_denoised( - sigma, model_output, xa - ) - - oneflow_model.set_model_unet_function_wrapper(apply_model) - return (oneflow_model,) - from nodes import CheckpointLoaderSimple, ControlNetLoader from comfy.controlnet import ControlLora, ControlNet @@ -542,19 +411,34 @@ class OneDiffControlNetLoader(ControlNetLoader): def onediff_load_controlnet(self, control_net_name): controlnet = super().load_controlnet(control_net_name)[0] + load_device = model_management.get_torch_device() + + def gen_compile_options(model): + graph_file = generate_graph_path(control_net_name, model) + return { + "graph_file": graph_file, + "graph_file_device": load_device, + } + if isinstance(controlnet, ControlLora): - controlnet = OneDiffControlLora.from_controllora(controlnet) + controlnet = OneDiffControlLora.from_controllora( + controlnet, gen_compile_options=gen_compile_options + ) return (controlnet,) elif isinstance(controlnet, ControlNet): control_model = controlnet.control_model - control_model = control_model.to(model_management.get_torch_device()) - controlnet.control_model = oneflow_compile(control_model) + compile_options = gen_compile_options(control_model) + control_model = control_model.to(load_device) + controlnet.control_model = oneflow_compile( + control_model, options=compile_options + ) return (controlnet,) else: - print(f"Warning: {type(controlnet)=} is not ControlLora or ControlNet") + print("\033[1;31;40m Warning: {type(controlnet)=} is not ControlLora or ControlNet \033[0m") return (controlnet,) + class OneDiffCheckpointLoaderSimple(CheckpointLoaderSimple): @classmethod def INPUT_TYPES(s): @@ -596,6 +480,105 @@ def onediff_load_checkpoint( return modelpatcher, clip, vae +class OneDiffDeepCacheCheckpointLoaderSimple(CheckpointLoaderSimple): + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "ckpt_name": (folder_paths.get_filename_list("checkpoints"),), + "vae_speedup": (["disable", "enable"],), + "static_mode": (["enable", "disable"],), + "cache_interval": ( + "INT", + { + "default": 3, + "min": 1, + "max": 1000, + "step": 1, + "display": "number", + }, + ), + "cache_layer_id": ( + "INT", + {"default": 0, "min": 0, "max": 12, "step": 1, "display": "number"}, + ), + "cache_block_id": ( + "INT", + {"default": 1, "min": 0, "max": 12, "step": 1, "display": "number"}, + ), + "start_step": ( + "INT", + { + "default": 0, + "min": 0, + "max": 1000, + "step": 1, + "display": "number", + }, + ), + "end_step": ( + "INT", + { + "default": 1000, + "min": 0, + "max": 1000, + "step": 0.1, + }, + ), + } + } + + CATEGORY = "OneDiff/Loaders" + FUNCTION = "onediff_load_checkpoint" + + def onediff_load_checkpoint( + self, + ckpt_name, + vae_speedup, + output_vae=True, + output_clip=True, + static_mode="enable", + cache_interval=3, + cache_layer_id=0, + cache_block_id=1, + start_step=0, + end_step=1000, + ): + # CheckpointLoaderSimple.load_checkpoint + modelpatcher, clip, vae = self.load_checkpoint( + ckpt_name, output_vae, output_clip + ) + + def gen_compile_options(model): + # cache_key = f'{cache_interval}_{cache_layer_id}_{cache_block_id}_{start_step}_{end_step}' + graph_file = generate_graph_path(ckpt_name, model) + return { + "graph_file": graph_file, + "graph_file_device": model_management.get_torch_device(), + } + + if vae_speedup == "enable": + vae.first_stage_model = oneflow_compile( + vae.first_stage_model, + use_graph=True, + options=gen_compile_options(vae.first_stage_model), + ) + + modelpatcher = deep_cache_speedup( + model=modelpatcher, + use_graph=(static_mode == "enable"), + cache_interval=cache_interval, + cache_layer_id=cache_layer_id, + cache_block_id=cache_block_id, + start_step=start_step, + end_step=end_step, + gen_compile_options=gen_compile_options, + )[0] + # set inplace update + modelpatcher.weight_inplace_update = True + return modelpatcher, clip, vae + + class OneDiffQuantCheckpointLoaderSimple(CheckpointLoaderSimple): @classmethod def INPUT_TYPES(s): diff --git a/onediff_comfy_nodes/modules/hijack_nodes.py b/onediff_comfy_nodes/modules/hijack_nodes.py new file mode 100644 index 000000000..cb066da29 --- /dev/null +++ b/onediff_comfy_nodes/modules/hijack_nodes.py @@ -0,0 +1,100 @@ +from nodes import ControlNetApply, ControlNetApplyAdvanced +from .sd_hijack_utils import Hijacker +from .onediff_controlnet import OneDiffControlLora + + +def apply_controlnet_base(orig_func, self, conditioning, control_net, image, strength): + if strength == 0: + return (conditioning,) + + c = [] + control_hint = image.movedim(-1, 1) + for t in conditioning: + n = [t[0], t[1].copy()] + if not hasattr(self, "_c_net"): + self._c_net = control_net.copy() + + c_net = self._c_net.set_cond_hint(control_hint, strength) + + if "control" in t[1]: + c_net.set_previous_controlnet(t[1]["control"]) + n[1]["control"] = c_net + n[1]["control_apply_to_uncond"] = True + c.append(n) + return (c,) + + +def apply_controlnet_cond_func_base( + orig_func, self, conditioning, control_net, image, strength +): + return isinstance(control_net, OneDiffControlLora) + + +def apply_controlnet_advanced( + orig_func, + self, + positive, + negative, + control_net, + image, + strength, + start_percent, + end_percent, +): + if strength == 0: + return (positive, negative) + control_hint = image.movedim(-1, 1) + cnets = {} + + out = [] + for conditioning in [positive, negative]: + c = [] + for t in conditioning: + d = t[1].copy() + + prev_cnet = d.get("control", None) + if prev_cnet in cnets: + c_net = cnets[prev_cnet] + else: + if not hasattr(self, "_c_net"): + print("creating new cnet") + self._c_net = control_net.copy() + c_net = self._c_net.set_cond_hint( + control_hint, strength, (start_percent, end_percent) + ) + c_net.set_previous_controlnet(prev_cnet) + cnets[prev_cnet] = c_net + + d["control"] = c_net + d["control_apply_to_uncond"] = False + n = [t[0], d] + c.append(n) + out.append(c) + return (out[0], out[1]) + + +def apply_controlnet_cond_func_advanced( + orig_func, + self, + positive, + negative, + control_net, + image, + strength, + start_percent, + end_percent, +): + return isinstance(control_net, OneDiffControlLora) + + +nodes_hijacker = Hijacker() +nodes_hijacker.register( + orig_func=ControlNetApply.apply_controlnet, + sub_func=apply_controlnet_base, + cond_func=apply_controlnet_cond_func_base, +) +nodes_hijacker.register( + orig_func=ControlNetApplyAdvanced.apply_controlnet, + sub_func=apply_controlnet_advanced, + cond_func=apply_controlnet_cond_func_advanced, +) diff --git a/onediff_comfy_nodes/modules/onediff_controlnet.py b/onediff_comfy_nodes/modules/onediff_controlnet.py index c138efb8f..87e9e5125 100644 --- a/onediff_comfy_nodes/modules/onediff_controlnet.py +++ b/onediff_comfy_nodes/modules/onediff_controlnet.py @@ -34,7 +34,9 @@ def _set_attr_of(obj, attr, value): class OneDiffControlLora(ControlLora): @classmethod - def from_controllora(cls, controlnet: ControlLora): + def from_controllora( + cls, controlnet: ControlLora, *, gen_compile_options: callable = None + ): c = cls( controlnet.control_weights, global_average_pooling=controlnet.global_average_pooling, @@ -42,6 +44,7 @@ def from_controllora(cls, controlnet: ControlLora): ) controlnet.copy_to(c) c._oneflow_model = None + c.gen_compile_options = gen_compile_options return c def pre_run(self, model, percent_to_timestep_function): @@ -75,7 +78,15 @@ class control_lora_ops(ControlLoraOps, comfy.ops.manual_cast): self.control_model.to(dtype) self.control_model.to(comfy.model_management.get_torch_device()) - self._oneflow_model = oneflow_compile(self.control_model) + compile_options = ( + self.gen_compile_options(self.control_model) + if self.gen_compile_options is not None + else {} + ) + + self._oneflow_model = oneflow_compile( + self.control_model, options=compile_options + ) self.control_model = self._oneflow_model @@ -108,4 +119,5 @@ def copy(self): ) self.copy_to(c) c._oneflow_model = self._oneflow_model + c.gen_compile_options = self.gen_compile_options return c diff --git a/onediff_comfy_nodes/utils/deep_cache_speedup.py b/onediff_comfy_nodes/utils/deep_cache_speedup.py new file mode 100644 index 000000000..30ea494e7 --- /dev/null +++ b/onediff_comfy_nodes/utils/deep_cache_speedup.py @@ -0,0 +1,122 @@ +import torch +from comfy import model_management + + +from .model_patcher import OneFlowDeepCacheSpeedUpModelPatcher + + +def deep_cache_speedup( + model, + use_graph, + cache_interval, + cache_layer_id, + cache_block_id, + start_step, + end_step, + *, + gen_compile_options=None, +): + offload_device = model_management.unet_offload_device() + model_patcher = OneFlowDeepCacheSpeedUpModelPatcher( + model.model, + load_device=model_management.get_torch_device(), + offload_device=offload_device, + cache_layer_id=cache_layer_id, + cache_block_id=cache_block_id, + use_graph=use_graph, + gen_compile_options=gen_compile_options, + ) + + current_t = -1 + current_step = -1 + cache_h = None + + def apply_model(model_function, kwargs): + nonlocal current_t, current_step, cache_h + + xa = kwargs["input"] + t = kwargs["timestep"] + c_concat = kwargs["c"].get("c_concat", None) + c_crossattn = kwargs["c"].get("c_crossattn", None) + y = kwargs["c"].get("y", None) + control = kwargs["c"].get("control", None) + transformer_options = kwargs["c"].get("transformer_options", None) + + # https://github.com/comfyanonymous/ComfyUI/blob/629e4c552cc30a75d2756cbff8095640af3af163/comfy/model_base.py#L51-L69 + sigma = t + xc = model_patcher.model.model_sampling.calculate_input(sigma, xa) + if c_concat is not None: + xc = torch.cat([xc] + [c_concat], dim=1) + + context = c_crossattn + dtype = model_patcher.model.get_dtype() + xc = xc.to(dtype) + t = model_patcher.model.model_sampling.timestep(t).float() + context = context.to(dtype) + extra_conds = {} + for o in kwargs: + extra = kwargs[o] + if hasattr(extra, "to"): + extra = extra.to(dtype) + extra_conds[o] = extra + + x = xc + timesteps = t + y = None if y is None else y.to(dtype) + transformer_options["original_shape"] = list(x.shape) + transformer_options["current_index"] = 0 + transformer_patches = transformer_options.get("patches", {}) + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :param context: conditioning plugged in via crossattn + :param y: an [N] Tensor of labels, if class-conditional. + :return: an [N x C x ...] Tensor of outputs. + """ + + # reference https://gist.github.com/laksjdjf/435c512bc19636e9c9af4ee7bea9eb86 + if t[0].item() > current_t: + current_step = -1 + + current_t = t[0].item() + apply = 1000 - end_step <= current_t <= 1000 - start_step # t is 999->0 + + if apply: + current_step += 1 + else: + current_step = -1 + current_t = t[0].item() + + is_slow_step = current_step % cache_interval == 0 and apply + + model_output = None + if is_slow_step: + cache_h = None + model_output, cache_h = model_patcher.deep_cache_unet( + x, + timesteps, + context, + y, + control, + transformer_options, + **extra_conds, + ) + else: + model_output, cache_h = model_patcher.fast_deep_cache_unet( + x, + cache_h, + timesteps, + context, + y, + control, + transformer_options, + **extra_conds, + ) + + return model_patcher.model.model_sampling.calculate_denoised( + sigma, model_output, xa + ) + + model_patcher.set_model_unet_function_wrapper(apply_model) + return (model_patcher,) diff --git a/onediff_comfy_nodes/utils/model_patcher.py b/onediff_comfy_nodes/utils/model_patcher.py index ca77d7f82..2ba6de289 100644 --- a/onediff_comfy_nodes/utils/model_patcher.py +++ b/onediff_comfy_nodes/utils/model_patcher.py @@ -496,6 +496,7 @@ def __init__( weight_inplace_update=False, *, use_graph=None, + gen_compile_options=None, ): from onediff.infer_compiler import oneflow_compile from onediff.infer_compiler.with_oneflow_compile import DeployableModule @@ -514,15 +515,19 @@ def __init__( self.fast_deep_cache_unet = FastDeepCacheUNet( self.model.diffusion_model, cache_layer_id, cache_block_id ) - if use_graph: + gen_compile_options = gen_compile_options or (lambda x: {}) + compile_options = gen_compile_options(self.deep_cache_unet) self.deep_cache_unet = oneflow_compile( self.deep_cache_unet, use_graph=use_graph, + options=compile_options, ) + compile_options = gen_compile_options(self.fast_deep_cache_unet) self.fast_deep_cache_unet = oneflow_compile( self.fast_deep_cache_unet, use_graph=use_graph, + options=compile_options, ) self.model._register_state_dict_hook(state_dict_hook) diff --git a/onediff_comfy_nodes/workflows/ControlNet/README.md b/onediff_comfy_nodes/workflows/ControlNet/README.md new file mode 100644 index 000000000..38ef0e7e3 --- /dev/null +++ b/onediff_comfy_nodes/workflows/ControlNet/README.md @@ -0,0 +1,62 @@ +##
ControlNet Documentation
+### performance +- [**Installation Guide**](https://github.com/siliconflow/onediff/blob/main/README_ENTERPRISE.md#install-onediff-enterprise) +- **Checkpoint:** v1-5-pruned-emaonly.ckpt +- **ControlNet:** control_openpose-fp16.safetensors +- **(4.48 - 2.35) / 4.48 = 47.6%** + + +
+ + **End2End Time** , Image Size 512x512 , Batch Size 4 , steps 20 + + + +
+ + +
+Figure Notes + +- [ControlNet Baseline Workflow](https://github.com/siliconflow/onediff/releases/download/0.12.0/controlnet_torch_00.png) +- [ControlNet + OneDiff Enterprise Workflow](https://github.com/siliconflow/onediff/releases/download/0.12.0/controlnet_onediff_quant_02.png) +
+ +### ControlNet Workflow +Here's some simple example of how to use controlnets. You can load this image in `ComfyUI` to get the full workflow. + +**Please note**: When launching comfyui, add the `gpu-only` parameter, for example, `python main.py --gpu-only`. +#### Basic Usage +Replace `"Load ControlNet Model"` with `"Load ControlNet Model - OneDiff"` in comfyui, as follows: +![ControlNet](./controlnet_onediff.png) +#### Quantization +![ControlNet](./controlnet_onediff_quant.png) +#### Mixing ControlNet +![ControlNet](./mixing_controlnets.png) + + + + +## FAQ +- Q: RuntimeError: After graph built, the device of graph can't be modified, current device: cuda:0, target device: cpu + - Please use `--gpu-only` when launching comfyui, for example, `python main.py --gpu-only`. + + +- Q: oneflow._oneflow_internal.exception.Exception: Check failed:(xxx == yyy) + - To initiate a fresh run, delete the files within the `ComfyUI/input/graphs/` directory and then proceed with rerunning the process. + - **Switching the strength parameter between 0 and > 0 in the "Apply ControlNet" node is not supported.** A strength of 0 implies not using ControlNet, while a strength greater than 0 activates ControlNet. This may lead to changes in the graph structure, resulting in errors. + +- Q: Acceleration of ControlNet: Not very apparent + - ControlNet is a very small model, In prior tests, the iteration ratio between UNet and ControlNet was 2:1. + - UNet compilation contributed to a 30% acceleration, while ControlNet contributed 15%, resulting in an overall acceleration of approximately 45%. + - In the enterprise edition, UNet exhibited a more substantial acceleration, making the acceleration from ControlNet relatively smaller. + + + +## Contact + +For users of OneDiff Community, please visit [GitHub Issues](https://github.com/siliconflow/onediff/issues) for bug reports and feature requests. + +For users of OneDiff Enterprise, you can contact contact@siliconflow.com for commercial support. + +Feel free to join our [Discord](https://discord.gg/RKJTjZMcPQ) community for discussions and to receive the latest updates. diff --git a/onediff_comfy_nodes/workflows/ControlNet/controlnet_onediff.png b/onediff_comfy_nodes/workflows/ControlNet/controlnet_onediff.png new file mode 100644 index 000000000..f15e13074 Binary files /dev/null and b/onediff_comfy_nodes/workflows/ControlNet/controlnet_onediff.png differ diff --git a/onediff_comfy_nodes/workflows/ControlNet/controlnet_onediff_quant.png b/onediff_comfy_nodes/workflows/ControlNet/controlnet_onediff_quant.png new file mode 100644 index 000000000..57a77e80a Binary files /dev/null and b/onediff_comfy_nodes/workflows/ControlNet/controlnet_onediff_quant.png differ diff --git a/onediff_comfy_nodes/workflows/ControlNet/controlnet_performance.png b/onediff_comfy_nodes/workflows/ControlNet/controlnet_performance.png new file mode 100644 index 000000000..18d605b85 Binary files /dev/null and b/onediff_comfy_nodes/workflows/ControlNet/controlnet_performance.png differ diff --git a/onediff_comfy_nodes/workflows/ControlNet/mixing_controlnets.png b/onediff_comfy_nodes/workflows/ControlNet/mixing_controlnets.png new file mode 100644 index 000000000..de087e58a Binary files /dev/null and b/onediff_comfy_nodes/workflows/ControlNet/mixing_controlnets.png differ diff --git a/onediff_comfy_nodes/workflows/model-speedup-controlnet.png b/onediff_comfy_nodes/workflows/model-speedup-controlnet.png deleted file mode 100644 index ea03cfc06..000000000 Binary files a/onediff_comfy_nodes/workflows/model-speedup-controlnet.png and /dev/null differ