Skip to content

Commit

Permalink
add mock class for specific workflow (#416)
Browse files Browse the repository at this point in the history
- 适配动态 batchsize 以及一个特定工作流中的动态 shape 问题
- 动态 batchsize 不支持 community 版本,说明
  • Loading branch information
doombeaker authored Dec 13, 2023
1 parent 8c458a3 commit 3daab10
Show file tree
Hide file tree
Showing 8 changed files with 182 additions and 15 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ Updated on Nov 6, 2023.

## Business inquiry on OneDiff Enterprise Edition

If you need **unrestricted multiple resolution**, **quantization** support or any other more advanced features, please send an email to caishenghang@oneflow.org . Tell us about your **use case, deployment scale and requirements**!
If you need **unrestricted multiple resolution**, **quantization**, **dynamic batchsize** support or any other more advanced features, please send an email to caishenghang@oneflow.org . Tell us about your **use case, deployment scale and requirements**!

|   | OneDiff Community | OneDiff Enterprise|
| -------------------- | ------------------- | ----------- |
Expand Down
2 changes: 1 addition & 1 deletion onediff_comfy_nodes/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ cp -r onediff_comfy_nodes path/to/ComfyUI/custom_nodes/

6. (Optional) Advanced features

If you need **unrestricted multiple resolution**, **quantization** support or any other more advanced features, please send an email to caishenghang@oneflow.org . Tell us about your **use case, deployment scale and requirements**!
If you need **unrestricted multiple resolution**, **quantization**, **dynamic batchsize** support or any other more advanced features, please send an email to caishenghang@oneflow.org . Tell us about your **use case, deployment scale and requirements**!



Expand Down
3 changes: 2 additions & 1 deletion onediff_comfy_nodes/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@
sys.path.insert(0, str(COMFYUI_SPEEDUP_ROOT))
sys.path.insert(0, str(INFER_COMPILER_REGISTRY))
import register_comfy # load plugins
from .utils.comfyui_speedup_utils import is_community_version
from onediff.infer_compiler.utils import is_community_version, get_support_message

if is_community_version():
_USE_UNET_INT8 = False
print(get_support_message())

if _USE_UNET_INT8:
import register_diffusers_quant # load plugins
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from onediff.infer_compiler import register
from onediff.infer_compiler.utils import is_community_version
from nodes import * # must imported before import comfy
from pathlib import Path

Expand All @@ -23,5 +24,18 @@
comfy_ops_Linear: Linear1f,
}

if not is_community_version():
from .openaimodel import Upsample as Upsample1f
from .openaimodel import UNetModel as UNetModel1f

torch2of_class_map.update(
{
comfy.ldm.modules.diffusionmodules.openaimodel.Upsample: Upsample1f,
comfy.ldm.modules.diffusionmodules.openaimodel.UNetModel: UNetModel1f,
}
)
else:
print("Dynamic batchsize is not supported in community version.")


register(torch2oflow_class_map=torch2of_class_map)
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ def forward(self, x, context=None, transformer_options={}):
x = self.proj_in(x)
# NOTE: rearrange in ComfyUI is replaced with reshape and use -1 to enable for
# dynamic shape inference (multi resolution compilation)
x = x.reshape(b, c, -1).permute(0, 2, 1)
x = x.flatten(2, 3).permute(0, 2, 1)
# x = x.reshape(b, c, -1).permute(0, 2, 1)
# x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
if self.use_linear:
x = self.proj_in(x)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import comfy
import oneflow as th # 'th' is the way ComfyUI name the torch
import oneflow.nn.functional as F
from onediff.infer_compiler.transform import proxy_class
from onediff.infer_compiler.transform import transform_mgr

onediff_comfy = transform_mgr.transform_package("comfy")


class Upsample(proxy_class(comfy.ldm.modules.diffusionmodules.openaimodel.Upsample)):
# https://github.com/comfyanonymous/ComfyUI/blob/b0aab1e4ea3dfefe09c4f07de0e5237558097e22/comfy/ldm/modules/diffusionmodules/openaimodel.py#L82
def forward(self, x, output_shape=None):
assert x.shape[1] == self.channels
if output_shape is not None and isinstance(output_shape, th.Tensor):
if self.dims == 3:
raise ValueError("output_shape shoud not be Tensor for dims == 3")
else:
x = F.interpolate_like(x, like=output_shape, mode="nearest")
else:
if self.dims == 3:
shape = [x.shape[2], x.shape[3] * 2, x.shape[4] * 2]
if output_shape is not None:
shape[1] = output_shape[3]
shape[2] = output_shape[4]
else:
shape = [x.shape[2] * 2, x.shape[3] * 2]
if output_shape is not None:
shape[0] = output_shape[2]
shape[1] = output_shape[3]

x = F.interpolate(x, size=shape, mode="nearest")

if self.use_conv:
x = self.conv(x)
return x


