diff --git a/README.md b/README.md index 3e788e66c..b72d61dca 100644 --- a/README.md +++ b/README.md @@ -43,7 +43,7 @@ Updated on Nov 6, 2023. ## Business inquiry on OneDiff Enterprise Edition -If you need **unrestricted multiple resolution**, **quantization** support or any other more advanced features, please send an email to caishenghang@oneflow.org . Tell us about your **use case, deployment scale and requirements**! +If you need **unrestricted multiple resolution**, **quantization**, **dynamic batchsize** support or any other more advanced features, please send an email to caishenghang@oneflow.org . Tell us about your **use case, deployment scale and requirements**! |   | OneDiff Community | OneDiff Enterprise| | -------------------- | ------------------- | ----------- | diff --git a/onediff_comfy_nodes/README.md b/onediff_comfy_nodes/README.md index 6e3fa811a..c976cd170 100644 --- a/onediff_comfy_nodes/README.md +++ b/onediff_comfy_nodes/README.md @@ -64,7 +64,7 @@ cp -r onediff_comfy_nodes path/to/ComfyUI/custom_nodes/ 6. (Optional) Advanced features -If you need **unrestricted multiple resolution**, **quantization** support or any other more advanced features, please send an email to caishenghang@oneflow.org . Tell us about your **use case, deployment scale and requirements**! +If you need **unrestricted multiple resolution**, **quantization**, **dynamic batchsize** support or any other more advanced features, please send an email to caishenghang@oneflow.org . Tell us about your **use case, deployment scale and requirements**! diff --git a/onediff_comfy_nodes/_config.py b/onediff_comfy_nodes/_config.py index 405896985..78c46142f 100644 --- a/onediff_comfy_nodes/_config.py +++ b/onediff_comfy_nodes/_config.py @@ -13,10 +13,11 @@ sys.path.insert(0, str(COMFYUI_SPEEDUP_ROOT)) sys.path.insert(0, str(INFER_COMPILER_REGISTRY)) import register_comfy # load plugins -from .utils.comfyui_speedup_utils import is_community_version +from onediff.infer_compiler.utils import is_community_version, get_support_message if is_community_version(): _USE_UNET_INT8 = False + print(get_support_message()) if _USE_UNET_INT8: import register_diffusers_quant # load plugins diff --git a/onediff_comfy_nodes/infer_compiler_registry/register_comfy/__init__.py b/onediff_comfy_nodes/infer_compiler_registry/register_comfy/__init__.py index d4cd61de6..f3a7d0397 100644 --- a/onediff_comfy_nodes/infer_compiler_registry/register_comfy/__init__.py +++ b/onediff_comfy_nodes/infer_compiler_registry/register_comfy/__init__.py @@ -1,4 +1,5 @@ from onediff.infer_compiler import register +from onediff.infer_compiler.utils import is_community_version from nodes import * # must imported before import comfy from pathlib import Path @@ -23,5 +24,18 @@ comfy_ops_Linear: Linear1f, } +if not is_community_version(): + from .openaimodel import Upsample as Upsample1f + from .openaimodel import UNetModel as UNetModel1f + + torch2of_class_map.update( + { + comfy.ldm.modules.diffusionmodules.openaimodel.Upsample: Upsample1f, + comfy.ldm.modules.diffusionmodules.openaimodel.UNetModel: UNetModel1f, + } + ) +else: + print("Dynamic batchsize is not supported in community version.") + register(torch2oflow_class_map=torch2of_class_map) diff --git a/onediff_comfy_nodes/infer_compiler_registry/register_comfy/attention.py b/onediff_comfy_nodes/infer_compiler_registry/register_comfy/attention.py index cbba59bb5..c2a5e0d4d 100644 --- a/onediff_comfy_nodes/infer_compiler_registry/register_comfy/attention.py +++ b/onediff_comfy_nodes/infer_compiler_registry/register_comfy/attention.py @@ -113,7 +113,8 @@ def forward(self, x, context=None, transformer_options={}): x = self.proj_in(x) # NOTE: rearrange in ComfyUI is replaced with reshape and use -1 to enable for # dynamic shape inference (multi resolution compilation) - x = x.reshape(b, c, -1).permute(0, 2, 1) + x = x.flatten(2, 3).permute(0, 2, 1) + # x = x.reshape(b, c, -1).permute(0, 2, 1) # x = rearrange(x, 'b c h w -> b (h w) c').contiguous() if self.use_linear: x = self.proj_in(x) diff --git a/onediff_comfy_nodes/infer_compiler_registry/register_comfy/openaimodel.py b/onediff_comfy_nodes/infer_compiler_registry/register_comfy/openaimodel.py new file mode 100644 index 000000000..2dfe4e0f8 --- /dev/null +++ b/onediff_comfy_nodes/infer_compiler_registry/register_comfy/openaimodel.py @@ -0,0 +1,153 @@ +import comfy +import oneflow as th # 'th' is the way ComfyUI name the torch +import oneflow.nn.functional as F +from onediff.infer_compiler.transform import proxy_class +from onediff.infer_compiler.transform import transform_mgr + +onediff_comfy = transform_mgr.transform_package("comfy") + + +class Upsample(proxy_class(comfy.ldm.modules.diffusionmodules.openaimodel.Upsample)): + # https://github.com/comfyanonymous/ComfyUI/blob/b0aab1e4ea3dfefe09c4f07de0e5237558097e22/comfy/ldm/modules/diffusionmodules/openaimodel.py#L82 + def forward(self, x, output_shape=None): + assert x.shape[1] == self.channels + if output_shape is not None and isinstance(output_shape, th.Tensor): + if self.dims == 3: + raise ValueError("output_shape shoud not be Tensor for dims == 3") + else: + x = F.interpolate_like(x, like=output_shape, mode="nearest") + else: + if self.dims == 3: + shape = [x.shape[2], x.shape[3] * 2, x.shape[4] * 2] + if output_shape is not None: + shape[1] = output_shape[3] + shape[2] = output_shape[4] + else: + shape = [x.shape[2] * 2, x.shape[3] * 2] + if output_shape is not None: + shape[0] = output_shape[2] + shape[1] = output_shape[3] + + x = F.interpolate(x, size=shape, mode="nearest") + + if self.use_conv: + x = self.conv(x) + return x + + +class UNetModel(proxy_class(comfy.ldm.modules.diffusionmodules.openaimodel.UNetModel)): + # https://github.com/comfyanonymous/ComfyUI/blob/b0aab1e4ea3dfefe09c4f07de0e5237558097e22/comfy/ldm/modules/diffusionmodules/openaimodel.py#L823 + def forward( + self, + x, + timesteps=None, + context=None, + y=None, + control=None, + transformer_options={}, + **kwargs + ): + timestep_embedding = ( + onediff_comfy.ldm.modules.diffusionmodules.util.timestep_embedding + ) + forward_timestep_embed = ( + onediff_comfy.ldm.modules.diffusionmodules.openaimodel.forward_timestep_embed + ) + apply_control = ( + onediff_comfy.ldm.modules.diffusionmodules.openaimodel.apply_control + ) + + transformer_options["original_shape"] = list(x.shape) + transformer_options["transformer_index"] = 0 + transformer_patches = transformer_options.get("patches", {}) + + num_video_frames = kwargs.get("num_video_frames", self.default_num_video_frames) + image_only_indicator = kwargs.get( + "image_only_indicator", self.default_image_only_indicator + ) + time_context = kwargs.get("time_context", None) + + assert (y is not None) == ( + self.num_classes is not None + ), "must specify y if and only if the model is class-conditional" + hs = [] + t_emb = timestep_embedding( + timesteps, self.model_channels, repeat_only=False + ).to(x.dtype) + emb = self.time_embed(t_emb) + + if self.num_classes is not None: + assert y.shape[0] == x.shape[0] + emb = emb + self.label_emb(y) + + h = x + for id, module in enumerate(self.input_blocks): + transformer_options["block"] = ("input", id) + h = forward_timestep_embed( + module, + h, + emb, + context, + transformer_options, + time_context=time_context, + num_video_frames=num_video_frames, + image_only_indicator=image_only_indicator, + ) + h = apply_control(h, control, "input") + if "input_block_patch" in transformer_patches: + patch = transformer_patches["input_block_patch"] + for p in patch: + h = p(h, transformer_options) + + hs.append(h) + if "input_block_patch_after_skip" in transformer_patches: + patch = transformer_patches["input_block_patch_after_skip"] + for p in patch: + h = p(h, transformer_options) + + transformer_options["block"] = ("middle", 0) + h = forward_timestep_embed( + self.middle_block, + h, + emb, + context, + transformer_options, + time_context=time_context, + num_video_frames=num_video_frames, + image_only_indicator=image_only_indicator, + ) + h = apply_control(h, control, "middle") + + for id, module in enumerate(self.output_blocks): + transformer_options["block"] = ("output", id) + hsp = hs.pop() + hsp = apply_control(hsp, control, "output") + + if "output_block_patch" in transformer_patches: + patch = transformer_patches["output_block_patch"] + for p in patch: + h, hsp = p(h, hsp, transformer_options) + + h = th.cat([h, hsp], dim=1) + del hsp + if len(hs) > 0: + # output_shape = hs[-1].shape + output_shape = hs[-1] + else: + output_shape = None + h = forward_timestep_embed( + module, + h, + emb, + context, + transformer_options, + output_shape, + time_context=time_context, + num_video_frames=num_video_frames, + image_only_indicator=image_only_indicator, + ) + h = h.type(x.dtype) + if self.predict_codebook_ids: + return self.id_predictor(h) + else: + return self.out(h) diff --git a/src/onediff/infer_compiler/utils/__init__.py b/src/onediff/infer_compiler/utils/__init__.py index cff867c47..91bc61b38 100644 --- a/src/onediff/infer_compiler/utils/__init__.py +++ b/src/onediff/infer_compiler/utils/__init__.py @@ -1,3 +1,8 @@ from .oneflow_exec_mode import oneflow_exec_mode, oneflow_exec_mode_enabled from .env_var import parse_boolean_from_env, set_boolean_env_var from .model_inplace_assign import TensorInplaceAssign +from .version_util import ( + get_support_message, + is_quantization_enabled, + is_community_version, +) diff --git a/onediff_comfy_nodes/utils/comfyui_speedup_utils.py b/src/onediff/infer_compiler/utils/version_util.py similarity index 56% rename from onediff_comfy_nodes/utils/comfyui_speedup_utils.py rename to src/onediff/infer_compiler/utils/version_util.py index 719c7927f..5b2bf888b 100644 --- a/onediff_comfy_nodes/utils/comfyui_speedup_utils.py +++ b/src/onediff/infer_compiler/utils/version_util.py @@ -1,4 +1,3 @@ -import oneflow from importlib_metadata import version @@ -6,14 +5,14 @@ def get_support_message(): recipient_email = "caishenghang@oneflow.org" message = f"""\033[91m Advanced features cannot be used !!! \033[0m - If you need unrestricted multiple resolution, quantization support or any other more advanced - features, please send an email to {recipient_email} and tell us about - your **use case, deployment scale and requirements**. +If you need unrestricted multiple resolution, quantization support or any other more advanced features, please send an email to \033[91m{recipient_email}\033[0m and tell us about your use case, deployment scale and requirements. """ return message def is_quantization_enabled(): + import oneflow + if version("oneflow") < "0.9.1": RuntimeError( "onediff_comfy_nodes requires oneflow>=0.9.1 to run.", get_support_message() @@ -26,12 +25,6 @@ def is_quantization_enabled(): return hasattr(oneflow._C, "dynamic_quantization") -def is_community_version(stop_if_not=False): +def is_community_version(): is_community = not is_quantization_enabled() - if is_community: - message = get_support_message() - if stop_if_not: - input(message + "\nPress any key to continue...") - else: - print(message) return is_community