Skip to content

Commit

Permalink
Integrate deep cache (#426)
Browse files Browse the repository at this point in the history
## ComfyUI Node and Example
Node name: ModuleDeepCacheSpeedup

Example workflow:
https://github.com/Oneflow-Inc/onediff/blob/integrate_deep_cache/onediff_comfy_nodes/workflows/deep-cache.png

## Depending on the new oneflow community edition

cuda 11.8: `python3 -m pip install --pre oneflow -f
https://oneflow-pro.oss-cn-beijing.aliyuncs.com/branch/community/cu118`

cuda12.1: `python3 -m pip install --pre oneflow -f
https://oneflow-pro.oss-cn-beijing.aliyuncs.com/branch/community/cu121`

cuda12.2: `python3 -m pip install --pre oneflow -f
https://oneflow-pro.oss-cn-beijing.aliyuncs.com/branch/community/cu122`

## SDXL  UNet iter speed
### a100
- 1024 * 1024

|SDXL(Pytorch)| DeepCache | SDXL(with OneDiff) |DeepCache with OneDiff|
| :-: | :-: | :-: | :-: |
|7.79it/s(1x)| 16.92it/s(2.17x) | 10.11it/s(1.30x) | 23.58it/s(3.03x) |

-  512 * 512

|SDXL(Pytorch)| DeepCache | SDXL(with OneDiff) |DeepCache with OneDiff|
| :-: | :-: | :-: | :-: |
|22.47it/s(1x)| 48.73it/(2.17x) | 28.67it/s(1.26x) | 54.79it/s(2.44x) |

### 3090 
- 1024 * 1024

|SDXL(Pytorch)| DeepCache | SDXL(with OneDiff) |DeepCache with OneDiff|
| :-: | :-: | :-: | :-: |
|4.01it/s(1x)|  8.79it/s(2.19x) | 6.41it/s(1.6x)  |  14.11it/s(3.52x) |

-  512 * 512

|SDXL(Pytorch)| DeepCache | SDXL(with OneDiff) |DeepCache with OneDiff|
| :-: | :-: | :-: | :-: |
|13.88it/(1x)| 29.81it/s(2.15x) | 21.49it/s(1.55x) | 43.56it/s(3.14x) |

---------

Co-authored-by: FengWen <ccsuwen@gmail.com>
Co-authored-by: FengWen <109639975+ccssu@users.noreply.github.com>
Co-authored-by: Xiaoyu Xu <xuxiaoyu2048@foxmail.com>
  • Loading branch information
4 people authored Dec 19, 2023
1 parent 2cf4a2f commit 8a35a9e
Show file tree
Hide file tree
Showing 8 changed files with 493 additions and 3 deletions.
3 changes: 3 additions & 0 deletions onediff_comfy_nodes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
ControlNetGraphLoader,
ControlNetGraphSaver,
SVDSpeedup,
ModuleDeepCacheSpeedup,
)
from ._compare_node import CompareModel, ShowImageDiff

Expand All @@ -28,6 +29,7 @@
"ControlNetGraphLoader": ControlNetGraphLoader,
"ControlNetGraphSaver": ControlNetGraphSaver,
"SVDSpeedup": SVDSpeedup,
"ModuleDeepCacheSpeedup": ModuleDeepCacheSpeedup,
}

NODE_DISPLAY_NAME_MAPPINGS = {
Expand All @@ -43,6 +45,7 @@
"ControlNetGraphLoader": "ControlNet Graph Loader",
"ControlNetGraphSaver": "ControlNet Graph Saver",
"SVDSpeedup": "SVD Speedup",
"ModuleDeepCacheSpeedup": "Model DeepCache Speedup",
}

if _USE_UNET_INT8:
Expand Down
177 changes: 176 additions & 1 deletion onediff_comfy_nodes/_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,13 @@
from comfy import model_management
from comfy.cli_args import args

from .utils import OneFlowSpeedUpModelPatcher, save_graph, load_graph, OUTPUT_FOLDER
from .utils import (
OneFlowSpeedUpModelPatcher,
OneFlowDeepCacheSpeedUpModelPatcher,
save_graph,
load_graph,
OUTPUT_FOLDER,
)

from .modules.hijack_model_management import model_management_hijacker

Expand All @@ -31,6 +37,7 @@
"VaeGraphLoader",
"VaeGraphSaver",
"SVDSpeedup",
"ModuleDeepCacheSpeedup",
]

if not args.dont_upcast_attention:
Expand Down Expand Up @@ -402,3 +409,171 @@ def quantize_model(self, model, output_dir, conv, linear):
verbose=False,
)
return {}


