diff --git a/onediff_comfy_nodes/_nodes.py b/onediff_comfy_nodes/_nodes.py index 99ffcc2d7..65662609c 100644 --- a/onediff_comfy_nodes/_nodes.py +++ b/onediff_comfy_nodes/_nodes.py @@ -1,5 +1,7 @@ import folder_paths import torch +import comfy +from onediff.utils.chache_utils import LRUCache from nodes import CheckpointLoaderSimple, ControlNetLoader from ._config import is_disable_oneflow_backend from .modules import BoosterScheduler, BoosterExecutor @@ -175,6 +177,8 @@ def onediff_load_controlnet(self, control_net_name, custom_booster=None): class OneDiffCheckpointLoaderSimple(CheckpointLoaderSimple): + _cache_map = LRUCache(1) + @classmethod def INPUT_TYPES(s): return { @@ -188,24 +192,45 @@ def INPUT_TYPES(s): CATEGORY = "OneDiff/Loaders" FUNCTION = "onediff_load_checkpoint" - @torch.no_grad() - def onediff_load_checkpoint( - self, - ckpt_name, - vae_speedup="disable", - output_vae=True, - output_clip=True, - custom_booster: BoosterScheduler = None, + @staticmethod + def _load_checkpoint( + ckpt_name, vae_speedup="disable", custom_booster: BoosterScheduler = None ): - # CheckpointLoaderSimple.load_checkpoint - modelpatcher, clip, vae = self.load_checkpoint( - ckpt_name, output_vae, output_clip + """Loads a checkpoint, applying speedup techniques.""" + + ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) + out = comfy.sd.load_checkpoint_guess_config( + ckpt_path, + output_vae=True, + output_clip=True, + embedding_directory=folder_paths.get_folder_paths("embeddings"), ) - if custom_booster is None: - custom_booster = BoosterScheduler(BasicBoosterExecutor()) + + # Unpack outputs + modelpatcher, clip, vae = out[:3] + + # Apply custom booster if provided, otherwise use a basic one + custom_booster = custom_booster or BoosterScheduler(BasicBoosterExecutor()) modelpatcher = custom_booster(modelpatcher, ckpt_name=ckpt_name) + + # Apply VAE speedup if enabled if vae_speedup == "enable": vae = BoosterScheduler(BasicBoosterExecutor())(vae, ckpt_name=ckpt_name) - # set inplace update + + # Set weight inplace update modelpatcher.weight_inplace_update = True + return modelpatcher, clip, vae + + @torch.inference_mode() + def onediff_load_checkpoint( + self, ckpt_name, vae_speedup="disable", custom_booster: BoosterScheduler = None, + ): + cache_key = (ckpt_name, vae_speedup, custom_booster) + out = self._cache_map.get(cache_key, None) + if out is None: + out = self._load_checkpoint(ckpt_name, vae_speedup, custom_booster) + self._cache_map.put(cache_key, out) + + # Return the loaded checkpoint (modelpatcher, clip, vae) + return out diff --git a/onediff_comfy_nodes/modules/oneflow/hijack_ipadapter_plus/set_model_patch_replace.py b/onediff_comfy_nodes/modules/oneflow/hijack_ipadapter_plus/set_model_patch_replace.py index 3a323a2a2..8c2f6bc87 100644 --- a/onediff_comfy_nodes/modules/oneflow/hijack_ipadapter_plus/set_model_patch_replace.py +++ b/onediff_comfy_nodes/modules/oneflow/hijack_ipadapter_plus/set_model_patch_replace.py @@ -37,7 +37,9 @@ def split_patch_kwargs(patch_kwargs): split1dict = {} split2dict = {} for k, v in patch_kwargs.items(): - if k in ["cond", "uncond", "mask", "weight"]: + if k in ["cond", "cond_alt", "uncond", "mask", "weight"] or isinstance( + v, torch.Tensor + ): split1dict[k] = v else: split2dict[k] = v diff --git a/src/onediff/infer_compiler/backends/oneflow/args_tree_util.py b/src/onediff/infer_compiler/backends/oneflow/args_tree_util.py index cf0d40de4..efab5da01 100644 --- a/src/onediff/infer_compiler/backends/oneflow/args_tree_util.py +++ b/src/onediff/infer_compiler/backends/oneflow/args_tree_util.py @@ -42,12 +42,26 @@ def wrapper(self: "OneflowDeployableModule", *args, **kwargs): and self._deployable_module_dpl_graph is not None and self._deployable_module_input_structure_key != input_structure_key ): - logger.warning( - "Input structure key has changed. Resetting the deployable module graph." + # Retrieve the deployable module graph from cache using the input structure key + dpl_graph = self._deployable_module_graph_cache.get( + input_structure_key, None ) - self._deployable_module_dpl_graph = None - self._load_graph_first_run = True - self._deployable_module_input_structure_key = None + self._deployable_module_graph_cache.put( + self._deployable_module_input_structure_key, + self._deployable_module_dpl_graph, + ) + + # If a cached graph is found, update the deployable module graph and input structure key + if dpl_graph is not None: + self._deployable_module_dpl_graph = dpl_graph + self._deployable_module_input_structure_key = input_structure_key + else: + logger.warning( + f"Input structure key {self._deployable_module_input_structure_key} to {input_structure_key} has changed. Resetting the deployable module graph. This may slow down the process." + ) + self._deployable_module_dpl_graph = None + self._deployable_module_input_structure_key = None + self._load_graph_first_run = True output = func(self, *mapped_args, **mapped_kwargs) return process_output(output) diff --git a/src/onediff/infer_compiler/backends/oneflow/deployable_module.py b/src/onediff/infer_compiler/backends/oneflow/deployable_module.py index 5fa8956a6..caa9f07cf 100644 --- a/src/onediff/infer_compiler/backends/oneflow/deployable_module.py +++ b/src/onediff/infer_compiler/backends/oneflow/deployable_module.py @@ -5,6 +5,7 @@ import oneflow as flow from onediff.utils import logger +from onediff.utils.chache_utils import LRUCache from ..deployable_module import DeployableModule @@ -70,6 +71,9 @@ def __init__( options if options is not None else OneflowCompileOptions() ) self._deployable_module_dpl_graph = None + self._deployable_module_graph_cache = LRUCache( + self._deployable_module_options.max_cached_graph_size + ) self._is_raw_deployable_module = True self._load_graph_first_run = True self._deployable_module_input_structure_key = None @@ -84,6 +88,9 @@ def from_existing(cls, existing_module, dynamic=True, options=None): instance._deployable_module_dpl_graph = ( existing_module._deployable_module_dpl_graph ) + instance._deployable_module_graph_cache = ( + existing_module._deployable_module_graph_cache + ) instance._load_graph_first_run = existing_module._load_graph_first_run instance._deployable_module_input_structure_key = ( existing_module._deployable_module_input_structure_key diff --git a/src/onediff/infer_compiler/backends/oneflow/transform/custom_transform.py b/src/onediff/infer_compiler/backends/oneflow/transform/custom_transform.py index 35fcc1065..2ea176e82 100644 --- a/src/onediff/infer_compiler/backends/oneflow/transform/custom_transform.py +++ b/src/onediff/infer_compiler/backends/oneflow/transform/custom_transform.py @@ -37,17 +37,16 @@ def register_torch2oflow_func(func, first_param_type=None, verbose=False): def set_default_registry(): - mocked_packages = transform_mgr.get_mocked_packages() - def import_module_safely(module_path, module_name): - nonlocal mocked_packages - if module_name in mocked_packages: + if module_name in transform_mgr.loaded_modules: return try: import_module_from_path(module_path) except Exception as e: logger.warning(f"Failed to import {module_name} from {module_path}. {e=}") + finally: + transform_mgr.loaded_modules.add(module_name) # compiler_registry_path registry_path = Path(__file__).parents[5] / "infer_compiler_registry" diff --git a/src/onediff/infer_compiler/backends/oneflow/transform/manager.py b/src/onediff/infer_compiler/backends/oneflow/transform/manager.py index 376b1e881..dd572c11d 100644 --- a/src/onediff/infer_compiler/backends/oneflow/transform/manager.py +++ b/src/onediff/infer_compiler/backends/oneflow/transform/manager.py @@ -25,6 +25,7 @@ def __init__(self, debug_mode=False, tmp_dir="./output"): self._oflow_to_torch_cls_map = {} self._setup_logger() self.mocker = LazyMocker(prefix="", suffix="", tmp_dir=None) + self.loaded_modules = set() def _setup_logger(self): name = "ONEDIFF" diff --git a/src/onediff/utils/chache_utils.py b/src/onediff/utils/chache_utils.py new file mode 100644 index 000000000..f3684da29 --- /dev/null +++ b/src/onediff/utils/chache_utils.py @@ -0,0 +1,21 @@ +import collections + + +class LRUCache(collections.OrderedDict): + __slots__ = ["LEN"] + + def __init__(self, capacity: int): + self.LEN = capacity + + def get(self, key: str, default=None) -> any: + if key in self: + self.move_to_end(key) + return self[key] + else: + return default + + def put(self, key: str, value: any) -> None: + self[key] = value + self.move_to_end(key) + if len(self) > self.LEN: + self.popitem(last=False)