Skip to content

Commit

Permalink
fix ControlLora usage (#565)
Browse files Browse the repository at this point in the history
- 修复 controllora ,修改 length参数触发编译。
- 支持 OneDiffControlNetLoader 自动缓存和加载图。
- doc:
https://github.com/siliconflow/onediff/tree/fix_ControlLora_use/onediff_comfy_nodes/workflows/ControlNet#basic-usage
  • Loading branch information
ccssu authored Jan 24, 2024
1 parent 25d2baf commit e777af6
Show file tree
Hide file tree
Showing 13 changed files with 439 additions and 152 deletions.
7 changes: 5 additions & 2 deletions onediff_comfy_nodes/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,12 @@ This example demonstrates how to utilize LoRAs. You have the flexibility to modi

### ControlNet

While there is an example demonstrating OpenPose ControlNet, it's important to note that OneDiff seamlessly supports a wide range of ControlNet types, including depth mapping, canny, and more.
> doc link: [ControlNet](https://github.com/siliconflow/onediff/tree/main/onediff_comfy_nodes/workflows/ControlNet)

[ControlNet Speedup](workflows/model-speedup-controlnet.png)

While there is an example demonstrating OpenPose ControlNet, it's important to note that OneDiff seamlessly supports a wide range of ControlNet types, including depth mapping, canny, and more.
[ControlNet Speedup](workflows/ControlNet/controlnet_onediff.png)
### SVD
Expand Down
6 changes: 3 additions & 3 deletions onediff_comfy_nodes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
VaeSpeedup,
VaeGraphLoader,
VaeGraphSaver,
ControlNetSpeedup,
SVDSpeedup,
ModuleDeepCacheSpeedup,
OneDiffCheckpointLoaderSimple,
OneDiffControlNetLoader,
OneDiffDeepCacheCheckpointLoaderSimple,
)
from ._compare_node import CompareModel, ShowImageDiff

Expand All @@ -25,11 +25,11 @@
"VaeSpeedup": VaeSpeedup,
"VaeGraphSaver": VaeGraphSaver,
"VaeGraphLoader": VaeGraphLoader,
"ControlNetSpeedup": ControlNetSpeedup,
"SVDSpeedup": SVDSpeedup,
"ModuleDeepCacheSpeedup": ModuleDeepCacheSpeedup,
"OneDiffCheckpointLoaderSimple": OneDiffCheckpointLoaderSimple,
"OneDiffControlNetLoader": OneDiffControlNetLoader,
"OneDiffDeepCacheCheckpointLoaderSimple": OneDiffDeepCacheCheckpointLoaderSimple,
}

NODE_DISPLAY_NAME_MAPPINGS = {
Expand All @@ -41,11 +41,11 @@
"VaeSpeedup": "VAE Speedup",
"VaeGraphLoader": "VAE Graph Loader",
"VaeGraphSaver": "VAE Graph Saver",
"ControlNetSpeedup": "ControlNet Speedup",
"SVDSpeedup": "SVD Speedup",
"ModuleDeepCacheSpeedup": "Model DeepCache Speedup",
"OneDiffCheckpointLoaderSimple": "Load Checkpoint - OneDiff",
"OneDiffControlNetLoader": "Load ControlNet Model - OneDiff",
"OneDiffDeepCacheCheckpointLoaderSimple": "Load Checkpoint - OneDiff DeepCache",
}


Expand Down
271 changes: 127 additions & 144 deletions onediff_comfy_nodes/_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,11 @@
from .utils.loader_sample_tools import compoile_unet, quantize_unet
from .utils.graph_path import generate_graph_path
from .modules.hijack_model_management import model_management_hijacker
from .modules.hijack_nodes import nodes_hijacker
from .utils.deep_cache_speedup import deep_cache_speedup

model_management_hijacker.hijack() # add flow.cuda.empty_cache()
nodes_hijacker.hijack()


__all__ = [
Expand Down Expand Up @@ -290,44 +293,6 @@ def save_graph(self, images, vae, filename_prefix):
return {}


class ControlNetSpeedup:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"control_net": ("CONTROL_NET",),
"static_mode": (["enable", "disable"],),
}
}

