From 5d2714a56d6a1e1a4a3dc5ced79a6dbedfeb5259 Mon Sep 17 00:00:00 2001 From: FengWen Date: Wed, 22 May 2024 17:48:32 +0800 Subject: [PATCH 1/5] Modify after Cancelling Quantitative Model --- .../modules/oneflow/booster_basic.py | 2 +- .../modules/oneflow/patch_management/__init__.py | 1 + .../oneflow/patch_management/patch_for_oneflow.py | 12 ++++++++++++ .../oneflow/patch_management/patch_for_torch.py | 13 +++++++++++++ 4 files changed, 27 insertions(+), 1 deletion(-) create mode 100644 onediff_comfy_nodes/modules/oneflow/patch_management/patch_for_torch.py diff --git a/onediff_comfy_nodes/modules/oneflow/booster_basic.py b/onediff_comfy_nodes/modules/oneflow/booster_basic.py index cf7773993..f35d4f27d 100644 --- a/onediff_comfy_nodes/modules/oneflow/booster_basic.py +++ b/onediff_comfy_nodes/modules/oneflow/booster_basic.py @@ -47,7 +47,7 @@ def _(self, model: ModelPatcher, ckpt_name: Optional[str] = None, **kwargs): ) set_compiled_options(compiled_model, graph_file) - model.weight_inplace_update = True + return model @execute.register(ControlNet) diff --git a/onediff_comfy_nodes/modules/oneflow/patch_management/__init__.py b/onediff_comfy_nodes/modules/oneflow/patch_management/__init__.py index 5665c4b74..f3d65d4c2 100644 --- a/onediff_comfy_nodes/modules/oneflow/patch_management/__init__.py +++ b/onediff_comfy_nodes/modules/oneflow/patch_management/__init__.py @@ -1,2 +1,3 @@ from .patch_for_oneflow import * +from .patch_for_torch import * from .patch_factory import create_patch_executor, PatchType diff --git a/onediff_comfy_nodes/modules/oneflow/patch_management/patch_for_oneflow.py b/onediff_comfy_nodes/modules/oneflow/patch_management/patch_for_oneflow.py index 304d905f7..381db0aa0 100644 --- a/onediff_comfy_nodes/modules/oneflow/patch_management/patch_for_oneflow.py +++ b/onediff_comfy_nodes/modules/oneflow/patch_management/patch_for_oneflow.py @@ -14,3 +14,15 @@ def __init__(self, *args, **kwargs): flow.framework.args_tree.NamedArg = PatchNamedArg + + + +original_copy_ = flow.Tensor.copy_ + +def new_copy_(self, *args, **kwargs): + # print(f'{__file__}.new_copy_ {self.dtype=}') + if self.dtype != flow.int8: + return original_copy_(self, *args, **kwargs) + +# Replace the original copy_ method with the new one +flow.Tensor.copy_ = new_copy_ diff --git a/onediff_comfy_nodes/modules/oneflow/patch_management/patch_for_torch.py b/onediff_comfy_nodes/modules/oneflow/patch_management/patch_for_torch.py new file mode 100644 index 000000000..d5e66d0d1 --- /dev/null +++ b/onediff_comfy_nodes/modules/oneflow/patch_management/patch_for_torch.py @@ -0,0 +1,13 @@ +import torch + +original_copy_ = torch.Tensor.copy_ + +def new_copy_(self, *args, **kwargs): + # print(f'{__file__}.new_copy_ {self.dtype=}') + if self.dtype != torch.int8: + return original_copy_(self, *args, **kwargs) + +# Replace the original copy_ method with the new one +torch.Tensor.copy_ = new_copy_ + + From 0913e9035969c242349ca7b58fc97d0563eb53c4 Mon Sep 17 00:00:00 2001 From: FengWen Date: Wed, 22 May 2024 18:59:54 +0800 Subject: [PATCH 2/5] refine --- .../modules/oneflow/patch_management/patch_for_oneflow.py | 7 ++++--- .../modules/oneflow/patch_management/patch_for_torch.py | 7 ++++--- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/onediff_comfy_nodes/modules/oneflow/patch_management/patch_for_oneflow.py b/onediff_comfy_nodes/modules/oneflow/patch_management/patch_for_oneflow.py index 381db0aa0..a3245aa36 100644 --- a/onediff_comfy_nodes/modules/oneflow/patch_management/patch_for_oneflow.py +++ b/onediff_comfy_nodes/modules/oneflow/patch_management/patch_for_oneflow.py @@ -19,10 +19,11 @@ def __init__(self, *args, **kwargs): original_copy_ = flow.Tensor.copy_ -def new_copy_(self, *args, **kwargs): +def new_copy_(self, src, *args, **kwargs): # print(f'{__file__}.new_copy_ {self.dtype=}') - if self.dtype != flow.int8: - return original_copy_(self, *args, **kwargs) + if self.dtype == flow.int8 and src.dtype != flow.int8: + return + return original_copy_(self, *args, **kwargs) # Replace the original copy_ method with the new one flow.Tensor.copy_ = new_copy_ diff --git a/onediff_comfy_nodes/modules/oneflow/patch_management/patch_for_torch.py b/onediff_comfy_nodes/modules/oneflow/patch_management/patch_for_torch.py index d5e66d0d1..a5cb7f3b9 100644 --- a/onediff_comfy_nodes/modules/oneflow/patch_management/patch_for_torch.py +++ b/onediff_comfy_nodes/modules/oneflow/patch_management/patch_for_torch.py @@ -2,10 +2,11 @@ original_copy_ = torch.Tensor.copy_ -def new_copy_(self, *args, **kwargs): +def new_copy_(self, src, *args, **kwargs): # print(f'{__file__}.new_copy_ {self.dtype=}') - if self.dtype != torch.int8: - return original_copy_(self, *args, **kwargs) + if self.dtype == torch.int8 and src.dtype != torch.int8: + return + return original_copy_(self, src, *args, **kwargs) # Replace the original copy_ method with the new one torch.Tensor.copy_ = new_copy_ From 2adff3124dd7bd282a8f9e9c00e53081cd28bfd9 Mon Sep 17 00:00:00 2001 From: FengWen Date: Thu, 23 May 2024 10:04:42 +0800 Subject: [PATCH 3/5] refine --- .../modules/oneflow/patch_management/patch_for_oneflow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onediff_comfy_nodes/modules/oneflow/patch_management/patch_for_oneflow.py b/onediff_comfy_nodes/modules/oneflow/patch_management/patch_for_oneflow.py index a3245aa36..8e110f4e7 100644 --- a/onediff_comfy_nodes/modules/oneflow/patch_management/patch_for_oneflow.py +++ b/onediff_comfy_nodes/modules/oneflow/patch_management/patch_for_oneflow.py @@ -23,7 +23,7 @@ def new_copy_(self, src, *args, **kwargs): # print(f'{__file__}.new_copy_ {self.dtype=}') if self.dtype == flow.int8 and src.dtype != flow.int8: return - return original_copy_(self, *args, **kwargs) + return original_copy_(self, other=src, *args, **kwargs) # Replace the original copy_ method with the new one flow.Tensor.copy_ = new_copy_ From 6fa057d11c3963d17c11330b44c22b105166bfde Mon Sep 17 00:00:00 2001 From: FengWen Date: Thu, 23 May 2024 14:12:22 +0800 Subject: [PATCH 4/5] refine --- .../extras_nodes/nodes_oneflow_booster.py | 2 ++ .../modules/oneflow/hijack_utils.py | 28 +++++++++++++++++++ .../oneflow/patch_management/__init__.py | 1 - .../patch_management/patch_for_oneflow.py | 13 --------- .../patch_management/patch_for_torch.py | 14 ---------- .../backends/oneflow/dual_module.py | 10 ++++++- .../oneflow/online_quantization_utils.py | 2 ++ 7 files changed, 41 insertions(+), 29 deletions(-) create mode 100644 onediff_comfy_nodes/modules/oneflow/hijack_utils.py delete mode 100644 onediff_comfy_nodes/modules/oneflow/patch_management/patch_for_torch.py diff --git a/onediff_comfy_nodes/extras_nodes/nodes_oneflow_booster.py b/onediff_comfy_nodes/extras_nodes/nodes_oneflow_booster.py index c8cfb6103..9daa567e6 100644 --- a/onediff_comfy_nodes/extras_nodes/nodes_oneflow_booster.py +++ b/onediff_comfy_nodes/extras_nodes/nodes_oneflow_booster.py @@ -17,6 +17,7 @@ from ..modules.oneflow.hijack_samplers import samplers_hijack from ..modules.oneflow.hijack_comfyui_instantid import comfyui_instantid_hijacker from ..modules.oneflow.hijack_model_patcher import model_patch_hijacker +from ..modules.oneflow.hijack_utils import comfy_utils_hijack from ..modules.oneflow import BasicOneFlowBoosterExecutor from ..modules.oneflow import DeepcacheBoosterExecutor from ..modules.oneflow import PatchBoosterExecutor @@ -35,6 +36,7 @@ ipadapter_plus_hijacker.hijack() comfyui_instantid_hijacker.hijack() model_patch_hijacker.hijack() +comfy_utils_hijack.hijack() import comfy_extras.nodes_video_model from nodes import CheckpointLoaderSimple diff --git a/onediff_comfy_nodes/modules/oneflow/hijack_utils.py b/onediff_comfy_nodes/modules/oneflow/hijack_utils.py new file mode 100644 index 000000000..4a4f25c5a --- /dev/null +++ b/onediff_comfy_nodes/modules/oneflow/hijack_utils.py @@ -0,0 +1,28 @@ +"""hijack ComfyUI/comfy/utils.py""" +import torch +from comfy.utils import copy_to_param +from ..sd_hijack_utils import Hijacker + + +def copy_to_param_of(org_fn, obj, attr, value): + # inplace update tensor instead of replacing it + attrs = attr.split(".") + for name in attrs[:-1]: + obj = getattr(obj, name) + prev = getattr(obj, attrs[-1]) + + if prev.data.dtype == torch.int8 and prev.data.dtype != value.dtype: + return + + prev.data.copy_(value) + + +def cond_func(orig_func, *args, **kwargs): + return True + + +comfy_utils_hijack = Hijacker() + +comfy_utils_hijack.register( + orig_func=copy_to_param, sub_func=copy_to_param_of, cond_func=cond_func +) diff --git a/onediff_comfy_nodes/modules/oneflow/patch_management/__init__.py b/onediff_comfy_nodes/modules/oneflow/patch_management/__init__.py index f3d65d4c2..5665c4b74 100644 --- a/onediff_comfy_nodes/modules/oneflow/patch_management/__init__.py +++ b/onediff_comfy_nodes/modules/oneflow/patch_management/__init__.py @@ -1,3 +1,2 @@ from .patch_for_oneflow import * -from .patch_for_torch import * from .patch_factory import create_patch_executor, PatchType diff --git a/onediff_comfy_nodes/modules/oneflow/patch_management/patch_for_oneflow.py b/onediff_comfy_nodes/modules/oneflow/patch_management/patch_for_oneflow.py index 8e110f4e7..304d905f7 100644 --- a/onediff_comfy_nodes/modules/oneflow/patch_management/patch_for_oneflow.py +++ b/onediff_comfy_nodes/modules/oneflow/patch_management/patch_for_oneflow.py @@ -14,16 +14,3 @@ def __init__(self, *args, **kwargs): flow.framework.args_tree.NamedArg = PatchNamedArg - - - -original_copy_ = flow.Tensor.copy_ - -def new_copy_(self, src, *args, **kwargs): - # print(f'{__file__}.new_copy_ {self.dtype=}') - if self.dtype == flow.int8 and src.dtype != flow.int8: - return - return original_copy_(self, other=src, *args, **kwargs) - -# Replace the original copy_ method with the new one -flow.Tensor.copy_ = new_copy_ diff --git a/onediff_comfy_nodes/modules/oneflow/patch_management/patch_for_torch.py b/onediff_comfy_nodes/modules/oneflow/patch_management/patch_for_torch.py deleted file mode 100644 index a5cb7f3b9..000000000 --- a/onediff_comfy_nodes/modules/oneflow/patch_management/patch_for_torch.py +++ /dev/null @@ -1,14 +0,0 @@ -import torch - -original_copy_ = torch.Tensor.copy_ - -def new_copy_(self, src, *args, **kwargs): - # print(f'{__file__}.new_copy_ {self.dtype=}') - if self.dtype == torch.int8 and src.dtype != torch.int8: - return - return original_copy_(self, src, *args, **kwargs) - -# Replace the original copy_ method with the new one -torch.Tensor.copy_ = new_copy_ - - diff --git a/src/onediff/infer_compiler/backends/oneflow/dual_module.py b/src/onediff/infer_compiler/backends/oneflow/dual_module.py index 3e6bfb979..3ce4c360e 100644 --- a/src/onediff/infer_compiler/backends/oneflow/dual_module.py +++ b/src/onediff/infer_compiler/backends/oneflow/dual_module.py @@ -29,6 +29,7 @@ def oneflow_module(self): logger.debug(f"Convert {type(self._torch_module)} ...") self._oneflow_module = torch2oflow(self._torch_module) logger.debug(f"Convert {type(self._torch_module)} done!") + return self._oneflow_module @oneflow_module.deleter @@ -99,6 +100,13 @@ def __setattr__(self, name: str, value: Any) -> None: if name in ["_torch_module", "_oneflow_module"]: super().__setattr__(name, value) else: # TODO: aviod memory up when set attr + _torch_module: torch.nn.Module = self._torch_module + if ( + hasattr(_torch_module, "disable_param_update") + and _torch_module.disable_param_update + ): + return + if self._oneflow_module is not None: v = torch2oflow(value) if isinstance(v, flow.Tensor): @@ -106,7 +114,7 @@ def __setattr__(self, name: str, value: Any) -> None: obj.copy_(v) else: setattr(self._oneflow_module, name, v) - setattr(self._torch_module, name, value) + setattr(_torch_module, name, value) def extra_repr(self) -> str: return self._torch_module.extra_repr() diff --git a/src/onediff/infer_compiler/backends/oneflow/online_quantization_utils.py b/src/onediff/infer_compiler/backends/oneflow/online_quantization_utils.py index 472c3d280..20afbb2b6 100644 --- a/src/onediff/infer_compiler/backends/oneflow/online_quantization_utils.py +++ b/src/onediff/infer_compiler/backends/oneflow/online_quantization_utils.py @@ -38,6 +38,8 @@ def online_quantize_model( in_args, in_kwargs = patch_input_adapter(input_args, input_kwargs) quantized_model, info = module.quantize_with_calibration(*in_args, **in_kwargs) status = module.collect_quantization_status(model, info) + for _, layer in quantized_model.named_modules(): + layer.disable_param_update = True return quantized_model, status From 1ef300bb0a19910eb456f363a73155194c812ae8 Mon Sep 17 00:00:00 2001 From: FengWen Date: Thu, 23 May 2024 14:41:07 +0800 Subject: [PATCH 5/5] refine --- src/onediff/infer_compiler/backends/oneflow/dual_module.py | 4 ++-- .../backends/oneflow/online_quantization_utils.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/onediff/infer_compiler/backends/oneflow/dual_module.py b/src/onediff/infer_compiler/backends/oneflow/dual_module.py index 3ce4c360e..903d814c7 100644 --- a/src/onediff/infer_compiler/backends/oneflow/dual_module.py +++ b/src/onediff/infer_compiler/backends/oneflow/dual_module.py @@ -102,8 +102,8 @@ def __setattr__(self, name: str, value: Any) -> None: else: # TODO: aviod memory up when set attr _torch_module: torch.nn.Module = self._torch_module if ( - hasattr(_torch_module, "disable_param_update") - and _torch_module.disable_param_update + hasattr(_torch_module, "_disable_param_update") + and _torch_module._disable_param_update ): return diff --git a/src/onediff/infer_compiler/backends/oneflow/online_quantization_utils.py b/src/onediff/infer_compiler/backends/oneflow/online_quantization_utils.py index 20afbb2b6..1a537dfc9 100644 --- a/src/onediff/infer_compiler/backends/oneflow/online_quantization_utils.py +++ b/src/onediff/infer_compiler/backends/oneflow/online_quantization_utils.py @@ -39,7 +39,7 @@ def online_quantize_model( quantized_model, info = module.quantize_with_calibration(*in_args, **in_kwargs) status = module.collect_quantization_status(model, info) for _, layer in quantized_model.named_modules(): - layer.disable_param_update = True + layer._disable_param_update = True return quantized_model, status