Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify after Cancelling Quantitative Model #897

Merged
merged 6 commits into from
May 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions onediff_comfy_nodes/extras_nodes/nodes_oneflow_booster.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from ..modules.oneflow.hijack_samplers import samplers_hijack
from ..modules.oneflow.hijack_comfyui_instantid import comfyui_instantid_hijacker
from ..modules.oneflow.hijack_model_patcher import model_patch_hijacker
from ..modules.oneflow.hijack_utils import comfy_utils_hijack
from ..modules.oneflow import BasicOneFlowBoosterExecutor
from ..modules.oneflow import DeepcacheBoosterExecutor
from ..modules.oneflow import PatchBoosterExecutor
Expand All @@ -35,6 +36,7 @@
ipadapter_plus_hijacker.hijack()
comfyui_instantid_hijacker.hijack()
model_patch_hijacker.hijack()
comfy_utils_hijack.hijack()

import comfy_extras.nodes_video_model
from nodes import CheckpointLoaderSimple
Expand Down
2 changes: 1 addition & 1 deletion onediff_comfy_nodes/modules/oneflow/booster_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def _(self, model: ModelPatcher, ckpt_name: Optional[str] = None, **kwargs):
)
set_compiled_options(compiled_model, graph_file)

model.weight_inplace_update = True

return model

@execute.register(ControlNet)
Expand Down
28 changes: 28 additions & 0 deletions onediff_comfy_nodes/modules/oneflow/hijack_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""hijack ComfyUI/comfy/utils.py"""
import torch
from comfy.utils import copy_to_param
from ..sd_hijack_utils import Hijacker


def copy_to_param_of(org_fn, obj, attr, value):
# inplace update tensor instead of replacing it
attrs = attr.split(".")
for name in attrs[:-1]:
obj = getattr(obj, name)
prev = getattr(obj, attrs[-1])

if prev.data.dtype == torch.int8 and prev.data.dtype != value.dtype:
return

prev.data.copy_(value)


def cond_func(orig_func, *args, **kwargs):
return True


comfy_utils_hijack = Hijacker()

comfy_utils_hijack.register(
orig_func=copy_to_param, sub_func=copy_to_param_of, cond_func=cond_func
)
10 changes: 9 additions & 1 deletion src/onediff/infer_compiler/backends/oneflow/dual_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def oneflow_module(self):
logger.debug(f"Convert {type(self._torch_module)} ...")
self._oneflow_module = torch2oflow(self._torch_module)
logger.debug(f"Convert {type(self._torch_module)} done!")

return self._oneflow_module

@oneflow_module.deleter
Expand Down Expand Up @@ -99,14 +100,21 @@ def __setattr__(self, name: str, value: Any) -> None:
if name in ["_torch_module", "_oneflow_module"]:
super().__setattr__(name, value)
else: # TODO: aviod memory up when set attr
_torch_module: torch.nn.Module = self._torch_module
if (
hasattr(_torch_module, "_disable_param_update")
and _torch_module._disable_param_update
):
return

if self._oneflow_module is not None:
v = torch2oflow(value)
if isinstance(v, flow.Tensor):
obj = getattr(self._oneflow_module, name)
obj.copy_(v)
else:
setattr(self._oneflow_module, name, v)
setattr(self._torch_module, name, value)
setattr(_torch_module, name, value)
strint marked this conversation as resolved.
Show resolved Hide resolved

def extra_repr(self) -> str:
return self._torch_module.extra_repr()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ def online_quantize_model(
in_args, in_kwargs = patch_input_adapter(input_args, input_kwargs)
quantized_model, info = module.quantize_with_calibration(*in_args, **in_kwargs)
status = module.collect_quantization_status(model, info)
for _, layer in quantized_model.named_modules():
layer._disable_param_update = True

return quantized_model, status

Expand Down