RETURN_TYPES = ("CONTROL_NET",)
RETURN_NAMES = ("control_net",)
FUNCTION = "apply_controlnet"

CATEGORY = "OneDiff"

def apply_controlnet(self, control_net, static_mode):
if static_mode == "enable":
from comfy.controlnet import ControlNet, ControlLora
from .modules.onediff_controlnet import (
OneDiffControlNet,
OneDiffControlLora,
)

if isinstance(control_net, ControlLora):
control_net = OneDiffControlLora.from_controllora(control_net)
return (control_net,)
elif isinstance(control_net, ControlNet):
control_net = OneDiffControlNet.from_controlnet(control_net)
return (control_net,)
else:
raise TypeError(
f"control_net must be ControlNet or ControlLora, got {type(control_net)}"
)
else:
return (control_net,)


class Quant8Model:
@classmethod
def INPUT_TYPES(s):
Expand Down Expand Up @@ -423,112 +388,16 @@ def deep_cache_convert(
start_step,
end_step,
):
use_graph = static_mode == "enable"

offload_device = model_management.unet_offload_device()
oneflow_model = OneFlowDeepCacheSpeedUpModelPatcher(
model.model,
load_device=model_management.get_torch_device(),
offload_device=offload_device,
return deep_cache_speedup(
model=model,
use_graph=(static_mode == "enable"),
cache_interval=cache_interval,
cache_layer_id=cache_layer_id,
cache_block_id=cache_block_id,
use_graph=use_graph,
start_step=start_step,
end_step=end_step,
)

current_t = -1
current_step = -1
cache_h = None

def apply_model(model_function, kwargs):
nonlocal current_t, current_step, cache_h

xa = kwargs["input"]
t = kwargs["timestep"]
c_concat = kwargs["c"].get("c_concat", None)
c_crossattn = kwargs["c"].get("c_crossattn", None)
y = kwargs["c"].get("y", None)
control = kwargs["c"].get("control", None)
transformer_options = kwargs["c"].get("transformer_options", None)

# https://github.com/comfyanonymous/ComfyUI/blob/629e4c552cc30a75d2756cbff8095640af3af163/comfy/model_base.py#L51-L69
sigma = t
xc = oneflow_model.model.model_sampling.calculate_input(sigma, xa)
if c_concat is not None:
xc = torch.cat([xc] + [c_concat], dim=1)

context = c_crossattn
dtype = oneflow_model.model.get_dtype()
xc = xc.to(dtype)
t = oneflow_model.model.model_sampling.timestep(t).float()
context = context.to(dtype)
extra_conds = {}
for o in kwargs:
extra = kwargs[o]
if hasattr(extra, "to"):
extra = extra.to(dtype)
extra_conds[o] = extra

x = xc
timesteps = t
y = None if y is None else y.to(dtype)
transformer_options["original_shape"] = list(x.shape)
transformer_options["current_index"] = 0
transformer_patches = transformer_options.get("patches", {})
"""
Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs.
:param timesteps: a 1-D batch of timesteps.
:param context: conditioning plugged in via crossattn
:param y: an [N] Tensor of labels, if class-conditional.
:return: an [N x C x ...] Tensor of outputs.
"""

# reference https://gist.github.com/laksjdjf/435c512bc19636e9c9af4ee7bea9eb86
if t[0].item() > current_t:
current_step = -1

current_t = t[0].item()
apply = 1000 - end_step <= current_t <= 1000 - start_step # t is 999->0

if apply:
current_step += 1
else:
current_step = -1
current_t = t[0].item()

is_slow_step = current_step % cache_interval == 0 and apply

model_output = None
if is_slow_step:
cache_h = None
model_output, cache_h = oneflow_model.deep_cache_unet(
x,
timesteps,
context,
y,
control,
transformer_options,
**extra_conds,
)
else:
model_output, cache_h = oneflow_model.fast_deep_cache_unet(
x,
cache_h,
timesteps,
context,
y,
control,
transformer_options,
**extra_conds,
)

return oneflow_model.model.model_sampling.calculate_denoised(
sigma, model_output, xa
)

oneflow_model.set_model_unet_function_wrapper(apply_model)
return (oneflow_model,)


from nodes import CheckpointLoaderSimple, ControlNetLoader
from comfy.controlnet import ControlLora, ControlNet
Expand All @@ -542,19 +411,34 @@ class OneDiffControlNetLoader(ControlNetLoader):

def onediff_load_controlnet(self, control_net_name):
controlnet = super().load_controlnet(control_net_name)[0]
load_device = model_management.get_torch_device()

def gen_compile_options(model):
graph_file = generate_graph_path(control_net_name, model)
return {
"graph_file": graph_file,
"graph_file_device": load_device,
}

if isinstance(controlnet, ControlLora):
controlnet = OneDiffControlLora.from_controllora(controlnet)
controlnet = OneDiffControlLora.from_controllora(
controlnet, gen_compile_options=gen_compile_options
)
return (controlnet,)
elif isinstance(controlnet, ControlNet):
control_model = controlnet.control_model
control_model = control_model.to(model_management.get_torch_device())
controlnet.control_model = oneflow_compile(control_model)
compile_options = gen_compile_options(control_model)
control_model = control_model.to(load_device)
controlnet.control_model = oneflow_compile(
control_model, options=compile_options
)
return (controlnet,)
else:
print(f"Warning: {type(controlnet)=} is not ControlLora or ControlNet")
print("\033[1;31;40m Warning: {type(controlnet)=} is not ControlLora or ControlNet \033[0m")
return (controlnet,)



class OneDiffCheckpointLoaderSimple(CheckpointLoaderSimple):
@classmethod
def INPUT_TYPES(s):
Expand Down Expand Up @@ -596,6 +480,105 @@ def onediff_load_checkpoint(
return modelpatcher, clip, vae


class OneDiffDeepCacheCheckpointLoaderSimple(CheckpointLoaderSimple):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"ckpt_name": (folder_paths.get_filename_list("checkpoints"),),
"vae_speedup": (["disable", "enable"],),
"static_mode": (["enable", "disable"],),
"cache_interval": (
"INT",
{
"default": 3,
"min": 1,
"max": 1000,
"step": 1,
"display": "number",
},
),
"cache_layer_id": (
"INT",
{"default": 0, "min": 0, "max": 12, "step": 1, "display": "number"},
),
"cache_block_id": (
"INT",
{"default": 1, "min": 0, "max": 12, "step": 1, "display": "number"},
),
"start_step": (
"INT",
{
"default": 0,
"min": 0,
"max": 1000,
"step": 1,
"display": "number",
},
),
"end_step": (
"INT",
{
"default": 1000,
"min": 0,
"max": 1000,
"step": 0.1,
},
),
}
}