class ModuleDeepCacheSpeedup:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL",),
"static_mode": (["enable", "disable"],),
"cache_interval": (
"INT",
{
"default": 3,
"min": 1,
"max": 1000,
"step": 1,
"display": "number",
},
),
"cache_layer_id": (
"INT",
{"default": 0, "min": 0, "max": 12, "step": 1, "display": "number"},
),
"cache_block_id": (
"INT",
{"default": 1, "min": 0, "max": 12, "step": 1, "display": "number"},
),
"start_step": (
"INT",
{
"default": 0,
"min": 0,
"max": 1000,
"step": 1,
"display": "number",
},
),
"end_step": (
"INT",
{
"default": 1000,
"min": 0,
"max": 1000,
"step": 0.1,
},
),
},
}

RETURN_TYPES = ("MODEL",)
FUNCTION = "deep_cache_convert"
CATEGORY = "OneDiff"

def deep_cache_convert(
self,
model,
static_mode,
cache_interval,
cache_layer_id,
cache_block_id,
start_step,
end_step,
):
use_graph = static_mode == "enable"

offload_device = model_management.unet_offload_device()
oneflow_model = OneFlowDeepCacheSpeedUpModelPatcher(
model.model,
load_device=model_management.get_torch_device(),
offload_device=offload_device,
cache_layer_id=cache_layer_id,
cache_block_id=cache_block_id,
use_graph=use_graph,
)

current_t = -1
current_step = -1
cache_h = None

def apply_model(model_function, kwargs):
nonlocal current_t, current_step, cache_h

xa = kwargs["input"]
t = kwargs["timestep"]
c_concat = kwargs["c"].get("c_concat", None)
c_crossattn = kwargs["c"].get("c_crossattn", None)
y = kwargs["c"].get("y", None)
control = kwargs["c"].get("control", None)
transformer_options = kwargs["c"].get("transformer_options", None)

# https://github.com/comfyanonymous/ComfyUI/blob/629e4c552cc30a75d2756cbff8095640af3af163/comfy/model_base.py#L51-L69
sigma = t
xc = oneflow_model.model.model_sampling.calculate_input(sigma, xa)
if c_concat is not None:
xc = torch.cat([xc] + [c_concat], dim=1)

context = c_crossattn
dtype = oneflow_model.model.get_dtype()
xc = xc.to(dtype)
t = oneflow_model.model.model_sampling.timestep(t).float()
context = context.to(dtype)
extra_conds = {}
for o in kwargs:
extra = kwargs[o]
if hasattr(extra, "to"):
extra = extra.to(dtype)
extra_conds[o] = extra

x = xc
timesteps = t
y = None if y is None else y.to(dtype)
transformer_options["original_shape"] = list(x.shape)
transformer_options["current_index"] = 0
transformer_patches = transformer_options.get("patches", {})
"""
Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs.
:param timesteps: a 1-D batch of timesteps.
:param context: conditioning plugged in via crossattn
:param y: an [N] Tensor of labels, if class-conditional.
:return: an [N x C x ...] Tensor of outputs.
"""

# reference https://gist.github.com/laksjdjf/435c512bc19636e9c9af4ee7bea9eb86
if t[0].item() > current_t:
current_step = -1

current_t = t[0].item()
apply = 1000 - end_step <= current_t <= 1000 - start_step # t is 999->0

if apply:
current_step += 1
else:
current_step = -1
current_t = t[0].item()

is_slow_step = current_step % cache_interval == 0 and apply

model_output = None
if is_slow_step:
cache_h = None
model_output, cache_h = oneflow_model.deep_cache_unet(
x,
timesteps,
context,
y,
control,
transformer_options,
**extra_conds,
)
else:
model_output, cache_h = oneflow_model.fast_deep_cache_unet(
x,
cache_h,
timesteps,
context,
y,
control,
transformer_options,
**extra_conds,
)

return oneflow_model.model.model_sampling.calculate_denoised(
sigma, model_output, xa
)

oneflow_model.set_model_unet_function_wrapper(apply_model)
return (oneflow_model,)
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from .attention import CrossAttention as CrossAttention1f
from .attention import SpatialTransformer as SpatialTransformer1f
from .linear import Linear as Linear1f
from .deep_cache_unet import DeepCacheUNet
from .deep_cache_unet import FastDeepCacheUNet

if hasattr(comfy.ops, "disable_weight_init"):
comfy_ops_Linear = comfy.ops.disable_weight_init.Linear
Expand Down
Loading

0 comments on commit 8a35a9e

Please sign in to comment.