class UNetModel(proxy_class(comfy.ldm.modules.diffusionmodules.openaimodel.UNetModel)):
# https://github.com/comfyanonymous/ComfyUI/blob/b0aab1e4ea3dfefe09c4f07de0e5237558097e22/comfy/ldm/modules/diffusionmodules/openaimodel.py#L823
def forward(
self,
x,
timesteps=None,
context=None,
y=None,
control=None,
transformer_options={},
**kwargs
):
timestep_embedding = (
onediff_comfy.ldm.modules.diffusionmodules.util.timestep_embedding
)
forward_timestep_embed = (
onediff_comfy.ldm.modules.diffusionmodules.openaimodel.forward_timestep_embed
)
apply_control = (
onediff_comfy.ldm.modules.diffusionmodules.openaimodel.apply_control
)

transformer_options["original_shape"] = list(x.shape)
transformer_options["transformer_index"] = 0
transformer_patches = transformer_options.get("patches", {})

num_video_frames = kwargs.get("num_video_frames", self.default_num_video_frames)
image_only_indicator = kwargs.get(
"image_only_indicator", self.default_image_only_indicator
)
time_context = kwargs.get("time_context", None)

assert (y is not None) == (
self.num_classes is not None
), "must specify y if and only if the model is class-conditional"
hs = []
t_emb = timestep_embedding(
timesteps, self.model_channels, repeat_only=False
).to(x.dtype)
emb = self.time_embed(t_emb)

if self.num_classes is not None:
assert y.shape[0] == x.shape[0]
emb = emb + self.label_emb(y)

h = x
for id, module in enumerate(self.input_blocks):
transformer_options["block"] = ("input", id)
h = forward_timestep_embed(
module,
h,
emb,
context,
transformer_options,
time_context=time_context,
num_video_frames=num_video_frames,
image_only_indicator=image_only_indicator,
)
h = apply_control(h, control, "input")
if "input_block_patch" in transformer_patches:
patch = transformer_patches["input_block_patch"]
for p in patch:
h = p(h, transformer_options)

hs.append(h)
if "input_block_patch_after_skip" in transformer_patches:
patch = transformer_patches["input_block_patch_after_skip"]
for p in patch:
h = p(h, transformer_options)

transformer_options["block"] = ("middle", 0)
h = forward_timestep_embed(
self.middle_block,
h,
emb,
context,
transformer_options,
time_context=time_context,
num_video_frames=num_video_frames,
image_only_indicator=image_only_indicator,
)
h = apply_control(h, control, "middle")

for id, module in enumerate(self.output_blocks):
transformer_options["block"] = ("output", id)
hsp = hs.pop()
hsp = apply_control(hsp, control, "output")

if "output_block_patch" in transformer_patches:
patch = transformer_patches["output_block_patch"]
for p in patch:
h, hsp = p(h, hsp, transformer_options)

h = th.cat([h, hsp], dim=1)
del hsp
if len(hs) > 0:
# output_shape = hs[-1].shape
output_shape = hs[-1]
else:
output_shape = None
h = forward_timestep_embed(
module,
h,
emb,
context,
transformer_options,
output_shape,
time_context=time_context,
num_video_frames=num_video_frames,
image_only_indicator=image_only_indicator,
)
h = h.type(x.dtype)
if self.predict_codebook_ids:
return self.id_predictor(h)
else:
return self.out(h)
5 changes: 5 additions & 0 deletions src/onediff/infer_compiler/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
from .oneflow_exec_mode import oneflow_exec_mode, oneflow_exec_mode_enabled
from .env_var import parse_boolean_from_env, set_boolean_env_var
from .model_inplace_assign import TensorInplaceAssign
from .version_util import (
get_support_message,
is_quantization_enabled,
is_community_version,
)
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
import oneflow
from importlib_metadata import version


def get_support_message():
recipient_email = "caishenghang@oneflow.org"

message = f"""\033[91m Advanced features cannot be used !!! \033[0m
If you need unrestricted multiple resolution, quantization support or any other more advanced
features, please send an email to {recipient_email} and tell us about
your **use case, deployment scale and requirements**.
If you need unrestricted multiple resolution, quantization support or any other more advanced features, please send an email to \033[91m{recipient_email}\033[0m and tell us about your use case, deployment scale and requirements.
"""
return message


def is_quantization_enabled():
import oneflow

if version("oneflow") < "0.9.1":
RuntimeError(
"onediff_comfy_nodes requires oneflow>=0.9.1 to run.", get_support_message()
Expand All @@ -26,12 +25,6 @@ def is_quantization_enabled():
return hasattr(oneflow._C, "dynamic_quantization")


def is_community_version(stop_if_not=False):
def is_community_version():
is_community = not is_quantization_enabled()
if is_community:
message = get_support_message()
if stop_if_not:
input(message + "\nPress any key to continue...")
else:
print(message)
return is_community

0 comments on commit 3daab10

Please sign in to comment.