CATEGORY = "OneDiff/Loaders"
FUNCTION = "onediff_load_checkpoint"

def onediff_load_checkpoint(
self,
ckpt_name,
vae_speedup,
output_vae=True,
output_clip=True,
static_mode="enable",
cache_interval=3,
cache_layer_id=0,
cache_block_id=1,
start_step=0,
end_step=1000,
):
# CheckpointLoaderSimple.load_checkpoint
modelpatcher, clip, vae = self.load_checkpoint(
ckpt_name, output_vae, output_clip
)

def gen_compile_options(model):
# cache_key = f'{cache_interval}_{cache_layer_id}_{cache_block_id}_{start_step}_{end_step}'
graph_file = generate_graph_path(ckpt_name, model)
return {
"graph_file": graph_file,
"graph_file_device": model_management.get_torch_device(),
}

if vae_speedup == "enable":
vae.first_stage_model = oneflow_compile(
vae.first_stage_model,
use_graph=True,
options=gen_compile_options(vae.first_stage_model),
)

modelpatcher = deep_cache_speedup(
model=modelpatcher,
use_graph=(static_mode == "enable"),
cache_interval=cache_interval,
cache_layer_id=cache_layer_id,
cache_block_id=cache_block_id,
start_step=start_step,
end_step=end_step,
gen_compile_options=gen_compile_options,
)[0]
# set inplace update
modelpatcher.weight_inplace_update = True
return modelpatcher, clip, vae


class OneDiffQuantCheckpointLoaderSimple(CheckpointLoaderSimple):
@classmethod
def INPUT_TYPES(s):
Expand Down
Loading

0 comments on commit e777af6

Please sign in to comment.