From 32c352e20e651e7c79318b1eb0952a0b0850565d Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Wed, 21 Aug 2024 20:29:54 -0500 Subject: [PATCH 01/43] Initial scaffolding for supporting FancyVideo, refactored load_motion_module code to manually parse load_state_dict load_result --- animatediff/model_injection.py | 46 ++++++++++++++++++++++++++------- animatediff/motion_module_ad.py | 25 ++++++++++++++++-- animatediff/motion_module_fv.py | 0 3 files changed, 59 insertions(+), 12 deletions(-) create mode 100644 animatediff/motion_module_fv.py diff --git a/animatediff/model_injection.py b/animatediff/model_injection.py index 27c7da8..26c4a28 100644 --- a/animatediff/model_injection.py +++ b/animatediff/model_injection.py @@ -1,5 +1,6 @@ import copy from typing import Union, Callable +from collections import namedtuple from einops import rearrange from torch import Tensor @@ -19,7 +20,7 @@ from .ad_settings import AnimateDiffSettings, AdjustPE, AdjustWeight from .adapter_cameractrl import CameraPoseEncoder, CameraEntry, prepare_pose_embedding from .context import ContextOptions, ContextOptions, ContextOptionsGroup -from .motion_module_ad import (AnimateDiffModel, AnimateDiffFormat, EncoderOnlyAnimateDiffModel, VersatileAttention, PerBlock, AllPerBlocks, +from .motion_module_ad import (AnimateDiffModel, AnimateDiffFormat, AnimateDiffInfo, EncoderOnlyAnimateDiffModel, VersatileAttention, PerBlock, AllPerBlocks, has_mid_block, normalize_ad_state_dict, get_position_encoding_max_len) from .logger import logger from .utils_motion import (ADKeyframe, ADKeyframeGroup, MotionCompatibilityError, InputPIA, @@ -1263,9 +1264,8 @@ def load_motion_module_gen1(model_name: str, model: ModelPatcher, motion_lora: M ad_wrapper = AnimateDiffModel(mm_state_dict=mm_state_dict, mm_info=mm_info) ad_wrapper.to(model.model_dtype()) ad_wrapper.to(model.offload_device) - is_animatelcm = mm_info.mm_format==AnimateDiffFormat.ANIMATELCM - load_result = ad_wrapper.load_state_dict(mm_state_dict, strict=not is_animatelcm) - # TODO: report load_result of motion_module loading? + load_result = ad_wrapper.load_state_dict(mm_state_dict, strict=False) + verify_load_result(load_result=load_result, mm_info=mm_info) # wrap motion_module into a ModelPatcher, to allow motion lora patches motion_model = MotionModelPatcher(model=ad_wrapper, load_device=model.load_device, offload_device=model.offload_device) # load motion_lora, if present @@ -1288,18 +1288,44 @@ def load_motion_module_gen2(model_name: str, motion_model_settings: AnimateDiffS ad_wrapper = AnimateDiffModel(mm_state_dict=mm_state_dict, mm_info=mm_info) ad_wrapper.to(comfy.model_management.unet_dtype()) ad_wrapper.to(comfy.model_management.unet_offload_device()) - is_animatelcm = mm_info.mm_format==AnimateDiffFormat.ANIMATELCM - load_result = ad_wrapper.load_state_dict(mm_state_dict, strict=not is_animatelcm) - # TODO: manually check load_results for AnimateLCM models - if is_animatelcm: - pass - # TODO: report load_result of motion_module loading? + load_result = ad_wrapper.load_state_dict(mm_state_dict, strict=False) + verify_load_result(load_result=load_result, mm_info=mm_info) # wrap motion_module into a ModelPatcher, to allow motion lora patches motion_model = MotionModelPatcher(model=ad_wrapper, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device()) return motion_model +IncompatibleKeys = namedtuple('IncompatibleKeys', ['missing_keys', 'unexpected_keys']) +def verify_load_result(load_result: IncompatibleKeys, mm_info: AnimateDiffInfo): + error_msgs: list[str] = [] + is_animatelcm = mm_info.mm_format==AnimateDiffFormat.ANIMATELCM + + remove_missing_idxs = [] + remove_unexpected_idxs = [] + for idx, key in enumerate(load_result.missing_keys): + # NOTE: AnimateLCM has no pe keys in the model file, so any errors associated with missing pe keys can be ignored + if is_animatelcm and "pos_encoder.pe" in key: + remove_missing_idxs.append(idx) + # remove any keys to ignore in reverse order (to preserve idx correlation) + for idx in reversed(remove_unexpected_idxs): + load_result.unexpected_keys.pop(idx) + for idx in reversed(remove_missing_idxs): + load_result.missing_keys.pop(idx) + # copied over from torch.nn.Module.module class Module's load_state_dict func + if len(load_result.unexpected_keys) > 0: + error_msgs.insert( + 0, 'Unexpected key(s) in state_dict: {}. '.format( + ', '.join(f'"{k}"' for k in load_result.unexpected_keys))) + if len(load_result.missing_keys) > 0: + error_msgs.insert( + 0, 'Missing key(s) in state_dict: {}. '.format( + ', '.join(f'"{k}"' for k in load_result.missing_keys))) + if len(error_msgs) > 0: + raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( + mm_info.mm_name, "\n\t".join(error_msgs))) + + def create_fresh_motion_module(motion_model: MotionModelPatcher) -> MotionModelPatcher: ad_wrapper = AnimateDiffModel(mm_state_dict=motion_model.model.state_dict(), mm_info=motion_model.model.mm_info) ad_wrapper.to(comfy.model_management.unet_dtype()) diff --git a/animatediff/motion_module_ad.py b/animatediff/motion_module_ad.py index b66538e..ab46790 100644 --- a/animatediff/motion_module_ad.py +++ b/animatediff/motion_module_ad.py @@ -40,8 +40,9 @@ class AnimateDiffFormat: HOTSHOTXL = "HotshotXL" ANIMATELCM = "AnimateLCM" PIA = "PIA" + FANCYVIDEO = "FancyVideo" - _LIST = [ANIMATEDIFF, HOTSHOTXL, ANIMATELCM, PIA] + _LIST = [ANIMATEDIFF, HOTSHOTXL, ANIMATELCM, PIA, FANCYVIDEO] class AnimateDiffVersion: @@ -141,6 +142,12 @@ def is_pia(mm_state_dict: dict[str, Tensor]) -> bool: return False +def is_fancyvideo(mm_state_dict: dict[str, Tensor]) -> bool: + if 'FancyVideo' in mm_state_dict: + return True + return False + + def get_down_block_max(mm_state_dict: dict[str, Tensor]) -> int: # keep track of biggest down_block count in module biggest_block = 0 @@ -195,7 +202,16 @@ def normalize_ad_state_dict(mm_state_dict: dict[str, Tensor], mm_name: str) -> T # log_name = mm_name.split('\\')[-1] # with open(Path(__file__).parent.parent.parent / rf"keys_{log_name}.txt", "w") as afile: # for key, value in mm_state_dict.items(): - # afile.write(f"{key}:\t{value.shape}\n") + # if key == 'module': + # for inkey, invalue in value.items(): + # if hasattr(invalue, 'shape'): + # afile.write(f"{inkey}:\t{invalue.shape}\n") + # else: + # afile.write(f"{inkey}:\t{invalue}\n") + # elif hasattr(value, 'shape'): + # afile.write(f"{key}:\t{value.shape}\n") + # else: + # afile.write(f"{key}:\t{type(value)}\n") # determine what SD model the motion module is intended for sd_type: str = None down_block_max = get_down_block_max(mm_state_dict) @@ -213,6 +229,9 @@ def normalize_ad_state_dict(mm_state_dict: dict[str, Tensor], mm_name: str) -> T mm_format = AnimateDiffFormat.ANIMATELCM if is_pia(mm_state_dict): mm_format = AnimateDiffFormat.PIA + if is_fancyvideo(mm_state_dict): + mm_format = AnimateDiffFormat.FANCYVIDEO + mm_state_dict.pop("FancyVideo") # for AnimateLCM-I2V purposes, check for img_encoder keys contains_img_encoder = has_img_encoder(mm_state_dict) # remove all non-temporal keys (in case model has extra stuff in it) @@ -222,6 +241,8 @@ def normalize_ad_state_dict(mm_state_dict: dict[str, Tensor], mm_name: str) -> T continue if mm_format == AnimateDiffFormat.PIA and key.startswith("conv_in."): continue + if mm_format == AnimateDiffFormat.FANCYVIDEO: + continue del mm_state_dict[key] # determine the model's version mm_version = AnimateDiffVersion.V1 diff --git a/animatediff/motion_module_fv.py b/animatediff/motion_module_fv.py new file mode 100644 index 0000000..e69de29 From 7c2c863038192cd9d9140b5e58c8587a95ef7291 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Wed, 21 Aug 2024 21:07:09 -0500 Subject: [PATCH 02/43] Improved description for Value Scheduling node --- animatediff/nodes_scheduling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/animatediff/nodes_scheduling.py b/animatediff/nodes_scheduling.py index 660b09f..dece973 100644 --- a/animatediff/nodes_scheduling.py +++ b/animatediff/nodes_scheduling.py @@ -158,7 +158,7 @@ def INPUT_TYPES(s): FUNCTION = "create_schedule" Desc = [ - short_desc('Create a list of values, its length matching passed-in latent count.'), + short_desc('Create a list of values with automatic interpolation, its length matching passed-in latent count.'), {'Format': desc_format_values}, {coll('Inputs'): DocHelper.combine(desc_values, desc_latent, desc_print_schedule)}, ] @@ -195,7 +195,7 @@ def INPUT_TYPES(s): FUNCTION = "create_schedule" Desc = [ - short_desc('Create a list of values, its length matching passed-in latent count.'), + short_desc('Create a list of values with automatic interpolation.'), {'Format': desc_format_values}, {coll('Inputs'): DocHelper.combine(desc_values, desc_max_length, desc_print_schedule)}, ] From ae7ae379cdbb4de4c9dab45c5d657a7d0554f3ff Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Fri, 23 Aug 2024 04:43:47 -0500 Subject: [PATCH 03/43] In Create Raw Sigma Schedule, renamed lcm_zsnr to zsnr and changed logic to allow applying zsnr for non-lcm --- animatediff/nodes_sigma_schedule.py | 9 +++++---- animatediff/utils_model.py | 2 +- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/animatediff/nodes_sigma_schedule.py b/animatediff/nodes_sigma_schedule.py index 828e2ca..a27b36c 100644 --- a/animatediff/nodes_sigma_schedule.py +++ b/animatediff/nodes_sigma_schedule.py @@ -44,7 +44,7 @@ def INPUT_TYPES(s): #"cosine_s": ("FLOAT", {"default": 8e-3, "min": 0.0, "max": 1.0, "step": 0.000001}), "sampling": (ModelSamplingType._FULL_LIST,), "lcm_original_timesteps": ("INT", {"default": 50, "min": 1, "max": 1000}), - "lcm_zsnr": ("BOOLEAN", {"default": False}), + "zsnr": ("BOOLEAN", {"default": False}), } } @@ -53,14 +53,15 @@ def INPUT_TYPES(s): FUNCTION = "get_sigma_schedule" def get_sigma_schedule(self, raw_beta_schedule: str, linear_start: float, linear_end: float,# cosine_s: float, - sampling: str, lcm_original_timesteps: int, lcm_zsnr: bool): + sampling: str, lcm_original_timesteps: int, zsnr: bool, lcm_zsnr: bool=None): + if lcm_zsnr is not None: + zsnr = lcm_zsnr new_config = ModelSamplingConfig(beta_schedule=raw_beta_schedule, linear_start=linear_start, linear_end=linear_end) if sampling != ModelSamplingType.LCM: lcm_original_timesteps=None - lcm_zsnr=False model_type = ModelSamplingType.from_alias(sampling) new_model_sampling = BetaSchedules._to_model_sampling(alias=BetaSchedules.AUTOSELECT, model_type=model_type, config_override=new_config, original_timesteps=lcm_original_timesteps) - if lcm_zsnr: + if zsnr: SigmaSchedule.apply_zsnr(new_model_sampling=new_model_sampling) return (SigmaSchedule(model_sampling=new_model_sampling, model_type=model_type),) diff --git a/animatediff/utils_model.py b/animatediff/utils_model.py index 48c432b..930cdc9 100644 --- a/animatediff/utils_model.py +++ b/animatediff/utils_model.py @@ -194,7 +194,7 @@ def to_config(cls, alias: str) -> ModelSamplingConfig: return ModelSamplingConfig(cls.to_name(alias), linear_start=linear_start, linear_end=linear_end) @classmethod - def _to_model_sampling(cls, alias: str, model_type: ModelType, config_override: ModelSamplingConfig=None, original_timesteps: int=None): + def _to_model_sampling(cls, alias: str, model_type: ModelType, config_override: Union[ModelSamplingConfig,None]=None, original_timesteps: Union[int,None]=None): if alias == cls.USE_EXISTING: return None elif config_override != None: From 0837d2b1978f6a9695fff79cb2162b68d6331d99 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Sun, 1 Sep 2024 09:31:53 -0500 Subject: [PATCH 04/43] Adjust ADEAUTOSIZE since after description feature was added, it works differently and most adjustments are no longer necessary (and their presence makes the nodes too wide) --- animatediff/nodes_ad_settings.py | 10 +++++----- animatediff/nodes_cameractrl.py | 4 ++-- animatediff/nodes_conditioning.py | 22 +++++++++++----------- animatediff/nodes_context_extras.py | 16 ++++++++-------- animatediff/nodes_multival.py | 10 +++++----- animatediff/nodes_pia.py | 4 ++-- animatediff/nodes_sample.py | 24 ++++++++++++------------ animatediff/nodes_sigma_schedule.py | 8 ++++---- 8 files changed, 49 insertions(+), 49 deletions(-) diff --git a/animatediff/nodes_ad_settings.py b/animatediff/nodes_ad_settings.py index e99364b..3eb0de1 100644 --- a/animatediff/nodes_ad_settings.py +++ b/animatediff/nodes_ad_settings.py @@ -35,7 +35,7 @@ def INPUT_TYPES(s): }, "optional": { "prev_pe_adjust": ("PE_ADJUST",), - "autosize": ("ADEAUTOSIZE", {"padding": 30}), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -67,7 +67,7 @@ def INPUT_TYPES(s): }, "optional": { "prev_pe_adjust": ("PE_ADJUST",), - "autosize": ("ADEAUTOSIZE", {"padding": 20}), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -95,7 +95,7 @@ def INPUT_TYPES(s): }, "optional": { "prev_pe_adjust": ("PE_ADJUST",), - "autosize": ("ADEAUTOSIZE", {"padding": 30}), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -258,7 +258,7 @@ def INPUT_TYPES(s): }, "optional": { "prev_weight_adjust": ("WEIGHT_ADJUST",), - "autosize": ("ADEAUTOSIZE", {"padding": 20}), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -305,7 +305,7 @@ def INPUT_TYPES(s): }, "optional": { "prev_weight_adjust": ("WEIGHT_ADJUST",), - "autosize": ("ADEAUTOSIZE", {"padding": 20}), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } diff --git a/animatediff/nodes_cameractrl.py b/animatediff/nodes_cameractrl.py index 6ba94e7..9e9a1ab 100644 --- a/animatediff/nodes_cameractrl.py +++ b/animatediff/nodes_cameractrl.py @@ -273,7 +273,7 @@ def INPUT_TYPES(s): "cameractrl_multival": ("MULTIVAL",), "inherit_missing": ("BOOLEAN", {"default": True}, ), "guarantee_steps": ("INT", {"default": 1, "min": 0, "max": BIGMAX}), - "autosize": ("ADEAUTOSIZE", {"padding": 30}), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -373,7 +373,7 @@ def INPUT_TYPES(cls): }, "optional": { "prev_poses": ("CAMERACTRL_POSES",), - "autosize": ("ADEAUTOSIZE", {"padding": 30}), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } diff --git a/animatediff/nodes_conditioning.py b/animatediff/nodes_conditioning.py index 57f1233..9e3b263 100644 --- a/animatediff/nodes_conditioning.py +++ b/animatediff/nodes_conditioning.py @@ -97,7 +97,7 @@ def INPUT_TYPES(s): "opt_mask": ("MASK", ), "opt_lora_hook": ("LORA_HOOK",), "opt_timesteps": ("TIMESTEPS_COND",), - "autosize": ("ADEAUTOSIZE", {"padding": 70}), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -129,7 +129,7 @@ def INPUT_TYPES(s): "opt_mask": ("MASK", ), "opt_lora_hook": ("LORA_HOOK",), "opt_timesteps": ("TIMESTEPS_COND",), - "autosize": ("ADEAUTOSIZE", {"padding": 55}), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -158,7 +158,7 @@ def INPUT_TYPES(s): }, "optional": { "opt_lora_hook": ("LORA_HOOK",), - "autosize": ("ADEAUTOSIZE", {"padding": 10}), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -256,7 +256,7 @@ def INPUT_TYPES(s): "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}) }, "optional": { - "autosize": ("ADEAUTOSIZE", {"padding": 25}), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -277,7 +277,7 @@ def INPUT_TYPES(s): "hook_kf": ("LORA_HOOK_KEYFRAMES",), }, "optional": { - "autosize": ("ADEAUTOSIZE", {"padding": 40}), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -302,7 +302,7 @@ def INPUT_TYPES(s): }, "optional": { "prev_hook_kf": ("LORA_HOOK_KEYFRAMES",), - "autosize": ("ADEAUTOSIZE", {"padding": 5}), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -337,7 +337,7 @@ def INPUT_TYPES(s): }, "optional": { "prev_hook_kf": ("LORA_HOOK_KEYFRAMES",), - "autosize": ("ADEAUTOSIZE", {"padding": 70}), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -560,7 +560,7 @@ def INPUT_TYPES(s): "lora_hook": ("LORA_HOOK",), }, "optional": { - "autosize": ("ADEAUTOSIZE", {"padding": 5}), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -610,7 +610,7 @@ def INPUT_TYPES(s): "optional": { "lora_hook_A": ("LORA_HOOK",), "lora_hook_B": ("LORA_HOOK",), - "autosize": ("ADEAUTOSIZE", {"padding": 30}), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -634,7 +634,7 @@ def INPUT_TYPES(s): "lora_hook_B": ("LORA_HOOK",), "lora_hook_C": ("LORA_HOOK",), "lora_hook_D": ("LORA_HOOK",), - "autosize": ("ADEAUTOSIZE", {"padding": 30}), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -664,7 +664,7 @@ def INPUT_TYPES(s): "lora_hook_F": ("LORA_HOOK",), "lora_hook_G": ("LORA_HOOK",), "lora_hook_H": ("LORA_HOOK",), - "autosize": ("ADEAUTOSIZE", {"padding": 30}), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } diff --git a/animatediff/nodes_context_extras.py b/animatediff/nodes_context_extras.py index 8a28e8d..10b3b73 100644 --- a/animatediff/nodes_context_extras.py +++ b/animatediff/nodes_context_extras.py @@ -120,7 +120,7 @@ def INPUT_TYPES(s): "optional": { "prev_kf": ("NAIVEREUSE_KEYFRAME",), "mult_multival": ("MULTIVAL",), - "autosize": ("ADEAUTOSIZE", {"padding": 50}), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -305,7 +305,7 @@ def INPUT_TYPES(s): "mult_multival": ("MULTIVAL",), "mode_replace": ("CONTEXTREF_MODE",), "tune_replace": ("CONTEXTREF_TUNE",), - "autosize": ("ADEAUTOSIZE", {"padding": 50}), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -354,7 +354,7 @@ def INPUT_TYPES(s): "mult_multival": ("MULTIVAL",), "mode_replace": ("CONTEXTREF_MODE",), "tune_replace": ("CONTEXTREF_TUNE",), - "autosize": ("ADEAUTOSIZE", {"padding": 50}), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -398,7 +398,7 @@ def INPUT_TYPES(s): "required": { }, "optional": { - "autosize": ("ADEAUTOSIZE", {"padding": 25}), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), }, } @@ -419,7 +419,7 @@ def INPUT_TYPES(s): }, "optional": { "sliding_width": ("INT", {"default": 2, "min": 2, "max": BIGMAX, "step": 1}), - "autosize": ("ADEAUTOSIZE", {"padding": 42}), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -441,7 +441,7 @@ def INPUT_TYPES(s): "optional": { "switch_on_idxs": ("STRING", {"default": ""}), "always_include_0": ("BOOLEAN", {"default": True},), - "autosize": ("ADEAUTOSIZE", {"padding": 50}), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), }, } @@ -470,7 +470,7 @@ def INPUT_TYPES(s): "adain_style_fidelity": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), "adain_ref_weight": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), "adain_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), - "autosize": ("ADEAUTOSIZE", {"padding": 65}), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -496,7 +496,7 @@ def INPUT_TYPES(s): "attn_style_fidelity": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), "attn_ref_weight": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), "attn_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), - "autosize": ("ADEAUTOSIZE", {"padding": 15}), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } diff --git a/animatediff/nodes_multival.py b/animatediff/nodes_multival.py index 272b6cd..d2f4a52 100644 --- a/animatediff/nodes_multival.py +++ b/animatediff/nodes_multival.py @@ -22,7 +22,7 @@ def INPUT_TYPES(s): }, "optional": { "mask_optional": ("MASK",), - "autosize": ("ADEAUTOSIZE", {"padding": 10}), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -45,7 +45,7 @@ def INPUT_TYPES(s): }, "optional": { "scaling": (ScaleType.LIST,), - "autosize": ("ADEAUTOSIZE", {"padding": 10}), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -91,7 +91,7 @@ def INPUT_TYPES(s): }, "optional": { "mask_optional": ("MASK",), - "autosize": ("ADEAUTOSIZE", {"padding": 10}), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -132,7 +132,7 @@ def INPUT_TYPES(s): "float_val": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001},), }, "optional": { - "autosize": ("ADEAUTOSIZE", {"padding": 10}), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -152,7 +152,7 @@ def INPUT_TYPES(s): "multival": ("MULTIVAL",), }, "optional": { - "autosize": ("ADEAUTOSIZE", {"padding": 10}), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } diff --git a/animatediff/nodes_pia.py b/animatediff/nodes_pia.py index 5944485..d5d9beb 100644 --- a/animatediff/nodes_pia.py +++ b/animatediff/nodes_pia.py @@ -201,7 +201,7 @@ def INPUT_TYPES(s): "pia_input": ("PIA_INPUT",), "inherit_missing": ("BOOLEAN", {"default": True}, ), "guarantee_steps": ("INT", {"default": 1, "min": 0, "max": BIGMAX}), - "autosize": ("ADEAUTOSIZE", {"padding": 5}), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -253,7 +253,7 @@ def INPUT_TYPES(s): "optional": { "mult_multival": ("MULTIVAL",), "print_values": ("BOOLEAN", {"default": False},), - "autosize": ("ADEAUTOSIZE", {"padding": 60}), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), #"effect_multival": ("MULTIVAL",), } } diff --git a/animatediff/nodes_sample.py b/animatediff/nodes_sample.py index 70f4f1c..b3e86b8 100644 --- a/animatediff/nodes_sample.py +++ b/animatediff/nodes_sample.py @@ -33,7 +33,7 @@ def INPUT_TYPES(s): "custom_cfg": ("CUSTOM_CFG",), "sigma_schedule": ("SIGMA_SCHEDULE",), "image_inject": ("IMAGE_INJECT",), - "autosize": ("ADEAUTOSIZE", {"padding": 10}), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -65,7 +65,7 @@ def INPUT_TYPES(s): "prev_noise_layers": ("NOISE_LAYERS",), "mask_optional": ("MASK",), "seed_override": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff, "forceInput": True}), - "autosize": ("ADEAUTOSIZE", {"padding": 20}), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -101,7 +101,7 @@ def INPUT_TYPES(s): "prev_noise_layers": ("NOISE_LAYERS",), "mask_optional": ("MASK",), "seed_override": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff, "forceInput": True}), - "autosize": ("ADEAUTOSIZE", {"padding": 20}), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -140,7 +140,7 @@ def INPUT_TYPES(s): "prev_noise_layers": ("NOISE_LAYERS",), "mask_optional": ("MASK",), "seed_override": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff, "forceInput": True}), - "autosize": ("ADEAUTOSIZE", {"padding": 10}), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -202,7 +202,7 @@ def INPUT_TYPES(s): "optional": { "iter_batch_offset": ("INT", {"default": 0, "min": 0, "max": BIGMAX}), "iter_seed_offset": ("INT", {"default": 1, "min": BIGMIN, "max": BIGMAX}), - "autosize": ("ADEAUTOSIZE", {"padding": 55}), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -229,7 +229,7 @@ def INPUT_TYPES(s): }, "optional": { "cfg_extras": ("CFG_EXTRAS",), - "autosize": ("ADEAUTOSIZE", {"padding": 20}), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -253,7 +253,7 @@ def INPUT_TYPES(s): }, "optional": { "cfg_extras": ("CFG_EXTRAS",), - "autosize": ("ADEAUTOSIZE", {"padding": 10}), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -277,7 +277,7 @@ def INPUT_TYPES(s): "optional": { "prev_custom_cfg": ("CUSTOM_CFG",), "cfg_extras": ("CFG_EXTRAS",), - "autosize": ("ADEAUTOSIZE", {"padding": 80}), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -337,7 +337,7 @@ def INPUT_TYPES(s): "optional": { "prev_custom_cfg": ("CUSTOM_CFG",), "cfg_extras": ("CFG_EXTRAS",), - "autosize": ("ADEAUTOSIZE", {"padding": 70}), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -424,7 +424,7 @@ def INPUT_TYPES(s): }, "optional": { "prev_extras": ("CFG_EXTRAS",), - "autosize": ("ADEAUTOSIZE", {"padding": 45}), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -509,7 +509,7 @@ def INPUT_TYPES(s): }, "optional": { "prev_extras": ("CFG_EXTRAS",), - "autosize": ("ADEAUTOSIZE", {"padding": 45}), + "autosize": ("ADEAUTOSIZE", {"padding": 10}), } } @@ -569,7 +569,7 @@ def INPUT_TYPES(s): "optional": { "composite_x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), "composite_y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), - "autosize": ("ADEAUTOSIZE", {"padding": 30}), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } diff --git a/animatediff/nodes_sigma_schedule.py b/animatediff/nodes_sigma_schedule.py index a27b36c..c6667f2 100644 --- a/animatediff/nodes_sigma_schedule.py +++ b/animatediff/nodes_sigma_schedule.py @@ -76,7 +76,7 @@ def INPUT_TYPES(s): "weight_A": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.001}), }, "optional": { - "autosize": ("ADEAUTOSIZE", {"padding": 80}), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -104,7 +104,7 @@ def INPUT_TYPES(s): "interpolation": (InterpolationMethod._LIST,), }, "optional": { - "autosize": ("ADEAUTOSIZE", {"padding": 70}), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -135,7 +135,7 @@ def INPUT_TYPES(s): "idx_split_percent": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.001}) }, "optional": { - "autosize": ("ADEAUTOSIZE", {"padding": 40}), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -164,7 +164,7 @@ def INPUT_TYPES(s): "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), }, "optional": { - "autosize": ("ADEAUTOSIZE", {"padding": 50}), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } From 6a1d498ce01f8faa1631490d4714cd8986ce95ca Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Sun, 1 Sep 2024 12:34:07 -0500 Subject: [PATCH 05/43] Started work on NoiseCalibration implementation (pain) --- animatediff/nodes.py | 4 +- animatediff/nodes_sample.py | 31 +++++++++++-- animatediff/sample_settings.py | 82 +++++++++++++++++++++++++++++++++- animatediff/sampling.py | 3 ++ 4 files changed, 114 insertions(+), 6 deletions(-) diff --git a/animatediff/nodes.py b/animatediff/nodes.py index a429789..11cbd13 100644 --- a/animatediff/nodes.py +++ b/animatediff/nodes.py @@ -24,7 +24,7 @@ from .nodes_sample import (FreeInitOptionsNode, NoiseLayerAddWeightedNode, SampleSettingsNode, NoiseLayerAddNode, NoiseLayerReplaceNode, IterationOptionsNode, CustomCFGNode, CustomCFGSimpleNode, CustomCFGKeyframeNode, CustomCFGKeyframeSimpleNode, CustomCFGKeyframeInterpolationNode, CustomCFGKeyframeFromListNode, CFGExtrasPAGNode, CFGExtrasPAGSimpleNode, CFGExtrasRescaleCFGNode, CFGExtrasRescaleCFGSimpleNode, - NoisedImageInjectionNode, NoisedImageInjectOptionsNode) + NoisedImageInjectionNode, NoisedImageInjectOptionsNode, NoiseCalibrationNode) from .nodes_sigma_schedule import (SigmaScheduleNode, RawSigmaScheduleNode, WeightedAverageSigmaScheduleNode, InterpolatedWeightedAverageSigmaScheduleNode, SplitAndCombineSigmaScheduleNode, SigmaScheduleToSigmasNode) from .nodes_context import (LegacyLoopedUniformContextOptionsNode, LoopedUniformContextOptionsNode, LoopedUniformViewOptionsNode, StandardUniformContextOptionsNode, StandardStaticContextOptionsNode, BatchedContextOptionsNode, StandardStaticViewOptionsNode, StandardUniformViewOptionsNode, ViewAsContextOptionsNode, @@ -158,6 +158,7 @@ "ADE_SigmaScheduleToSigmas": SigmaScheduleToSigmasNode, "ADE_NoisedImageInjection": NoisedImageInjectionNode, "ADE_NoisedImageInjectOptions": NoisedImageInjectOptionsNode, + "ADE_NoiseCalibration": NoiseCalibrationNode, # Scheduling PromptSchedulingNode.NodeID: PromptSchedulingNode, PromptSchedulingLatentsNode.NodeID: PromptSchedulingLatentsNode, @@ -325,6 +326,7 @@ "ADE_SigmaScheduleToSigmas": "Sigma Schedule To Sigmas πŸŽ­πŸ…πŸ…“", "ADE_NoisedImageInjection": "Image Injection πŸŽ­πŸ…πŸ…“", "ADE_NoisedImageInjectOptions": "Image Injection Options πŸŽ­πŸ…πŸ…“", + "ADE_NoiseCalibration": "Noise Calibration πŸŽ­πŸ…πŸ…“", # Scheduling PromptSchedulingNode.NodeID: PromptSchedulingNode.NodeName, PromptSchedulingLatentsNode.NodeID: PromptSchedulingLatentsNode.NodeName, diff --git a/animatediff/nodes_sample.py b/animatediff/nodes_sample.py index b3e86b8..f274106 100644 --- a/animatediff/nodes_sample.py +++ b/animatediff/nodes_sample.py @@ -7,7 +7,7 @@ from .freeinit import FreeInitFilter from .sample_settings import (FreeInitOptions, IterationOptions, NoiseLayerAdd, NoiseLayerAddWeighted, NoiseLayerGroup, NoiseLayerReplace, NoiseLayerType, - SeedNoiseGeneration, SampleSettings, + SeedNoiseGeneration, SampleSettings, NoiseCalibration, CustomCFGKeyframeGroup, CustomCFGKeyframe, CFGExtrasGroup, CFGExtras, NoisedImageToInjectGroup, NoisedImageToInject, NoisedImageInjectOptions) from .utils_model import BIGMIN, BIGMAX, MAX_RESOLUTION, SigmaSchedule, InterpolationMethod @@ -33,6 +33,7 @@ def INPUT_TYPES(s): "custom_cfg": ("CUSTOM_CFG",), "sigma_schedule": ("SIGMA_SCHEDULE",), "image_inject": ("IMAGE_INJECT",), + "noise_calib": ("NOISE_CALIBRATION",), "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -44,10 +45,11 @@ def INPUT_TYPES(s): def create_settings(self, batch_offset: int, noise_type: str, seed_gen: str, seed_offset: int, noise_layers: NoiseLayerGroup=None, iteration_opts: IterationOptions=None, seed_override: int=None, adapt_denoise_steps=False, - custom_cfg: CustomCFGKeyframeGroup=None, sigma_schedule: SigmaSchedule=None, image_inject: NoisedImageToInjectGroup=None): + custom_cfg: CustomCFGKeyframeGroup=None, sigma_schedule: SigmaSchedule=None, image_inject: NoisedImageToInjectGroup=None, + noise_calib: NoiseCalibration=None): sampling_settings = SampleSettings(batch_offset=batch_offset, noise_type=noise_type, seed_gen=seed_gen, seed_offset=seed_offset, noise_layers=noise_layers, iteration_opts=iteration_opts, seed_override=seed_override, adapt_denoise_steps=adapt_denoise_steps, - custom_cfg=custom_cfg, sigma_schedule=sigma_schedule, image_injection=image_inject) + custom_cfg=custom_cfg, sigma_schedule=sigma_schedule, image_injection=image_inject, noise_calibration=noise_calib) return (sampling_settings,) @@ -220,6 +222,29 @@ def create_iter_opts(self, iterations: int, filter: str, d_s: float, d_t: float, return (iter_opts,) +class NoiseCalibrationNode: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "calib_iterations": ("INT", {"default": 1, "min": 1, "step": 1}), + "thresh_freq": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.001}), + }, + "optional": { + "autosize": ("ADEAUTOSIZE", {"padding": 0}), + } + } + + RETURN_TYPES = ("NOISE_CALIBRATION",) + RETURN_NAMES = ("NOISE_CALIB",) + CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/sample settings" + FUNCTION = "create_noisecalibration" + + def create_noisecalibration(self, calib_iterations: int, thresh_freq: float): + noise_calib = NoiseCalibration(scale=thresh_freq, calib_iterations=calib_iterations) + return (noise_calib,) + + class CustomCFGNode: @classmethod def INPUT_TYPES(s): diff --git a/animatediff/sample_settings.py b/animatediff/sample_settings.py index bff8445..593c21b 100644 --- a/animatediff/sample_settings.py +++ b/animatediff/sample_settings.py @@ -2,6 +2,8 @@ from typing import Union, Callable import torch from torch import Tensor +import torch.fft as fft +from einops import rearrange import comfy.sample import comfy.samplers @@ -56,7 +58,8 @@ class NoiseNormalize: class SampleSettings: def __init__(self, batch_offset: int=0, noise_type: str=None, seed_gen: str=None, seed_offset: int=0, noise_layers: 'NoiseLayerGroup'=None, iteration_opts=None, seed_override:int=None, negative_cond_flipflop=False, adapt_denoise_steps: bool=False, - custom_cfg: 'CustomCFGKeyframeGroup'=None, sigma_schedule: SigmaSchedule=None, image_injection: 'NoisedImageToInjectGroup'=None): + custom_cfg: 'CustomCFGKeyframeGroup'=None, sigma_schedule: SigmaSchedule=None, image_injection: 'NoisedImageToInjectGroup'=None, + noise_calibration: 'NoiseCalibration'=None): self.batch_offset = batch_offset self.noise_type = noise_type if noise_type is not None else NoiseLayerType.DEFAULT self.seed_gen = seed_gen if seed_gen is not None else SeedNoiseGeneration.COMFY @@ -69,6 +72,7 @@ def __init__(self, batch_offset: int=0, noise_type: str=None, seed_gen: str=None self.custom_cfg = custom_cfg.clone() if custom_cfg else custom_cfg self.sigma_schedule = sigma_schedule self.image_injection = image_injection.clone() if image_injection else NoisedImageToInjectGroup() + self.noise_calibration = noise_calibration def prepare_noise(self, seed: int, latents: Tensor, noise: Tensor, extra_seed_offset=0, extra_args:dict={}, force_create_noise=True): if self.seed_override is not None: @@ -109,7 +113,7 @@ def clone(self): return SampleSettings(batch_offset=self.batch_offset, noise_type=self.noise_type, seed_gen=self.seed_gen, seed_offset=self.seed_offset, noise_layers=self.noise_layers.clone(), iteration_opts=self.iteration_opts, seed_override=self.seed_override, negative_cond_flipflop=self.negative_cond_flipflop, adapt_denoise_steps=self.adapt_denoise_steps, custom_cfg=self.custom_cfg, - sigma_schedule=self.sigma_schedule, image_injection=self.image_injection) + sigma_schedule=self.sigma_schedule, image_injection=self.image_injection, noise_calibration=self.noise_calibration) class NoiseLayer: @@ -503,6 +507,80 @@ def preprocess_latents(self, curr_i: int, model: ModelPatcher, latents: Tensor, raise ValueError(f"FreeInit init_type '{self.init_type}' is not recognized.") +class NoiseCalibration: + def __init__(self, scale: float=0.5, calib_iterations: int=1): + self.scale = scale + self.calib_iterations = calib_iterations + + def perform_calibration(self, sample_func: Callable, model: ModelPatcher, latents: Tensor, noise: Tensor, is_custom: bool, args: list, kwargs: dict): + if is_custom: + return self._perform_calibration_custom(sample_func=sample_func, model=model, latents=latents, noise=noise, _args=args, _kwargs=kwargs) + return self._perform_calibration_not_custom(sample_func=sample_func, model=model, latents=latents, noise=noise, args=args, kwargs=kwargs) + + def _perform_calibration_custom(self, sample_func: Callable, model: ModelPatcher, latents: Tensor, noise: Tensor, _args: list, _kwargs: dict): + args = _args.copy() + kwargs = _kwargs.copy() + # need to get sigmas to be used in sampling and for noise calc + sigmas = args[2] + # use first 2 sigmas as real sigmas (2 sigmas = 1 step) + sigmas = sigmas[:2] + args[2] = sigmas + # divide by scale factor + sigma = sigmas[0] + alpha_cumprod = 1 / ((sigma * sigma) + 1) + sqrt_alpha_prod = alpha_cumprod ** 0.5 + sqrt_one_minus_alpha_prod = (1 - alpha_cumprod) ** 0.5 + zero_noise = torch.zeros_like(noise) + new_latents = latents + #new_latents = latents * (model.model.latent_format.scale_factor) + for _ in range(self.calib_iterations): + # TODO: do i need to use DDIM noising, or will ComfyUI's work? + x = new_latents * sqrt_alpha_prod + noise * sqrt_one_minus_alpha_prod + #x = latents + #x = latents + noise * sigma #torch.sqrt(1.0 + sigma ** 2.0) + # replace latents in args with x + args[-1] = x + e_t_theta = sample_func(model, zero_noise, *args, **kwargs) + x_0_t = (x - sqrt_one_minus_alpha_prod * e_t_theta) / sqrt_alpha_prod + freq_delta = (self.get_low_or_high_fft(x_0_t, self.scale, is_low=False) - self.get_low_or_high_fft(new_latents, self.scale, is_low=False)) + noise = e_t_theta + sqrt_alpha_prod / sqrt_one_minus_alpha_prod * freq_delta + #return latents, noise + #x = latents * sqrt_alpha_prod + noise * sqrt_one_minus_alpha_prod + #return zero_noise, x #noise * (model.model.latent_format.scale_factor) + return latents, noise * (model.model.latent_format.scale_factor) + + def _perform_calibration_not_custom(self, sample_func: Callable, model: ModelPatcher, latents: Tensor, noise: Tensor, args: list, kwargs: dict): + return latents, noise + + @staticmethod + # From NoiseCalibration code at https://github.com/yangqy1110/NC-SDEdit/ + def get_low_or_high_fft(x: Tensor, scale: float, is_low=True): + # reshape to match intended dims; starts in b c h w, turn into c b h w + x = rearrange(x, "b c h w -> c b h w") + # FFT + x_freq = fft.fftn(x, dim=(-2, -1)) + x_freq = fft.fftshift(x_freq, dim=(-2, -1)) + C, T, H, W = x_freq.shape + + # extract + if is_low: + mask = torch.zeros((C, T, H, W), device=x.device) + crow, ccol = H // 2, W // 2 + mask[..., crow - int(crow * scale):crow + int(crow * scale), ccol - int(ccol * scale):ccol + int(ccol * scale)] = 1 + else: + mask = torch.ones((C, T, H, W), device=x.device) + crow, ccol = H // 2, W //2 + mask[..., crow - int(crow * scale):crow + int(crow * scale), ccol - int(ccol * scale):ccol + int(ccol * scale)] = 0 + x_freq = x_freq * mask + + # IFFT + x_freq = fft.ifftshift(x_freq, dim=(-2, -1)) + x_filtered = fft.ifftn(x_freq, dim=(-2, -1)).real + # rearrange back to ComfyUI expected dims + x_filtered = rearrange(x_filtered, "c b h w -> b c h w") + return x_filtered + + class CFGExtras: def __init__(self, call_fn: Callable): self.call_fn = call_fn diff --git a/animatediff/sampling.py b/animatediff/sampling.py index 8244f2c..7f3ff4c 100644 --- a/animatediff/sampling.py +++ b/animatediff/sampling.py @@ -515,6 +515,9 @@ def ad_callback(step, x0, x, total_steps): seed=seed, sample_settings=model.sample_settings, noise_extra_args=noise_extra_args, **iter_kwargs) + if model.sample_settings.noise_calibration is not None: + latents, noise = model.sample_settings.noise_calibration.perform_calibration(sample_func=orig_comfy_sample, model=model, latents=latents, noise=noise, + is_custom=is_custom, args=args, kwargs=kwargs) args[-1] = latents if model.motion_models is not None: From 1a9d70ee4cc4c1526c06e42f62f71e1d0d738c1d Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Tue, 3 Sep 2024 15:37:14 -0500 Subject: [PATCH 06/43] Progress on FancyVideo support --- animatediff/adapter_fancyvideo.py | 40 +++++++++++++ animatediff/model_injection.py | 62 +++++++++++++++++--- animatediff/motion_module_ad.py | 87 ++++++++++++++++++++++++++--- animatediff/motion_module_fv.py | 0 animatediff/nodes.py | 5 ++ animatediff/nodes_fancyvideo.py | 62 ++++++++++++++++++++ animatediff/nodes_gen2.py | 1 + animatediff/nodes_sigma_schedule.py | 3 + animatediff/sample_settings.py | 8 +-- animatediff/sampling.py | 61 ++++++++++++-------- 10 files changed, 287 insertions(+), 42 deletions(-) create mode 100644 animatediff/adapter_fancyvideo.py delete mode 100644 animatediff/motion_module_fv.py create mode 100644 animatediff/nodes_fancyvideo.py diff --git a/animatediff/adapter_fancyvideo.py b/animatediff/adapter_fancyvideo.py new file mode 100644 index 0000000..5ada7ac --- /dev/null +++ b/animatediff/adapter_fancyvideo.py @@ -0,0 +1,40 @@ +from torch import nn + +import comfy.ops + + +FancyVideoKeys = [ + 'fps_embedding.linear.bias', + 'fps_embedding.linear.weight', + 'motion_embedding.linear.bias', + 'motion_embedding.linear.weight', + 'conv_in.bias', + 'conv_in.weight', +] + + +def initialize_weights_to_zero(m): + if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d): + nn.init.constant_(m.weight, 0) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + +class FancyVideoCondEmbedding(nn.Module): + def __init__(self, in_channels: int, cond_embed_dim: int, act_fn: str = "silu", ops=comfy.ops.disable_weight_init): + super().__init__() + + self.linear = ops.Linear(in_channels, cond_embed_dim) + self.act = None + if act_fn == "silu": + self.act = nn.SiLU() + elif act_fn == "mish": + self.act = nn.Mish() + + def forward(self, sample): + sample = self.linear(sample) + + if self.act is not None: + sample = self.act(sample) + + return sample diff --git a/animatediff/model_injection.py b/animatediff/model_injection.py index 051798c..19c7fda 100644 --- a/animatediff/model_injection.py +++ b/animatediff/model_injection.py @@ -793,6 +793,13 @@ def __init__(self, *args, **kwargs): self.prev_current_pia_input: InputPIA = None self.pia_multival: Union[float, Tensor] = None + # FancyVideo + self.orig_fancy_images: Tensor = None + self.fancy_vae: VAE = None + self.cached_fancy_c_concat: comfy.conds.CONDNoiseShape = None # cached + self.prev_fancy_latents_shape: tuple = None + self.fancy_multival: Union[float, Tensor] = None + # temporary variables self.current_used_steps = 0 self.current_keyframe: ADKeyframe = None @@ -930,7 +937,7 @@ def prepare_current_keyframe(self, x: Tensor, t: Tensor): # update previous_t self.previous_t = curr_t - def prepare_img_features(self, x: Tensor, cond_or_uncond: list[int], ad_params: dict[str], latent_format): + def prepare_alcmi2v_features(self, x: Tensor, cond_or_uncond: list[int], ad_params: dict[str], latent_format): # if no img_encoder, done if self.model.img_encoder is None: return @@ -1009,10 +1016,7 @@ def get_pia_c_concat(self, model: BaseModel, x: Tensor) -> Tensor: self.prev_pia_latents_shape = None # otherwise, x shape should be the cached pia_latents_shape # get currently used models so they can be properly reloaded after perfoming VAE Encoding - if hasattr(comfy.model_management, "loaded_models"): - cached_loaded_models = comfy.model_management.loaded_models(only_currently_used=True) - else: - cached_loaded_models: list[ModelPatcherAndInjector] = [x.model for x in comfy.model_management.current_loaded_models] + cached_loaded_models = comfy.model_management.loaded_models(only_currently_used=True) try: b, c, h ,w = x.shape usable_ref = self.orig_pia_images[:b] @@ -1052,9 +1056,53 @@ def get_pia_c_concat(self, model: BaseModel, x: Tensor) -> Tensor: finally: comfy.model_management.load_models_gpu(cached_loaded_models) + def get_fancy_c_concat(self, model: BaseModel, x: Tensor) -> Tensor: + # if have cached shape, check if matches - if so, return cached fancy_latents + if self.prev_fancy_latents_shape is not None: + if self.prev_fancy_latents_shape[0] == x.shape[0] and self.prev_fancy_latents_shape[-2] == x.shape[-2] and self.prev_fancy_latents_shape[-1] == x.shape[-1]: + # TODO: if mask is also the same for this timestep, then retucn cached + return self.cached_fancy_c_concat + self.prev_fancy_latents_shape = None + # otherwise, x shape should be the cached fancy_latents_shape + # get currently used models so they can be properly reloaded after performing VAE Encoding + cached_loaded_models = comfy.model_management.loaded_models(only_currently_used=True) + try: + b, c, h, w = x.shape + usable_ref = self.orig_fancy_images[:b] + # resize images to latent's dims + usable_ref = usable_ref.movedim(-1,1) + usable_ref = comfy.utils.common_upscale(samples=usable_ref, width=w*self.fancy_vae.downscale_ratio, height=h*self.fancy_vae.downscale_ratio, + upscale_method="bilinear", crop="center") + usable_ref = usable_ref.movedim(1,-1) + # VAE encode images + logger.info("VAE Encoding FancyVideo input images...") + usable_ref: Tensor = model.process_latent_in(vae_encode_raw_batched(vae=self.fancy_vae, pixels=usable_ref, show_pbar=False)) + logger.info("VAE Encoding FancyVideo input images complete.") + self.prev_fancy_latents_shape = x.shape + # TODO: experiment with indexes that aren't the first + # pad usable_ref with zeros + ref_length = usable_ref.shape[0] + pad_length = b - ref_length + zero_ref = torch.zeros([pad_length, c, h, w], dtype=usable_ref.dtype, device=usable_ref.device) + usable_ref = torch.cat([usable_ref, zero_ref], dim=0) + del zero_ref + # create mask + mask_ones = torch.ones([ref_length, 1, h, w], dtype=usable_ref.dtype, device=usable_ref.device) + mask_zeros = torch.zeros([pad_length, 1, h, w], dtype=usable_ref.dtype, device=usable_ref.device) + mask = torch.cat([mask_ones, mask_zeros], dim=0) + # TODO: experiment with mask strength + # cache fancy c_concat - ref first, then mask + self.cached_fancy_c_concat = comfy.conds.CONDNoiseShape(torch.cat([usable_ref, mask], dim=1)) + return self.cached_fancy_c_concat + finally: + comfy.model_management.load_models_gpu(cached_loaded_models) + def is_pia(self): return self.model.mm_info.mm_format == AnimateDiffFormat.PIA and self.orig_pia_images is not None + def is_fancyvideo(self): + return self.model.mm_info.mm_format == AnimateDiffFormat.FANCYVIDEO + def cleanup(self): if self.model is not None: self.model.cleanup() @@ -1175,10 +1223,10 @@ def prepare_current_keyframe(self, x: Tensor, t: Tensor): for motion_model in self.models: motion_model.prepare_current_keyframe(x=x, t=t) - def get_pia_models(self): + def get_special_models(self): pia_motion_models: list[MotionModelPatcher] = [] for motion_model in self.models: - if motion_model.is_pia(): + if motion_model.is_pia() or motion_model.is_fancyvideo(): pia_motion_models.append(motion_model) return pia_motion_models diff --git a/animatediff/motion_module_ad.py b/animatediff/motion_module_ad.py index ab46790..407c39f 100644 --- a/animatediff/motion_module_ad.py +++ b/animatediff/motion_module_ad.py @@ -11,6 +11,7 @@ from comfy.ldm.modules.attention import FeedForward, SpatialTransformer from comfy.model_patcher import ModelPatcher from comfy.model_base import BaseModel +from comfy.ldm.modules.diffusionmodules.util import timestep_embedding from comfy.ldm.modules.diffusionmodules import openaimodel from comfy.ldm.modules.diffusionmodules.openaimodel import SpatialTransformer from comfy.controlnet import broadcast_image_to @@ -22,6 +23,7 @@ from .adapter_animatelcm_i2v import AdapterEmbed if TYPE_CHECKING: # avoids circular import from .adapter_cameractrl import CameraPoseEncoder +from .adapter_fancyvideo import FancyVideoCondEmbedding, FancyVideoKeys, initialize_weights_to_zero from .utils_motion import (CrossAttentionMM, MotionCompatibilityError, DummyNNModule, extend_to_batch_size, extend_list_to_batch_size, prepare_mask_batch, get_combined_multival) from .utils_model import BetaSchedules, ModelTypeSD @@ -135,7 +137,7 @@ def is_animatelcm(mm_state_dict: dict[str, Tensor]) -> bool: return True -def is_pia(mm_state_dict: dict[str, Tensor]) -> bool: +def has_conv_in(mm_state_dict: dict[str, Tensor]) -> bool: # check if conv_in.weight and .bias are present if "conv_in.weight" in mm_state_dict and "conv_in.bias" in mm_state_dict: return True @@ -197,6 +199,20 @@ def has_img_encoder(mm_state_dict: dict[str, Tensor]): return False +def has_fps_embedding(mm_state_dict: dict[str, Tensor]): + for key in mm_state_dict.keys(): + if key.startswith("fps_embedding."): + return True + return False + + +def has_motion_embedding(mm_state_dict: dict[str, Tensor]): + for key in mm_state_dict.keys(): + if key.startswith("motion_embedding."): + return True + return False + + def normalize_ad_state_dict(mm_state_dict: dict[str, Tensor], mm_name: str) -> Tuple[dict[str, Tensor], AnimateDiffInfo]: # from pathlib import Path # log_name = mm_name.split('\\')[-1] @@ -227,7 +243,7 @@ def normalize_ad_state_dict(mm_state_dict: dict[str, Tensor], mm_name: str) -> T mm_format = AnimateDiffFormat.HOTSHOTXL if is_animatelcm(mm_state_dict): mm_format = AnimateDiffFormat.ANIMATELCM - if is_pia(mm_state_dict): + if has_conv_in(mm_state_dict): mm_format = AnimateDiffFormat.PIA if is_fancyvideo(mm_state_dict): mm_format = AnimateDiffFormat.FANCYVIDEO @@ -241,7 +257,7 @@ def normalize_ad_state_dict(mm_state_dict: dict[str, Tensor], mm_name: str) -> T continue if mm_format == AnimateDiffFormat.PIA and key.startswith("conv_in."): continue - if mm_format == AnimateDiffFormat.FANCYVIDEO: + if mm_format == AnimateDiffFormat.FANCYVIDEO and key in FancyVideoKeys: continue del mm_state_dict[key] # determine the model's version @@ -319,11 +335,18 @@ def __init__(self, mm_state_dict: dict[str, Tensor], mm_info: AnimateDiffInfo): self.init_img_encoder() # CameraCtrl stuff self.camera_encoder: 'CameraPoseEncoder' = None - # PIA stuff - create conv_in if keys are present for it + # PIA/FancyVideo stuff - create conv_in if keys are present for it self.conv_in: comfy.ops.disable_weight_init.Conv2d = None self.orig_conv_in: comfy.ops.disable_weight_init.Conv2d = None - if is_pia(mm_state_dict): + if has_conv_in(mm_state_dict): self.init_conv_in(mm_state_dict) + # FancyVideo fps_embedding and motion_embedding + self.fps_embedding: FancyVideoCondEmbedding = None + self.motion_embedding: FancyVideoCondEmbedding = None + if has_fps_embedding(mm_state_dict): + self.init_fps_embedding(mm_state_dict) + if has_motion_embedding(mm_state_dict): + self.init_motion_embedding(mm_state_dict) def init_img_encoder(self): del self.img_encoder @@ -335,7 +358,7 @@ def set_camera_encoder(self, camera_encoder: 'CameraPoseEncoder'): def init_conv_in(self, mm_state_dict: dict[str, Tensor]): ''' - Used for PIA + Used for PIA/FancyVideo ''' del self.conv_in # hardcoded values, for now @@ -347,6 +370,54 @@ def init_conv_in(self, mm_state_dict: dict[str, Tensor]): self.conv_in = self.ops.conv_nd(2, in_channels, model_channels, 3, padding=1, dtype=comfy.model_management.unet_dtype(), device=comfy.model_management.unet_offload_device()) + def init_fps_embedding(self, mm_state_dict: dict[str, Tensor]): + ''' + Used for FancyVideo + ''' + del self.fps_embedding + in_channels = mm_state_dict["fps_embedding.linear.weight"].size(1) # expected to be 320 + cond_embed_dim = mm_state_dict["fps_embedding.linear.weight"].size(0) # expected to be 1280 + self.fps_embedding = FancyVideoCondEmbedding(in_channels=in_channels, cond_embed_dim=cond_embed_dim) + self.fps_embedding.apply(initialize_weights_to_zero) + + def init_motion_embedding(self, mm_state_dict: dict[str, Tensor]): + ''' + Used for FancyVideo + ''' + del self.motion_embedding + in_channels = mm_state_dict["motion_embedding.linear.weight"].size(1) # expected to be 320 + cond_embed_dim = mm_state_dict["motion_embedding.linear.weight"].size(0) # expected to be 1280 + self.motion_embedding = FancyVideoCondEmbedding(in_channels=in_channels, cond_embed_dim=cond_embed_dim) + self.motion_embedding.apply(initialize_weights_to_zero) + + def get_fancyvideo_emb_patches(self, dtype, device, fps=16, motion_score=1.0): + patches = [] + if self.fps_embedding is not None: + if fps is not None: + def fps_emb_patch(x: Tensor, emb: Tensor, model_channels: int, transformer_options: dict[str]): + nonlocal fps + if fps is None: + return emb + fps = torch.tensor(fps).to(dtype=emb.dtype, device=emb.device) + fps = fps.expand(x.shape[0]) + fps_emb = timestep_embedding(fps, model_channels, repeat_only=False).to(dtype=emb.dtype) + fps_emb = self.fps_embedding(fps_emb) + return emb + fps_emb + patches.append(fps_emb_patch) + if self.motion_embedding is not None: + if motion_score is not None: + def motion_emb_patch(x: Tensor, emb: Tensor, model_channels: int, transformer_options: dict[str]): + nonlocal motion_score + if motion_score is None: + return emb + motion_score = torch.tensor(motion_score).to(dtype=emb.dtype, device=emb.device) + motion_score = motion_score.expand(x.shape[0]) + motion_emb = timestep_embedding(motion_score, model_channels, repeat_only=False).to(dtype=emb.dtype) + motion_emb = self.motion_embedding(motion_emb) + return emb + motion_emb + patches.append(motion_emb_patch) + return patches + def get_device_debug(self): return self.down_blocks[0].motion_modules[0].temporal_transformer.proj_in.weight.device @@ -456,7 +527,7 @@ def _eject(self, unet_blocks: nn.ModuleList): for idx in sorted(idx_to_pop, reverse=True): block.pop(idx) - def inject_unet_conv_in_pia(self, model: BaseModel): + def inject_unet_conv_in_pia_fancyvideo(self, model: BaseModel): if self.conv_in is None: return # TODO: make sure works with lowvram @@ -480,7 +551,7 @@ def inject_unet_conv_in_pia(self, model: BaseModel): # now can apply combined_conv_in to unet block model.diffusion_model.input_blocks[0][0] = combined_conv_in - def restore_unet_conv_in_pia(self, model: BaseModel): + def restore_unet_conv_in_pia_fancyvideo(self, model: BaseModel): if self.orig_conv_in is not None: model.diffusion_model.input_blocks[0][0] = self.orig_conv_in.to(model.diffusion_model.input_blocks[0][0].weight.device) self.orig_conv_in = None diff --git a/animatediff/motion_module_fv.py b/animatediff/motion_module_fv.py deleted file mode 100644 index e69de29..0000000 diff --git a/animatediff/nodes.py b/animatediff/nodes.py index 11cbd13..be319d6 100644 --- a/animatediff/nodes.py +++ b/animatediff/nodes.py @@ -11,6 +11,7 @@ CameraCtrlPoseBasic, CameraCtrlPoseCombo, CameraCtrlPoseAdvanced, CameraCtrlManualAppendPose, CameraCtrlReplaceCameraParameters, CameraCtrlSetOriginalAspectRatio) from .nodes_pia import (ApplyAnimateDiffPIAModel, LoadAnimateDiffAndInjectPIANode, InputPIA_MultivalNode, InputPIA_PaperPresetsNode, PIA_ADKeyframeNode) +from .nodes_fancyvideo import (ApplyAnimateDiffFancyVideo,) from .nodes_multival import MultivalDynamicNode, MultivalScaledMaskNode, MultivalDynamicFloatInputNode, MultivalDynamicFloatsNode, MultivalConvertToMaskNode from .nodes_conditioning import (MaskableLoraLoader, MaskableLoraLoaderModelOnly, MaskableSDModelLoader, MaskableSDModelLoaderModelOnly, SetModelLoraHook, SetClipLoraHook, @@ -213,6 +214,8 @@ "ADE_InputPIA_PaperPresets": InputPIA_PaperPresetsNode, "ADE_PIA_AnimateDiffKeyframe": PIA_ADKeyframeNode, "ADE_InjectPIAIntoAnimateDiffModel": LoadAnimateDiffAndInjectPIANode, + # FancyVideo + ApplyAnimateDiffFancyVideo.NodeID: ApplyAnimateDiffFancyVideo, # Deprecated Nodes "AnimateDiffLoaderV1": AnimateDiffLoader_Deprecated, "ADE_AnimateDiffLoaderV1Advanced": AnimateDiffLoaderAdvanced_Deprecated, @@ -381,6 +384,8 @@ "ADE_InputPIA_PaperPresets": "PIA Input [Paper Presets] πŸŽ­πŸ…πŸ…“β‘‘", "ADE_PIA_AnimateDiffKeyframe": "AnimateDiff-PIA Keyframe πŸŽ­πŸ…πŸ…“", "ADE_InjectPIAIntoAnimateDiffModel": "πŸ§ͺInject PIA into AnimateDiff Model πŸŽ­πŸ…πŸ…“β‘‘", + # FancyVideo + ApplyAnimateDiffFancyVideo.NodeID: ApplyAnimateDiffFancyVideo.NodeName, # Deprecated Nodes "AnimateDiffLoaderV1": "🚫AnimateDiff Loader [DEPRECATED] πŸŽ­πŸ…πŸ…“", "ADE_AnimateDiffLoaderV1Advanced": "🚫AnimateDiff Loader (Advanced) [DEPRECATED] πŸŽ­πŸ…πŸ…“", diff --git a/animatediff/nodes_fancyvideo.py b/animatediff/nodes_fancyvideo.py new file mode 100644 index 0000000..451b698 --- /dev/null +++ b/animatediff/nodes_fancyvideo.py @@ -0,0 +1,62 @@ +from typing import Union +import torch +from torch import Tensor +import math + +from comfy.sd import VAE + +from .ad_settings import AnimateDiffSettings +from .logger import logger +from .utils_model import BIGMIN, BIGMAX, get_available_motion_models +from .utils_motion import ADKeyframeGroup, InputPIA, InputPIA_Multival, extend_list_to_batch_size, extend_to_batch_size, prepare_mask_batch +from .motion_lora import MotionLoraList +from .model_injection import MotionModelGroup, MotionModelPatcher, load_motion_module_gen2, inject_pia_conv_in_into_model +from .motion_module_ad import AnimateDiffFormat +from .nodes_gen2 import ApplyAnimateDiffModelNode, ADKeyframeNode + + +class ApplyAnimateDiffFancyVideo: + NodeID = 'ADE_ApplyAnimateDiffFancyVideo' + NodeName = 'Apply AD-FancyVideo Model πŸŽ­πŸ…πŸ…“' + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "motion_model": ("MOTION_MODEL_ADE",), + "image": ("IMAGE",), + "vae": ("VAE",), + "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), + "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}), + }, + "optional": { + "motion_lora": ("MOTION_LORA",), + "scale_multival": ("MULTIVAL",), + "effect_multival": ("MULTIVAL",), + "ad_keyframes": ("AD_KEYFRAMES",), + "prev_m_models": ("M_MODELS",), + "per_block": ("PER_BLOCK",), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), + } + } + + RETURN_TYPES = ("M_MODELS",) + CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/β‘‘ Gen2 nodes β‘‘/FancyVideo" + FUNCTION = "apply_motion_model" + + def apply_motion_model(self, motion_model: MotionModelPatcher, image: Tensor, vae: VAE, + start_percent: float=0.0, end_percent: float=1.0, + motion_lora: MotionLoraList=None, ad_keyframes: ADKeyframeGroup=None, + scale_multival=None, effect_multival=None, ref_multival=None, per_block=None, + prev_m_models: MotionModelGroup=None,): + new_m_models = ApplyAnimateDiffModelNode.apply_motion_model(self, motion_model, start_percent=start_percent, end_percent=end_percent, + motion_lora=motion_lora, ad_keyframes=ad_keyframes, + scale_multival=scale_multival, effect_multival=effect_multival, per_block=per_block, + prev_m_models=prev_m_models) + # most recent added model will always be first in list; + curr_model = new_m_models[0].models[0] + # confirm that model is FancyVideo + if curr_model.model.mm_info.mm_format != AnimateDiffFormat.FANCYVIDEO: + raise Exception(f"Motion model '{curr_model.model.mm_info.mm_name}' is not a FancyVideo model; cannot be used with Apply AD-FancyModel Model node.") + curr_model.orig_fancy_images = image + curr_model.fancy_vae = vae + return new_m_models diff --git a/animatediff/nodes_gen2.py b/animatediff/nodes_gen2.py index 979c10f..1098c77 100644 --- a/animatediff/nodes_gen2.py +++ b/animatediff/nodes_gen2.py @@ -174,6 +174,7 @@ def INPUT_TYPES(s): }, "optional": { "ad_settings": ("AD_SETTINGS",), + "autosize": ("ADEAUTOSIZE", {"padding": 50}), } } diff --git a/animatediff/nodes_sigma_schedule.py b/animatediff/nodes_sigma_schedule.py index c6667f2..e09e217 100644 --- a/animatediff/nodes_sigma_schedule.py +++ b/animatediff/nodes_sigma_schedule.py @@ -45,6 +45,9 @@ def INPUT_TYPES(s): "sampling": (ModelSamplingType._FULL_LIST,), "lcm_original_timesteps": ("INT", {"default": 50, "min": 1, "max": 1000}), "zsnr": ("BOOLEAN", {"default": False}), + }, + "optional": { + "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } diff --git a/animatediff/sample_settings.py b/animatediff/sample_settings.py index 593c21b..898414c 100644 --- a/animatediff/sample_settings.py +++ b/animatediff/sample_settings.py @@ -526,12 +526,12 @@ def _perform_calibration_custom(self, sample_func: Callable, model: ModelPatcher sigmas = sigmas[:2] args[2] = sigmas # divide by scale factor - sigma = sigmas[0] + sigma = sigmas[0] / (model.model.latent_format.scale_factor) alpha_cumprod = 1 / ((sigma * sigma) + 1) sqrt_alpha_prod = alpha_cumprod ** 0.5 sqrt_one_minus_alpha_prod = (1 - alpha_cumprod) ** 0.5 zero_noise = torch.zeros_like(noise) - new_latents = latents + new_latents = latents# / (model.model.latent_format.scale_factor) #new_latents = latents * (model.model.latent_format.scale_factor) for _ in range(self.calib_iterations): # TODO: do i need to use DDIM noising, or will ComfyUI's work? @@ -540,14 +540,14 @@ def _perform_calibration_custom(self, sample_func: Callable, model: ModelPatcher #x = latents + noise * sigma #torch.sqrt(1.0 + sigma ** 2.0) # replace latents in args with x args[-1] = x - e_t_theta = sample_func(model, zero_noise, *args, **kwargs) + e_t_theta = sample_func(model, zero_noise, *args, **kwargs) * (model.model.latent_format.scale_factor) x_0_t = (x - sqrt_one_minus_alpha_prod * e_t_theta) / sqrt_alpha_prod freq_delta = (self.get_low_or_high_fft(x_0_t, self.scale, is_low=False) - self.get_low_or_high_fft(new_latents, self.scale, is_low=False)) noise = e_t_theta + sqrt_alpha_prod / sqrt_one_minus_alpha_prod * freq_delta #return latents, noise #x = latents * sqrt_alpha_prod + noise * sqrt_one_minus_alpha_prod #return zero_noise, x #noise * (model.model.latent_format.scale_factor) - return latents, noise * (model.model.latent_format.scale_factor) + return latents, noise# * (model.model.latent_format.scale_factor) def _perform_calibration_not_custom(self, sample_func: Callable, model: ModelPatcher, latents: Tensor, noise: Tensor, args: list, kwargs: dict): return latents, noise diff --git a/animatediff/sampling.py b/animatediff/sampling.py index 7f3ff4c..88295c4 100644 --- a/animatediff/sampling.py +++ b/animatediff/sampling.py @@ -75,23 +75,38 @@ def prepare_hooks_current_keyframes(self, timestep: Tensor, hook_groups: list[Lo if self.model_patcher is not None: self.model_patcher.prepare_hooked_patches_current_keyframe(t=timestep, hook_groups=hook_groups) - def perform_special_model_features(self, model: BaseModel, conds: list, x_in: Tensor): + def perform_special_model_features(self, model: BaseModel, conds: list, x_in: Tensor, model_options: dict[str]): if self.motion_models is not None: - pia_models = self.motion_models.get_pia_models() - if len(pia_models) > 0: - for pia_model in pia_models: - if pia_model.model.is_in_effect(): - pia_model.model.inject_unet_conv_in_pia(model) - conds = get_conds_with_c_concat(conds, - pia_model.get_pia_c_concat(model, x_in)) + special_models = self.motion_models.get_special_models() + if len(special_models) > 0: + for special_model in special_models: + if special_model.model.is_in_effect(): + if special_model.is_pia(): + special_model.model.inject_unet_conv_in_pia_fancyvideo(model) + conds = get_conds_with_c_concat(conds, + special_model.get_pia_c_concat(model, x_in)) + elif special_model.is_fancyvideo(): + # TODO: handle other weights + special_model.model.inject_unet_conv_in_pia_fancyvideo(model) + conds = get_conds_with_c_concat(conds, + special_model.get_fancy_c_concat(model, x_in)) + # add fps_embedding/motion_embedding patches + emb_patches = special_model.model.get_fancyvideo_emb_patches(dtype=x_in.dtype, device=x_in.device) + transformer_patches = model_options["transformer_options"].get("patches", {}) + transformer_patches["emb_patch"] = emb_patches + model_options["transformer_options"]["patches"] = transformer_patches return conds def restore_special_model_features(self, model: BaseModel): if self.motion_models is not None: - pia_models = self.motion_models.get_pia_models() - if len(pia_models) > 0: - for pia_model in reversed(pia_models): - pia_model.model.restore_unet_conv_in_pia(model) + special_models = self.motion_models.get_special_models() + if len(special_models) > 0: + for special_model in reversed(special_models): + if special_model.is_pia(): + special_model.model.restore_unet_conv_in_pia_fancyvideo(model) + elif special_model.is_fancyvideo(): + # TODO: fill out + special_model.model.restore_unet_conv_in_pia_fancyvideo(model) def reset(self): self.initialized = False @@ -215,7 +230,7 @@ def apply_model_ade_wrapper(self, *args, **kwargs): ad_params = kwargs["transformer_options"]["ad_params"] if ADGS.motion_models is not None: for motion_model in ADGS.motion_models.models: - motion_model.prepare_img_features(x=x, cond_or_uncond=cond_or_uncond, ad_params=ad_params, latent_format=self.latent_format) + motion_model.prepare_alcmi2v_features(x=x, cond_or_uncond=cond_or_uncond, ad_params=ad_params, latent_format=self.latent_format) motion_model.prepare_camera_features(x=x, cond_or_uncond=cond_or_uncond, ad_params=ad_params) del x return orig_apply_model(*args, **kwargs) @@ -612,7 +627,15 @@ def evolved_sampling_function(model, x: Tensor, timestep: Tensor, uncond, cond, ADGS.initialize(model) ADGS.prepare_current_keyframes(x=x, timestep=timestep) try: - cond, uncond = ADGS.perform_special_model_features(model, [cond, uncond], x) + # add AD/evolved-sampling params to model_options (transformer_options) + model_options = model_options.copy() + if "transformer_options" not in model_options: + model_options["transformer_options"] = {} + else: + model_options["transformer_options"] = model_options["transformer_options"].copy() + model_options["transformer_options"]["ad_params"] = ADGS.create_exposed_params() + + cond, uncond = ADGS.perform_special_model_features(model, [cond, uncond], x, model_options) # only use cfg1_optimization if not using custom_cfg or explicitly set to 1.0 uncond_ = uncond @@ -622,15 +645,7 @@ def evolved_sampling_function(model, x: Tensor, timestep: Tensor, uncond, cond, cfg_multival = ADGS.sample_settings.custom_cfg.cfg_multival if type(cfg_multival) != Tensor and math.isclose(cfg_multival, 1.0) and model_options.get("disable_cfg1_optimization", False) == False: uncond_ = None - del cfg_multival - - # add AD/evolved-sampling params to model_options (transformer_options) - model_options = model_options.copy() - if "transformer_options" not in model_options: - model_options["transformer_options"] = {} - else: - model_options["transformer_options"] = model_options["transformer_options"].copy() - model_options["transformer_options"]["ad_params"] = ADGS.create_exposed_params() + del cfg_multival if not ADGS.is_using_sliding_context(): cond_pred, uncond_pred = calc_conds_batch_wrapper(model, [cond, uncond_], x, timestep, model_options) From a77460cfafb440965f8a2accff955e17b2c2862c Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Sat, 7 Sep 2024 03:06:53 -0500 Subject: [PATCH 07/43] Updated emb_patch, some upgrades for ModelSamplingConfig --- animatediff/motion_module_ad.py | 10 +++++----- animatediff/nodes_sigma_schedule.py | 7 ++++++- animatediff/utils_model.py | 20 ++++++++++++++++---- 3 files changed, 27 insertions(+), 10 deletions(-) diff --git a/animatediff/motion_module_ad.py b/animatediff/motion_module_ad.py index 407c39f..dc1dfec 100644 --- a/animatediff/motion_module_ad.py +++ b/animatediff/motion_module_ad.py @@ -390,28 +390,28 @@ def init_motion_embedding(self, mm_state_dict: dict[str, Tensor]): self.motion_embedding = FancyVideoCondEmbedding(in_channels=in_channels, cond_embed_dim=cond_embed_dim) self.motion_embedding.apply(initialize_weights_to_zero) - def get_fancyvideo_emb_patches(self, dtype, device, fps=16, motion_score=1.0): + def get_fancyvideo_emb_patches(self, dtype, device, fps=25, motion_score=3.0): patches = [] if self.fps_embedding is not None: if fps is not None: - def fps_emb_patch(x: Tensor, emb: Tensor, model_channels: int, transformer_options: dict[str]): + def fps_emb_patch(emb: Tensor, model_channels: int, transformer_options: dict[str]): nonlocal fps if fps is None: return emb fps = torch.tensor(fps).to(dtype=emb.dtype, device=emb.device) - fps = fps.expand(x.shape[0]) + fps = fps.expand(emb.shape[0]) fps_emb = timestep_embedding(fps, model_channels, repeat_only=False).to(dtype=emb.dtype) fps_emb = self.fps_embedding(fps_emb) return emb + fps_emb patches.append(fps_emb_patch) if self.motion_embedding is not None: if motion_score is not None: - def motion_emb_patch(x: Tensor, emb: Tensor, model_channels: int, transformer_options: dict[str]): + def motion_emb_patch(emb: Tensor, model_channels: int, transformer_options: dict[str]): nonlocal motion_score if motion_score is None: return emb motion_score = torch.tensor(motion_score).to(dtype=emb.dtype, device=emb.device) - motion_score = motion_score.expand(x.shape[0]) + motion_score = motion_score.expand(emb.shape[0]) motion_emb = timestep_embedding(motion_score, model_channels, repeat_only=False).to(dtype=emb.dtype) motion_emb = self.motion_embedding(motion_emb) return emb + motion_emb diff --git a/animatediff/nodes_sigma_schedule.py b/animatediff/nodes_sigma_schedule.py index e09e217..369c3f1 100644 --- a/animatediff/nodes_sigma_schedule.py +++ b/animatediff/nodes_sigma_schedule.py @@ -59,7 +59,12 @@ def get_sigma_schedule(self, raw_beta_schedule: str, linear_start: float, linear sampling: str, lcm_original_timesteps: int, zsnr: bool, lcm_zsnr: bool=None): if lcm_zsnr is not None: zsnr = lcm_zsnr - new_config = ModelSamplingConfig(beta_schedule=raw_beta_schedule, linear_start=linear_start, linear_end=linear_end) + # from pathlib import Path + # log_name = 'enforce_zero_terminal_snr_betas' + # betas_file = Path(__file__).parent.parent / rf"{log_name}.pt" + # given_betas = torch.load(betas_file, weights_only=True) + # given_betas[-1] = 0.0 + new_config = ModelSamplingConfig(beta_schedule=raw_beta_schedule, linear_start=linear_start, linear_end=linear_end)#, given_betas=given_betas) if sampling != ModelSamplingType.LCM: lcm_original_timesteps=None model_type = ModelSamplingType.from_alias(sampling) diff --git a/animatediff/utils_model.py b/animatediff/utils_model.py index 930cdc9..009aff8 100644 --- a/animatediff/utils_model.py +++ b/animatediff/utils_model.py @@ -75,13 +75,16 @@ def vae_decode_raw_batched(vae: VAE, latents: Tensor, per_batch=16, show_pbar=Fa class ModelSamplingConfig: - def __init__(self, beta_schedule: str, linear_start: float=None, linear_end: float=None): + def __init__(self, beta_schedule: str, linear_start: float=None, linear_end: float=None, given_betas: Tensor=None, timesteps: int=None): self.sampling_settings = {"beta_schedule": beta_schedule} if linear_start is not None: self.sampling_settings["linear_start"] = linear_start if linear_end is not None: self.sampling_settings["linear_end"] = linear_end - self.beta_schedule = beta_schedule # keeping this for backwards compatibility + if given_betas is not None: + self.sampling_settings["given_betas"] = given_betas + if timesteps is not None: + self.sampling_settings["timesteps"] = timesteps class ModelSamplingType: @@ -112,7 +115,7 @@ def __init__(self, *args, **kwargs): # based on code in comfy_extras/nodes_model_advanced.py -def evolved_model_sampling(model_config: ModelSamplingConfig, model_type: ModelType, alias: str, original_timesteps: int=None): +def evolved_model_sampling(model_config: ModelSamplingConfig, model_type: ModelType, alias: str, original_timesteps: Union[int, None]=None): # if LCM, need to handle manually if BetaSchedules.is_lcm(alias) or original_timesteps is not None: sampling_type = comfy_extras.nodes_model_advanced.LCM @@ -129,7 +132,16 @@ class ModelSamplingAdvancedEvolved(sampling_base, sampling_type): # NOTE: if I want to support zsnr, this is where I would add that code return ModelSamplingAdvancedEvolved(model_config) # otherwise, use vanilla model_sampling function - return model_sampling(model_config, model_type) + ms = model_sampling(model_config, model_type) + if "given_betas" in model_config.sampling_settings: + beta_schedule = model_config.sampling_settings.get("beta_schedule", "linear") + linear_start = model_config.sampling_settings.get("linear_start", 0.00085) + linear_end = model_config.sampling_settings.get("linear_end", 0.012) + timesteps = model_config.sampling_settings.get("timesteps", 1000) + given_betas = model_config.sampling_settings.get("given_betas", None) + ms._register_schedule(given_betas=given_betas, beta_schedule=beta_schedule, + timesteps=timesteps, linear_start=linear_start, linear_end=linear_end) + return ms class BetaSchedules: From e6a9638cef6a317a2bcc087cabadfbcbd5dad801 Mon Sep 17 00:00:00 2001 From: Kosinkadink Date: Sat, 21 Sep 2024 16:33:05 +0900 Subject: [PATCH 08/43] Started porting over code to use upcoming ModelPatcher features in ComfyUI to get rid of a bunch of hacky code (and to enable hookable patches) --- animatediff/model_injection.py | 790 +++--------------------------- animatediff/nodes.py | 4 +- animatediff/nodes_conditioning.py | 3 +- animatediff/nodes_context.py | 7 +- animatediff/nodes_deprecated.py | 2 +- animatediff/nodes_gen1.py | 2 +- animatediff/nodes_gen2.py | 74 ++- animatediff/sampling.py | 238 +++++++-- 8 files changed, 358 insertions(+), 762 deletions(-) diff --git a/animatediff/model_injection.py b/animatediff/model_injection.py index 19c7fda..1857b78 100644 --- a/animatediff/model_injection.py +++ b/animatediff/model_injection.py @@ -13,7 +13,7 @@ import comfy.lora import comfy.model_management import comfy.utils -from comfy.model_patcher import ModelPatcher +from comfy.model_patcher import ModelPatcher, PatcherInjection, WrappersMP from comfy.model_base import BaseModel from comfy.sd import CLIP, VAE @@ -32,729 +32,86 @@ from .sample_settings import SampleSettings, SeedNoiseGeneration -# some motion_model casts here might fail if model becomes metatensor or is not castable; -# should not really matter if it fails, so ignore raised Exceptions -class ModelPatcherAndInjector(ModelPatcher): - def __init__(self, m: ModelPatcher): - # replicate ModelPatcher.clone() to initialize ModelPatcherAndInjector - super().__init__(m.model, m.load_device, m.offload_device, m.size, weight_inplace_update=m.weight_inplace_update) - self.patches = {} - for k in m.patches: - self.patches[k] = m.patches[k][:] - if hasattr(m, "patches_uuid"): - self.patches_uuid = m.patches_uuid - - self.object_patches = m.object_patches.copy() - self.model_options = copy.deepcopy(m.model_options) - if hasattr(m, "model_keys"): - self.model_keys = m.model_keys - if hasattr(m, "backup"): - self.backup = m.backup - if hasattr(m, "object_patches_backup"): - self.object_patches_backup = m.object_patches_backup - - # lora hook stuff - self.hooked_patches: dict[HookRef] = {} # binds LoraHook to specific keys - self.hooked_backup: dict[str, tuple[Tensor, torch.device]] = {} - self.cached_hooked_patches: dict[LoraHookGroup, dict[str, Tensor]] = {} # binds LoraHookGroup to pre-calculated weights (speed optimization) - self.current_lora_hooks = None - self.lora_hook_mode = LoraHookMode.MAX_SPEED - self.model_params_lowvram = False - self.model_params_lowvram_keys = {} # keeps track of keys with applied 'weight_function' or 'bias_function' - # injection stuff - self.currently_injected = False - self.skip_injection = False - self.motion_injection_params: InjectionParams = InjectionParams() - self.sample_settings: SampleSettings = SampleSettings() - self.motion_models: MotionModelGroup = None - # backwards-compatible calculate_weight - if hasattr(comfy.lora, "calculate_weight"): - self.do_calculate_weight = comfy.lora.calculate_weight - else: - self.do_calculate_weight = self.calculate_weight +class ModelPatcherHelper: + MOTION_MODELS = "ADE_motion_models" + SAMPLE_SETTINGS = "ADE_sample_settings" + PARAMS = "ADE_params" - - def clone(self, hooks_only=False): - cloned = ModelPatcherAndInjector(self) - # copy lora hooks - for hook_ref in self.hooked_patches: - cloned.hooked_patches[hook_ref] = {} - for k in self.hooked_patches[hook_ref]: - cloned.hooked_patches[hook_ref][k] = self.hooked_patches[hook_ref][k][:] - # copy pre-calc weights bound to LoraHookGroups - for group in self.cached_hooked_patches: - cloned.cached_hooked_patches[group] = {} - for k in self.cached_hooked_patches[group]: - cloned.cached_hooked_patches[group][k] = self.cached_hooked_patches[group][k] - cloned.hooked_backup = self.hooked_backup - cloned.current_lora_hooks = self.current_lora_hooks - cloned.currently_injected = self.currently_injected - cloned.lora_hook_mode = self.lora_hook_mode - if not hooks_only: - cloned.motion_models = self.motion_models.clone() if self.motion_models else self.motion_models - cloned.sample_settings = self.sample_settings - cloned.motion_injection_params = self.motion_injection_params.clone() if self.motion_injection_params else self.motion_injection_params - return cloned - - @classmethod - def create_from(cls, model: Union[ModelPatcher, 'ModelPatcherAndInjector'], hooks_only=False) -> 'ModelPatcherAndInjector': - if isinstance(model, ModelPatcherAndInjector): - return model.clone(hooks_only=hooks_only) - else: - return ModelPatcherAndInjector(model) - - def clone_has_same_weights(self, clone: 'ModelPatcherCLIPHooks'): - returned = super().clone_has_same_weights(clone) - if not returned: - return returned - # currently, hook patches require that model gets loaded when sampled, so always say is not a clone if hooks present - if len(self.hooked_patches) > 0: - return False - if type(self) != type(clone): - return False - if self.current_lora_hooks != clone.current_lora_hooks: - return False - if self.hooked_patches.keys() != clone.hooked_patches.keys(): - return False - return returned - - def set_lora_hook_mode(self, lora_hook_mode: str): - self.lora_hook_mode = lora_hook_mode - - def prepare_hooked_patches_current_keyframe(self, t: Tensor, hook_groups: list[LoraHookGroup]): - curr_t = t[0] - for hook_group in hook_groups: - for hook in hook_group.hooks: - changed = hook.lora_keyframe.prepare_current_keyframe(curr_t=curr_t) - # if keyframe changed, remove any cached LoraHookGroups that contain hook with the same hook_ref; - # this will cause the weights to be recalculated when sampling - if changed: - # reset current_lora_hooks if contains lora hook that changed - if self.current_lora_hooks is not None: - for current_hook in self.current_lora_hooks.hooks: - if current_hook == hook: - self.current_lora_hooks = None - break - for cached_group in list(self.cached_hooked_patches.keys()): - if cached_group.contains(hook): - self.cached_hooked_patches.pop(cached_group) - - def clean_hooks(self): - self.unpatch_hooked() - self.clear_cached_hooked_weights() - # for lora_hook in self.hooked_patches: - # lora_hook.reset() - - def add_hooked_patches(self, lora_hook: LoraHook, patches, strength_patch=1.0, strength_model=1.0): - ''' - Based on add_patches, but for hooked weights. - ''' - current_hooked_patches: dict[str,list] = self.hooked_patches.get(lora_hook.hook_ref, {}) - p = set() - model_sd = self.model.state_dict() - for k in patches: - offset = None - function = None - if isinstance(k, str): - key = k - else: - offset = k[1] - key = k[0] - if len(k) > 2: - function = k[2] - - if key in model_sd: - p.add(k) - current_patches: list[tuple] = current_hooked_patches.get(key, []) - current_patches.append((strength_patch, patches[k], strength_model, offset, function)) - current_hooked_patches[key] = current_patches - self.hooked_patches[lora_hook.hook_ref] = current_hooked_patches - # since should care about these patches too to determine if same model, reroll patches_uuid - self.patches_uuid = uuid.uuid4() - return list(p) - - def add_hooked_patches_as_diffs(self, lora_hook: LoraHook, patches: dict, strength_patch=1.0, strength_model=1.0): - ''' - Based on add_hooked_patches, but intended for using a model's weights as lora hook. - ''' - current_hooked_patches: dict[str,list] = self.hooked_patches.get(lora_hook.hook_ref, {}) - p = set() - model_sd = self.model.state_dict() - for k in patches: - offset = None - function = None - if isinstance(k, str): - key = k - else: - offset = k[1] - key = k[0] - if len(k) > 2: - function = k[2] - - if key in model_sd: - p.add(k) - current_patches: list[tuple] = current_hooked_patches.get(key, []) - # take difference between desired weight and existing weight to get diff - # TODO: create fix for fp8 - current_patches.append((strength_patch, (patches[k]-comfy.utils.get_attr(self.model, key),), strength_model, offset, function)) - current_hooked_patches[key] = current_patches - self.hooked_patches[lora_hook.hook_ref] = current_hooked_patches - # since should care about these patches too to determine if same model, reroll patches_uuid - self.patches_uuid = uuid.uuid4() - return list(p) - - def get_combined_hooked_patches(self, lora_hooks: LoraHookGroup): - ''' - Returns patches for selected lora_hooks. - ''' - # combined_patches will contain weights of all relevant lora_hooks, per key - combined_patches = {} - if lora_hooks is not None: - for hook in lora_hooks.hooks: - hook_patches: dict = self.hooked_patches.get(hook.hook_ref, {}) - for key in hook_patches.keys(): - current_patches: list[tuple] = combined_patches.get(key, []) - if math.isclose(hook.strength, 1.0): - # if hook strength is 1.0, can just add it directly - current_patches.extend(hook_patches[key]) - else: - # otherwise, need to multiply original patch strength by hook strength - # patches are stored as tuples: (strength_patch, (tuple_with_weights,), strength_model) - for patch in hook_patches[key]: - new_patch = list(patch) - new_patch[0] *= hook.strength - current_patches.append(tuple(new_patch)) - combined_patches[key] = current_patches - return combined_patches - - def patch_model(self, *args, **kwargs): - was_injected = False - if self.currently_injected: - self.eject_model() - was_injected = True - # first, perform model patching - patched_model = super().patch_model(*args, **kwargs) - # bring injection back to original state - if was_injected and not self.currently_injected: - self.inject_model() - return patched_model - - def load(self, device_to=None, lowvram_model_memory=0, *args, **kwargs): - self.eject_model() - try: - return super().load(device_to=device_to, lowvram_model_memory=lowvram_model_memory, *args, **kwargs) - finally: - self.inject_model() - if lowvram_model_memory > 0: - self._patch_lowvram_extras() - - def _patch_lowvram_extras(self): - # check if any modules have weight_function or bias_function that is not None - # NOTE: this serves no purpose currently, but I have it here for future reasons - self.model_params_lowvram = False - self.model_params_lowvram_keys.clear() - for n, m in self.model.named_modules(): - if not hasattr(m, "comfy_cast_weights"): - continue - if getattr(m, "weight_function", None) is not None: - self.model_params_lowvram = True - self.model_params_lowvram_keys[f"{n}.weight"] = n - if getattr(m, "bias_function", None) is not None: - self.model_params_lowvram = True - self.model_params_lowvram_keys[f"{n}.bias"] = n - - def unpatch_model(self, device_to=None, unpatch_weights=True): - # first, eject motion model from unet - self.eject_model() - # finally, do normal model unpatching - if unpatch_weights: - # handle hooked_patches first - self.clean_hooks() - try: - return super().unpatch_model(device_to, unpatch_weights) - finally: - self.model_params_lowvram = False - self.model_params_lowvram_keys.clear() + def __init__(self, model: ModelPatcher): + self.model = model - def partially_load(self, *args, **kwargs): - # partially_load calls patch_model, but we don't want to inject model in the intermediate call; - # make sure to eject before performing partial load, then inject - was_injected = self.currently_injected - try: - self.eject_model() - try: - self.skip_injection = True - to_return = super().partially_load(*args, **kwargs) - self.skip_injection = False - self.inject_model() - return to_return - finally: - self.skip_injection = False - finally: - if was_injected and not self.currently_injected: - self.inject_model() - - def partially_unload(self, *args, **kwargs): - if not self.currently_injected: - return super().partially_unload(*args, **kwargs) - # make sure to eject before performing partial unload, then inject again - self.eject_model() - try: - return super().partially_unload(*args, **kwargs) - finally: - self.inject_model() + def get_adgs(self): + pass - def inject_model(self): - if self.skip_injection: # make it possible to skip injection for intermediate calls (partial load) - return - if self.motion_models is not None: - for motion_model in self.motion_models.models: - self.currently_injected = True - motion_model.model.inject(self) - - def eject_model(self): - if self.motion_models is not None: - for motion_model in self.motion_models.models: - motion_model.model.eject(self) - self.currently_injected = False - - def apply_lora_hooks(self, lora_hooks: LoraHookGroup): - # first, determine if need to reapply patches - if self.current_lora_hooks == lora_hooks: - return - # patch hooks - self.patch_hooked(lora_hooks=lora_hooks) - - def patch_hooked(self, lora_hooks: LoraHookGroup) -> None: - # first, unpatch any previous patches - self.unpatch_hooked() - # eject model, if needed - was_injected = self.currently_injected - if was_injected: - self.eject_model() - - model_sd = self.model_state_dict() - # if have cached weights for lora_hooks, use it - cached_weights = self.cached_hooked_patches.get(lora_hooks, None) - if cached_weights is not None: - for key in cached_weights: - if key not in model_sd: - logger.warning(f"Cached LoraHook could not patch. key doesn't exist in model: {key}") - self.patch_cached_hooked_weight(cached_weights=cached_weights, key=key) - else: - # get combined patches of relevant lora_hooks - relevant_patches = self.get_combined_hooked_patches(lora_hooks=lora_hooks) - for key in relevant_patches: - if key not in model_sd: - logger.warning(f"LoraHook could not patch. key doesn't exist in model: {key}") - continue - self.patch_hooked_weight_to_device(lora_hooks=lora_hooks, combined_patches=relevant_patches, key=key) - self.current_lora_hooks = lora_hooks - # reinject model, if needed - if was_injected: - self.inject_model() - - def patch_cached_hooked_weight(self, cached_weights: dict, key: str): - # TODO: handle model_params_lowvram stuff if necessary - inplace_update = self.weight_inplace_update - if key not in self.hooked_backup: - weight: Tensor = comfy.utils.get_attr(self.model, key) - target_device = self.offload_device - if self.lora_hook_mode == LoraHookMode.MAX_SPEED: - target_device = weight.device - self.hooked_backup[key] = (weight.to(device=target_device, copy=inplace_update), weight.device) - if inplace_update: - comfy.utils.copy_to_param(self.model, key, cached_weights[key]) - else: - comfy.utils.set_attr_param(self.model, key, cached_weights[key]) + def get_motion_models(self) -> list['MotionModelPatcher']: + return self.model.additional_models.get(self.MOTION_MODELS, []) + + def set_motion_models(self, motion_models: list['MotionModelPatcher']): + self.model.set_additional_models(self.MOTION_MODELS, motion_models) + self.model.set_injections(self.MOTION_MODELS, + [PatcherInjection(inject=inject_motion_models, eject=eject_motion_models)]) + + def remove_motion_models(self): + if self.MOTION_MODELS in self.model.additional_models: + self.model.additional_models.pop(self.MOTION_MODELS) + self.model.injections.pop(self.MOTION_MODELS) + def cleanup_motion_models(self): + for motion_model in self.get_motion_models(): + motion_model.cleanup() - def clear_cached_hooked_weights(self): - self.cached_hooked_patches.clear() - self.current_lora_hooks = None + ########################## + # motion models helpers + def set_video_length(self, video_length: int, full_length: int): + for motion_model in self.get_motion_models(): + motion_model.model.set_video_length(video_length=video_length, full_length=full_length) + + def get_name_string(self, show_version=False): + identifiers = [] + for motion_model in self.get_motion_models(): + id = motion_model.model.mm_info.mm_name + if show_version: + id += f":{motion_model.model.mm_info.mm_version}" + identifiers.append(id) + return ", ".join(identifiers) + ########################## - def patch_hooked_weight_to_device(self, lora_hooks: LoraHookGroup, combined_patches: dict, key: str): - if key not in combined_patches: - return - inplace_update = self.weight_inplace_update - weight: Tensor = comfy.utils.get_attr(self.model, key) - if key not in self.hooked_backup: - target_device = self.offload_device - if self.lora_hook_mode == LoraHookMode.MAX_SPEED: - target_device = weight.device - self.hooked_backup[key] = (weight.to(device=target_device, copy=inplace_update), weight.device) - - # TODO: handle model_params_lowvram stuff if necessary - temp_weight = comfy.model_management.cast_to_device(weight, weight.device, torch.float32, copy=True) - out_weight = self.do_calculate_weight(combined_patches[key], temp_weight, key).to(weight.dtype) - if self.lora_hook_mode == LoraHookMode.MAX_SPEED: - self.cached_hooked_patches.setdefault(lora_hooks, {}) - self.cached_hooked_patches[lora_hooks][key] = out_weight - if inplace_update: - comfy.utils.copy_to_param(self.model, key, out_weight) - else: - comfy.utils.set_attr_param(self.model, key, out_weight) - - def patch_hooked_replace_weight_to_device(self, lora_hooks: LoraHookGroup, model_sd: dict, replace_patches: dict): - # first handle replace_patches - for key in replace_patches: - if key not in model_sd: - logger.warning(f"LoraHook could not replace patch. key doesn't exist in model: {key}") - continue - - inplace_update = self.weight_inplace_update - weight: Tensor = comfy.utils.get_attr(self.model, key) - if key not in self.hooked_backup: - # TODO: handle model_params_lowvram stuff if necessary - target_device = self.offload_device - if self.lora_hook_mode == LoraHookMode.MAX_SPEED: - target_device = weight.device - self.hooked_backup[key] = (weight.to(device=target_device, copy=inplace_update), weight.device) - - out_weight = replace_patches[key].to(weight.device) - if self.lora_hook_mode == LoraHookMode.MAX_SPEED: - self.cached_hooked_patches.setdefault(lora_hooks, {}) - self.cached_hooked_patches[lora_hooks][key] = out_weight - if inplace_update: - comfy.utils.copy_to_param(self.model, key, out_weight) - else: - comfy.utils.set_attr_param(self.model, key, out_weight) + def get_sample_settings(self) -> SampleSettings: + return self.model.attachments.get(self.SAMPLE_SETTINGS, None) + + def set_sample_settings(self, sample_settings: SampleSettings): + self.model.set_attachments(self.SAMPLE_SETTINGS, sample_settings) + - def unpatch_hooked(self) -> None: - # if no backups from before hook, then nothing to unpatch - if len(self.hooked_backup) == 0: - return - was_injected = self.currently_injected - if was_injected: - self.eject_model() - # TODO: handle model_params_lowvram stuff if necessary - keys = list(self.hooked_backup.keys()) - if self.weight_inplace_update: - for k in keys: - if self.lora_hook_mode == LoraHookMode.MAX_SPEED: # does not need to be casted - cache device matches needed device - comfy.utils.copy_to_param(self.model, k, self.hooked_backup[k][0]) - else: # should be casted as may not match needed device - comfy.utils.copy_to_param(self.model, k, self.hooked_backup[k][0].to(device=self.hooked_backup[k][1])) - else: - for k in keys: - if self.lora_hook_mode == LoraHookMode.MAX_SPEED: - comfy.utils.set_attr_param(self.model, k, self.hooked_backup[k][0]) - else: # should be casted as may not match needed device - comfy.utils.set_attr_param(self.model, k, self.hooked_backup[k][0].to(device=self.hooked_backup[k][1])) - # clear hooked_backup - self.hooked_backup.clear() - self.current_lora_hooks = None - # reinject model, if necessary - if was_injected: - self.inject_model() - - -class CLIPWithHooks(CLIP): - def __init__(self, clip: Union[CLIP, 'CLIPWithHooks']): - super().__init__(no_init=True) - self.patcher = ModelPatcherCLIPHooks.create_from(clip.patcher) - self.cond_stage_model = clip.cond_stage_model - self.tokenizer = clip.tokenizer - self.layer_idx = clip.layer_idx - self.desired_hooks: LoraHookGroup = None - if hasattr(clip, "desired_hooks"): - self.set_desired_hooks(clip.desired_hooks) + def get_params(self) -> 'InjectionParams': + return self.model.attachments.get(self.PARAMS) - def clone(self): - cloned = CLIPWithHooks(clip=self) - return cloned + def set_params(self, params: 'InjectionParams'): + self.model.set_attachments(self.PARAMS, params) - def set_desired_hooks(self, lora_hooks: LoraHookGroup): - self.desired_hooks = lora_hooks - self.patcher.set_desired_hooks(lora_hooks=lora_hooks) + def set_outer_sample_wrapper(self, wrapper): + self.model.add_wrapper(WrappersMP.OUTER_SAMPLE, wrapper) - def add_hooked_patches(self, lora_hook: LoraHook, patches, strength_patch=1.0, strength_model=1.0): - return self.patcher.add_hooked_patches(lora_hook=lora_hook, patches=patches, strength_patch=strength_patch, strength_model=strength_model) - def add_hooked_patches_as_diffs(self, lora_hook: LoraHook, patches, strength_patch=1.0, strength_model=1.0): - return self.patcher.add_hooked_patches_as_diffs(lora_hook=lora_hook, patches=patches, strength_patch=strength_patch, strength_model=strength_model) - - -class ModelPatcherCLIPHooks(ModelPatcher): - def __init__(self, m: ModelPatcher): - # replicate ModelPatcher.clone() to initialize - super().__init__(m.model, m.load_device, m.offload_device, m.size, weight_inplace_update=m.weight_inplace_update) - self.patches = {} - for k in m.patches: - self.patches[k] = m.patches[k][:] - if hasattr(m, "patches_uuid"): - self.patches_uuid = m.patches_uuid - - self.object_patches = m.object_patches.copy() - self.model_options = copy.deepcopy(m.model_options) - if hasattr(m, "model_keys"): - self.model_keys = m.model_keys - if hasattr(m, "backup"): - self.backup = m.backup - if hasattr(m, "object_patches_backup"): - self.object_patches_backup = m.object_patches_backup - # lora hook stuff - self.hooked_patches: dict[HookRef] = {} # binds LoraHook to specific keys - self.patches_backup = {} - self.hooked_backup: dict[str, tuple[Tensor, torch.device]] = {} - - self.current_lora_hooks = None - self.desired_lora_hooks = None - self.lora_hook_mode = LoraHookMode.MAX_SPEED - - self.model_params_lowvram = False - self.model_params_lowvram_keys = {} # keeps track of keys with applied 'weight_function' or 'bias_function' + def pre_run(self): + # TODO: could implement this as a ModelPatcher ON_PRE_RUN callback + for motion_model in self.get_motion_models(): + motion_model.pre_run(self.model) + self.get_sample_settings().pre_run(self.model) - def clone(self): - cloned = ModelPatcherCLIPHooks(self) - # copy lora hooks - for hook in self.hooked_patches: - cloned.hooked_patches[hook] = {} - for k in self.hooked_patches[hook]: - cloned.hooked_patches[hook][k] = self.hooked_patches[hook][k][:] - cloned.patches_backup = self.patches_backup - cloned.hooked_backup = self.hooked_backup - cloned.current_lora_hooks = self.current_lora_hooks - cloned.desired_lora_hooks = self.desired_lora_hooks - cloned.lora_hook_mode = self.lora_hook_mode - return cloned - @classmethod - def create_from(cls, model: Union[ModelPatcher, 'ModelPatcherCLIPHooks']): - if isinstance(model, ModelPatcherCLIPHooks): - return model.clone() - return ModelPatcherCLIPHooks(model) - - def clone_has_same_weights(self, clone: 'ModelPatcherCLIPHooks'): - returned = super().clone_has_same_weights(clone) - if not returned: - return returned - if type(self) != type(clone): - return False - if self.desired_lora_hooks != clone.desired_lora_hooks: - return False - if self.current_lora_hooks != clone.current_lora_hooks: - return False - if self.hooked_patches.keys() != clone.hooked_patches.keys(): - return False - return returned - - def set_desired_hooks(self, lora_hooks: LoraHookGroup): - self.desired_lora_hooks = lora_hooks - - def add_hooked_patches(self, lora_hook: LoraHook, patches, strength_patch=1.0, strength_model=1.0): - ''' - Based on add_patches, but for hooked weights. - ''' - current_hooked_patches: dict[str,list] = self.hooked_patches.get(lora_hook.hook_ref, {}) - p = set() - model_sd = self.model.state_dict() - for k in patches: - offset = None - function = None - if isinstance(k, str): - key = k - else: - offset = k[1] - key = k[0] - if len(k) > 2: - function = k[2] - - if key in model_sd: - p.add(k) - current_patches: list[tuple] = current_hooked_patches.get(key, []) - current_patches.append((strength_patch, patches[k], strength_model, offset, function)) - current_hooked_patches[key] = current_patches - self.hooked_patches[lora_hook.hook_ref] = current_hooked_patches - # since should care about these patches too to determine if same model, reroll patches_uuid - self.patches_uuid = uuid.uuid4() - return list(p) - - def add_hooked_patches_as_diffs(self, lora_hook: LoraHook, patches: dict, strength_patch=1.0, strength_model=1.0): - ''' - Based on add_hooked_patches, but intended for using a model's weights as lora hook. - ''' - current_hooked_patches: dict[str,list] = self.hooked_patches.get(lora_hook.hook_ref, {}) - p = set() - model_sd = self.model.state_dict() - for k in patches: - offset = None - function = None - if isinstance(k, str): - key = k - else: - offset = k[1] - key = k[0] - if len(k) > 2: - function = k[2] - - if key in model_sd: - p.add(k) - current_patches: list[tuple] = current_hooked_patches.get(key, []) - # take difference between desired weight and existing weight to get diff - # TODO: create fix for fp8 - current_patches.append((strength_patch, (patches[k]-comfy.utils.get_attr(self.model, key),), strength_model, offset, function)) - current_hooked_patches[key] = current_patches - self.hooked_patches[lora_hook.hook_ref] = current_hooked_patches - # since should care about these patches too to determine if same model, reroll patches_uuid - self.patches_uuid = uuid.uuid4() - return list(p) - - def get_combined_hooked_patches(self, lora_hooks: LoraHookGroup): - ''' - Returns patches for selected lora_hooks. - ''' - # combined_patches will contain weights of all relevant lora_hooks, per key - combined_patches = {} - if lora_hooks is not None: - for hook in lora_hooks.hooks: - hook_patches: dict = self.hooked_patches.get(hook.hook_ref, {}) - for key in hook_patches.keys(): - current_patches: list[tuple] = combined_patches.get(key, []) - current_patches.extend(hook_patches[key]) - combined_patches[key] = current_patches - return combined_patches - - def patch_hooked_replace_weight_to_device(self, model_sd: dict, replace_patches: dict): - # first handle replace_patches - for key in replace_patches: - if key not in model_sd: - logger.warning(f"CLIP LoraHook could not replace patch. key doesn't exist in model: {key}") - continue - weight: Tensor = comfy.utils.get_attr(self.model, key) - inplace_update = self.weight_inplace_update - target_device = weight.device - - if key not in self.hooked_backup: - self.hooked_backup[key] = (weight.to(device=target_device, copy=inplace_update), weight.device) - out_weight = replace_patches[key].to(target_device) - if inplace_update: - comfy.utils.copy_to_param(self.model, key, out_weight) - else: - comfy.utils.set_attr_param(self.model, key, out_weight) - - def patch_model(self, device_to=None, *args, **kwargs): - if self.desired_lora_hooks is not None: - self.patches_backup = self.patches.copy() - relevant_patches = self.get_combined_hooked_patches(lora_hooks=self.desired_lora_hooks) - for key in relevant_patches: - self.patches.setdefault(key, []) - self.patches[key].extend(relevant_patches[key]) - self.current_lora_hooks = self.desired_lora_hooks - return super().patch_model(device_to, *args, **kwargs) +def inject_motion_models(patcher: ModelPatcher): + helper = ModelPatcherHelper(patcher) + motion_models = helper.get_motion_models() + for mm in motion_models: + mm.model.inject(patcher) - def load(self, device_to=None, lowvram_model_memory=0, *args, **kwargs): - try: - return super().load(device_to=device_to, lowvram_model_memory=lowvram_model_memory, *args, **kwargs) - finally: - if lowvram_model_memory > 0: - self._patch_lowvram_extras() - - def _patch_lowvram_extras(self): - # check if any modules have weight_function or bias_function that is not None - # NOTE: this serves no purpose currently, but I have it here for future reasons - self.model_params_lowvram = False - self.model_params_lowvram_keys.clear() - for n, m in self.model.named_modules(): - if not hasattr(m, "comfy_cast_weights"): - continue - if getattr(m, "weight_function", None) is not None: - self.model_params_lowvram = True - self.model_params_lowvram_keys[f"{n}.weight"] = n - if getattr(m, "bias_function", None) is not None: - self.model_params_lowvram = True - self.model_params_lowvram_keys[f"{n}.weight"] = n - - def unpatch_model(self, device_to=None, unpatch_weights=True, *args, **kwargs): - try: - return super().unpatch_model(device_to, unpatch_weights, *args, **kwargs) - finally: - self.patches = self.patches_backup.copy() - self.patches_backup.clear() - # handle replace patches - keys = list(self.hooked_backup.keys()) - if self.weight_inplace_update: - for k in keys: - comfy.utils.copy_to_param(self.model, k, self.hooked_backup[k][0].to(device=self.hooked_backup[k][1])) - else: - for k in keys: - comfy.utils.set_attr_param(self.model, k, self.hooked_backup[k][0].to(device=self.hooked_backup[k][1])) - self.model_params_lowvram = False - self.model_params_lowvram_keys.clear() - # clear hooked_backup - self.hooked_backup.clear() - self.current_lora_hooks = None - - -def load_hooked_lora_for_models(model: Union[ModelPatcher, ModelPatcherAndInjector], clip: CLIP, lora: dict[str, Tensor], lora_hook: LoraHook, - strength_model: float, strength_clip: float): - key_map = {} - if model is not None: - key_map = comfy.lora.model_lora_keys_unet(model.model, key_map) - if clip is not None: - key_map = comfy.lora.model_lora_keys_clip(clip.cond_stage_model, key_map) - - loaded: dict[str] = comfy.lora.load_lora(lora, key_map) - if model is not None: - new_modelpatcher = ModelPatcherAndInjector.create_from(model) - k = new_modelpatcher.add_hooked_patches(lora_hook=lora_hook, patches=loaded, strength_patch=strength_model) - else: - k = () - new_modelpatcher = None - - if clip is not None: - new_clip = CLIPWithHooks(clip) - k1 = new_clip.add_hooked_patches(lora_hook=lora_hook, patches=loaded, strength_patch=strength_clip) - else: - k1 = () - new_clip = None - k = set(k) - k1 = set(k1) - for x in loaded: - if (x not in k) and (x not in k1): - logger.warning(f"NOT LOADED {x}") - return (new_modelpatcher, new_clip) - - -def load_model_as_hooked_lora_for_models(model: Union[ModelPatcher, ModelPatcherAndInjector], clip: CLIP, model_loaded: ModelPatcher, clip_loaded: CLIP, lora_hook: LoraHook, - strength_model: float, strength_clip: float): - if model is not None and model_loaded is not None: - new_modelpatcher = ModelPatcherAndInjector.create_from(model) - comfy.model_management.unload_model_clones(new_modelpatcher) - expected_model_keys = set(model_loaded.model.state_dict().keys()) - patches_model: dict[str, Tensor] = model_loaded.model.state_dict() - # do not include ANY model_sampling components of the model that should act as a patch - for key in list(patches_model.keys()): - if key.startswith("model_sampling"): - expected_model_keys.discard(key) - patches_model.pop(key, None) - k = new_modelpatcher.add_hooked_patches_as_diffs(lora_hook=lora_hook, patches=patches_model, strength_patch=strength_model) - else: - k = () - new_modelpatcher = None - - if clip is not None and clip_loaded is not None: - new_clip = CLIPWithHooks(clip) - comfy.model_management.unload_model_clones(new_clip.patcher) - expected_clip_keys = clip_loaded.patcher.model.state_dict().copy() - patches_clip: dict[str, Tensor] = clip_loaded.cond_stage_model.state_dict() - k1 = new_clip.add_hooked_patches_as_diffs(lora_hook=lora_hook, patches=patches_clip, strength_patch=strength_clip) - else: - k1 = () - new_clip = None - - k = set(k) - k1 = set(k1) - if model is not None and model_loaded is not None: - for key in expected_model_keys: - if key not in k: - logger.warning(f"MODEL-AS-LORA NOT LOADED {key}") - if clip is not None and clip_loaded is not None: - for key in expected_clip_keys: - if key not in k1: - logger.warning(f"CLIP-AS-LORA NOT LOADED {key}") - - return (new_modelpatcher, new_clip) + +def eject_motion_models(patcher: ModelPatcher): + helper = ModelPatcherHelper(patcher) + motion_models = helper.get_motion_models() + for mm in motion_models: + mm.model.eject(patcher) class MotionModelPatcher(ModelPatcher): @@ -854,7 +211,7 @@ def _handle_float8_pe_tensors(self): break comfy.utils.set_attr(self.model, key, comfy.utils.get_attr(self.model, key).half()) - def pre_run(self, model: ModelPatcherAndInjector): + def pre_run(self, model: ModelPatcher): self.cleanup() self.model.set_scale(self.scale_multival, self.per_block_list) self.model.set_effect(self.effect_multival, self.per_block_list) @@ -1174,7 +531,11 @@ class MotionModelGroup: def __init__(self, init_motion_model: MotionModelPatcher=None): self.models: list[MotionModelPatcher] = [] if init_motion_model is not None: - self.add(init_motion_model) + if isinstance(init_motion_model, list): + for m in init_motion_model: + self.add(m) + else: + self.add(init_motion_model) def add(self, mm: MotionModelPatcher): # add to end of list @@ -1211,7 +572,7 @@ def initialize_timesteps(self, model: BaseModel): for motion_model in self.models: motion_model.initialize_timesteps(model) - def pre_run(self, model: ModelPatcherAndInjector): + def pre_run(self, model: ModelPatcher): for motion_model in self.models: motion_model.pre_run(model) @@ -1653,3 +1014,6 @@ def clone(self) -> 'InjectionParams': new_params.set_context(self.context_options) new_params.set_motion_model_settings(self.motion_model_settings) # Gen1 return new_params + + def on_model_patcher_clone(self): + return self.clone() diff --git a/animatediff/nodes.py b/animatediff/nodes.py index be319d6..40fa7a0 100644 --- a/animatediff/nodes.py +++ b/animatediff/nodes.py @@ -51,8 +51,8 @@ from .logger import logger # override comfy_sample.sample with animatediff-support version -comfy_sample.sample = motion_sample_factory(comfy_sample.sample) -comfy_sample.sample_custom = motion_sample_factory(comfy_sample.sample_custom, is_custom=True) +#comfy_sample.sample = motion_sample_factory(comfy_sample.sample) +#comfy_sample.sample_custom = motion_sample_factory(comfy_sample.sample_custom, is_custom=True) NODE_CLASS_MAPPINGS = { diff --git a/animatediff/nodes_conditioning.py b/animatediff/nodes_conditioning.py index 9e3b263..6dfda75 100644 --- a/animatediff/nodes_conditioning.py +++ b/animatediff/nodes_conditioning.py @@ -11,7 +11,6 @@ from .conditioning import (COND_CONST, TimestepsCond, set_mask_conds, set_mask_and_combine_conds, set_unmasked_and_combine_conds, LoraHook, LoraHookGroup, LoraHookKeyframe, LoraHookKeyframeGroup) -from .model_injection import ModelPatcherAndInjector, CLIPWithHooks, load_hooked_lora_for_models, load_model_as_hooked_lora_for_models from .utils_model import BIGMAX, InterpolationMethod from .logger import logger @@ -443,7 +442,7 @@ def INPUT_TYPES(s): CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/conditioning/register lora hooks" FUNCTION = "load_lora" - def load_lora(self, model: Union[ModelPatcher, ModelPatcherAndInjector], clip: CLIP, lora_name: str, strength_model: float, strength_clip: float): + def load_lora(self, model: Union[ModelPatcher], clip: CLIP, lora_name: str, strength_model: float, strength_clip: float): if strength_model == 0 and strength_clip == 0: return (model, clip) diff --git a/animatediff/nodes_context.py b/animatediff/nodes_context.py index fd42683..c14a8b8 100644 --- a/animatediff/nodes_context.py +++ b/animatediff/nodes_context.py @@ -6,7 +6,6 @@ from .context import (ContextFuseMethod, ContextOptions, ContextOptionsGroup, ContextSchedules, generate_context_visualization) -from .model_injection import ModelPatcherAndInjector from .utils_model import BIGMAX, MAX_RESOLUTION @@ -379,7 +378,7 @@ def INPUT_TYPES(s): CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/context opts/visualize" FUNCTION = "visualize" - def visualize(self, model: ModelPatcherAndInjector, context_opts: ContextOptionsGroup, sampler_name: str, scheduler: str, + def visualize(self, model: ModelPatcher, context_opts: ContextOptionsGroup, sampler_name: str, scheduler: str, visual_width: 1280, latents_length=32, steps=20, start_step=0, end_step=20): images = generate_context_visualization(context_opts=context_opts, model=model, width=visual_width, video_length=latents_length, sampler_name=sampler_name, scheduler=scheduler, @@ -409,7 +408,7 @@ def INPUT_TYPES(s): CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/context opts/visualize" FUNCTION = "visualize" - def visualize(self, model: ModelPatcherAndInjector, context_opts: ContextOptionsGroup, sampler_name: str, scheduler: str, + def visualize(self, model: ModelPatcher, context_opts: ContextOptionsGroup, sampler_name: str, scheduler: str, visual_width: 1280, latents_length=32, steps=20, denoise=1.0): images = generate_context_visualization(context_opts=context_opts, model=model, width=visual_width, video_length=latents_length, sampler_name=sampler_name, scheduler=scheduler, @@ -436,7 +435,7 @@ def INPUT_TYPES(s): CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/context opts/visualize" FUNCTION = "visualize" - def visualize(self, model: ModelPatcherAndInjector, context_opts: ContextOptionsGroup, sigmas, + def visualize(self, model: ModelPatcher, context_opts: ContextOptionsGroup, sigmas, visual_width: 1280, latents_length=32): images = generate_context_visualization(context_opts=context_opts, model=model, width=visual_width, video_length=latents_length, sigmas=sigmas) diff --git a/animatediff/nodes_deprecated.py b/animatediff/nodes_deprecated.py index 7b6759c..baa1ddf 100644 --- a/animatediff/nodes_deprecated.py +++ b/animatediff/nodes_deprecated.py @@ -16,7 +16,7 @@ from .context import ContextOptionsGroup, ContextOptions, ContextSchedules from .logger import logger from .utils_model import Folders, BetaSchedules, get_available_motion_models -from .model_injection import ModelPatcherAndInjector, InjectionParams, MotionModelGroup, load_motion_module_gen1 +from .model_injection import InjectionParams, MotionModelGroup, load_motion_module_gen1 class AnimateDiffLoader_Deprecated: diff --git a/animatediff/nodes_gen1.py b/animatediff/nodes_gen1.py index 574c735..75dbced 100644 --- a/animatediff/nodes_gen1.py +++ b/animatediff/nodes_gen1.py @@ -11,7 +11,7 @@ from .utils_motion import ADKeyframeGroup, get_combined_multival from .motion_lora import MotionLoraInfo, MotionLoraList from .motion_module_ad import AllPerBlocks -from .model_injection import (InjectionParams, ModelPatcherAndInjector, MotionModelGroup, +from .model_injection import (InjectionParams, MotionModelGroup, load_motion_lora_as_patches, load_motion_module_gen1, load_motion_module_gen2, validate_model_compatibility_gen2, validate_per_block_compatibility) from .sample_settings import SampleSettings, SeedNoiseGeneration diff --git a/animatediff/nodes_gen2.py b/animatediff/nodes_gen2.py index 1098c77..8970c54 100644 --- a/animatediff/nodes_gen2.py +++ b/animatediff/nodes_gen2.py @@ -10,12 +10,84 @@ from .utils_motion import ADKeyframeGroup, ADKeyframe, InputPIA from .motion_lora import MotionLoraList from .motion_module_ad import AllPerBlocks -from .model_injection import (InjectionParams, ModelPatcherAndInjector, MotionModelGroup, MotionModelPatcher, create_fresh_motion_module, +from .model_injection import (ModelPatcherHelper, + InjectionParams, MotionModelGroup, MotionModelPatcher, create_fresh_motion_module, load_motion_module_gen2, load_motion_lora_as_patches, validate_model_compatibility_gen2, validate_per_block_compatibility) from .sample_settings import SampleSettings +from .sampling import outer_sample_wrapper class UseEvolvedSamplingNode: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": ("MODEL",), + "beta_schedule": (BetaSchedules.ALIAS_LIST, {"default": BetaSchedules.AUTOSELECT}), + }, + "optional": { + "m_models": ("M_MODELS",), + "context_options": ("CONTEXT_OPTIONS",), + "sample_settings": ("SAMPLE_SETTINGS",), + } + } + + RETURN_TYPES = ("MODEL",) + CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/β‘‘ Gen2 nodes β‘‘" + FUNCTION = "use_evolved_sampling" + + def use_evolved_sampling(self, model: ModelPatcher, beta_schedule: str, m_models: MotionModelGroup=None, context_options: ContextOptionsGroup=None, + sample_settings: SampleSettings=None): + model = model.clone() + helper = ModelPatcherHelper(model) + if m_models is not None: + m_models = m_models.clone() + # for each motion model, confirm that it is compatible with SD model + for motion_model in m_models.models: + validate_model_compatibility_gen2(model=model, motion_model=motion_model) + # create injection params + model_name_list = [motion_model.model.mm_info.mm_name for motion_model in m_models.models] + model_names = ",".join(model_name_list) + # TODO: check if any apply_v2_properly is set to False + params = InjectionParams(unlimited_area_hack=False, model_name=model_names) + helper.set_motion_models(m_models.models.copy()) + else: + params = InjectionParams() + helper.remove_motion_models() + # apply context options + if context_options: + params.set_context(context_options) + + sample_settings = sample_settings if sample_settings is not None else SampleSettings() + # attach sample settings and params to model + helper.set_sample_settings(sample_settings) + helper.set_params(params) + helper.set_outer_sample_wrapper(outer_sample_wrapper) + + if sample_settings.custom_cfg is not None: + logger.info("[Sample Settings] custom_cfg is set; will override any KSampler cfg values or patches.") + + if sample_settings.sigma_schedule is not None: + logger.info("[Sample Settings] sigma_schedule is set; will override beta_schedule.") + model.add_object_patch("model_sampling", sample_settings.sigma_schedule.clone().model_sampling) + else: + # save model_sampling from BetaSchedule as object patch + # if autoselect, get suggested beta_schedule from motion model + if beta_schedule == BetaSchedules.AUTOSELECT: + if helper.get_motion_models(): + beta_schedule = helper.get_motion_models()[0].model.get_best_beta_schedule(log=True) + else: + beta_schedule = BetaSchedules.USE_EXISTING + + new_model_sampling = BetaSchedules.to_model_sampling(beta_schedule, model) + if new_model_sampling is not None: + model.add_object_patch("model_sampling", new_model_sampling) + + del m_models + return (model,) + + +class UseEvolvedSamplingNodeOld: @classmethod def INPUT_TYPES(s): return { diff --git a/animatediff/sampling.py b/animatediff/sampling.py index 88295c4..ff845e5 100644 --- a/animatediff/sampling.py +++ b/animatediff/sampling.py @@ -16,6 +16,7 @@ import comfy.utils from comfy.controlnet import ControlBase from comfy.model_base import BaseModel +from comfy.model_patcher import ModelPatcher import comfy.conds import comfy.ops @@ -25,7 +26,7 @@ from .sample_settings import IterationOptions, SampleSettings, SeedNoiseGeneration, NoisedImageToInject from .utils_model import ModelTypeSD, MachineState, vae_encode_raw_batched, vae_decode_raw_batched from .utils_motion import composite_extend, get_combined_multival, prepare_mask_batch, extend_to_batch_size -from .model_injection import InjectionParams, ModelPatcherAndInjector, MotionModelGroup, MotionModelPatcher +from .model_injection import InjectionParams, ModelPatcherHelper, MotionModelGroup, MotionModelPatcher from .motion_module_ad import AnimateDiffFormat, AnimateDiffInfo, AnimateDiffVersion, VanillaTemporalModule from .logger import logger @@ -35,7 +36,7 @@ # Global variable to use to more conveniently hack variable access into samplers class AnimateDiffHelper_GlobalState: def __init__(self): - self.model_patcher: ModelPatcherAndInjector = None + self.model_patcher: ModelPatcher = None self.motion_models: MotionModelGroup = None self.params: InjectionParams = None self.sample_settings: SampleSettings = None @@ -54,15 +55,6 @@ def initialize(self, model: BaseModel): if self.sample_settings.custom_cfg is not None: self.sample_settings.custom_cfg.initialize_timesteps(model) - def hooks_initialize(self, model: BaseModel, hook_groups: list[LoraHookGroup]): - # this function is to be run the first time all gathered - if not self.hooks_initialized: - self.hooks_initialized = True - for hook_group in hook_groups: - for hook in hook_group.hooks: - hook.reset() - hook.initialize_timesteps(model) - def prepare_current_keyframes(self, x: Tensor, timestep: Tensor): if self.motion_models is not None: self.motion_models.prepare_current_keyframe(x=x, t=timestep) @@ -247,7 +239,7 @@ def diffusion_model_forward_groupnormed(*args, **kwargs): ################################################################################## -def apply_params_to_motion_models(motion_models: MotionModelGroup, params: InjectionParams): +def apply_params_to_motion_models_old(motion_models: MotionModelGroup, params: InjectionParams): params = params.clone() for context in params.context_options.contexts: if context.context_schedule == ContextSchedules.VIEW_AS_CONTEXT: @@ -287,66 +279,106 @@ def apply_params_to_motion_models(motion_models: MotionModelGroup, params: Injec return params +def apply_params_to_motion_models(helper: ModelPatcherHelper, params: InjectionParams): + params = params.clone() + for context in params.context_options.contexts: + if context.context_schedule == ContextSchedules.VIEW_AS_CONTEXT: + context.context_length = params.full_length + # TODO: check (and message) should be different based on use_on_equal_length setting + if params.context_options.context_length: + pass + + allow_equal = params.context_options.use_on_equal_length + if params.context_options.context_length: + enough_latents = params.full_length >= params.context_options.context_length if allow_equal else params.full_length > params.context_options.context_length + else: + enough_latents = False + if params.context_options.context_length and enough_latents: + logger.info(f"Sliding context window sampling activated - latents passed in ({params.full_length}) greater than context_length {params.context_options.context_length}.") + else: + logger.info(f"Regular sampling activated - latents passed in ({params.full_length}) less or equal to context_length {params.context_options.context_length}.") + params.reset_context() + if helper.get_motion_models(): + # if no context_length, treat video length as intended AD frame window + if not params.context_options.context_length: + for motion_model in helper.get_motion_models(): + if not motion_model.model.is_length_valid_for_encoding_max_len(params.full_length): + raise ValueError(f"Without a context window, AnimateDiff model {motion_model.model.mm_info.mm_name} has upper limit of {motion_model.model.encoding_max_len} frames, but received {params.full_length} latents.") + helper.set_video_length(params.full_length, params.full_length) + # otherwise, treat context_length as intended AD frame window + else: + for motion_model in helper.get_motion_models(): + view_options = params.context_options.view_options + context_length = view_options.context_length if view_options else params.context_options.context_length + if not motion_model.model.is_length_valid_for_encoding_max_len(context_length): + raise ValueError(f"AnimateDiff model {motion_model.model.mm_info.mm_name} has upper limit of {motion_model.model.encoding_max_len} frames for a context window, but received context length of {params.context_options.context_length}.") + helper.set_video_length(params.context_options.context_length, params.full_length) + # inject model + module_str = "modules" if len(helper.get_motion_models()) > 1 else "module" + logger.info(f"Using motion {module_str} {helper.get_name_string(show_version=True)}.") + return params + + class FunctionInjectionHolder: def __init__(self): self.temp_uninjector: GroupnormUninjectHelper = GroupnormUninjectHelper() self.groupnorm_injector: GroupnormInjectHelper = GroupnormInjectHelper() - def inject_functions(self, model: ModelPatcherAndInjector, params: InjectionParams): + def inject_functions(self, helper: ModelPatcherHelper, params: InjectionParams): # Save Original Functions - order must match between here and restore_functions self.orig_forward_timestep_embed = openaimodel.forward_timestep_embed # needed to account for VanillaTemporalModule - self.orig_memory_required = model.model.memory_required # allows for "unlimited area hack" to prevent halving of conds/unconds + self.orig_memory_required = helper.model.model.memory_required # allows for "unlimited area hack" to prevent halving of conds/unconds self.orig_groupnorm_forward = torch.nn.GroupNorm.forward # used to normalize latents to remove "flickering" of colors/brightness between frames self.orig_groupnorm_forward_comfy_cast_weights = comfy.ops.disable_weight_init.GroupNorm.forward_comfy_cast_weights - self.orig_diffusion_model_forward = model.model.diffusion_model.forward + self.orig_diffusion_model_forward = helper.model.model.diffusion_model.forward self.orig_sampling_function = comfy.samplers.sampling_function # used to support sliding context windows in samplers - self.orig_get_area_and_mult = comfy.samplers.get_area_and_mult - self.orig_get_additional_models = comfy.sampler_helpers.get_additional_models - self.orig_apply_model = model.model.apply_model + #self.orig_get_area_and_mult = comfy.samplers.get_area_and_mult + #self.orig_get_additional_models = comfy.sampler_helpers.get_additional_models + self.orig_apply_model = helper.model.model.apply_model # Inject Functions openaimodel.forward_timestep_embed = forward_timestep_embed_factory() if params.unlimited_area_hack: - model.model.memory_required = unlimited_memory_required - if model.motion_models is not None: + helper.model.model.memory_required = unlimited_memory_required + if helper.get_motion_models(): # only apply groupnorm hack if PIA, v2 and not properly applied, or v1 - info: AnimateDiffInfo = model.motion_models[0].model.mm_info + info: AnimateDiffInfo = helper.get_motion_models()[0].model.mm_info if ((info.mm_format == AnimateDiffFormat.PIA) or (info.mm_version == AnimateDiffVersion.V2 and not params.apply_v2_properly) or (info.mm_version == AnimateDiffVersion.V1)): self.inject_groupnorm_forward = groupnorm_mm_factory(params) self.inject_groupnorm_forward_comfy_cast_weights = groupnorm_mm_factory(params, manual_cast=True) self.groupnorm_injector = GroupnormInjectHelper(self) - model.model.diffusion_model.forward = diffusion_model_forward_groupnormed_factory(self.orig_diffusion_model_forward, self.groupnorm_injector) + helper.model.model.diffusion_model.forward = diffusion_model_forward_groupnormed_factory(self.orig_diffusion_model_forward, self.groupnorm_injector) # if mps device (Apple Silicon), disable batched conds to avoid black images with groupnorm hack try: - if model.load_device.type == "mps": - model.model.memory_required = unlimited_memory_required + if helper.model.load_device.type == "mps": + helper.model.model.memory_required = unlimited_memory_required except Exception: pass # if img_encoder or camera_encoder present, inject apply_model to handle correctly - for motion_model in model.motion_models: + for motion_model in helper.get_motion_models(): if (motion_model.model.img_encoder is not None) or (motion_model.model.camera_encoder is not None): - model.model.apply_model = apply_model_factory(self.orig_apply_model).__get__(model.model, type(model.model)) + helper.model.model.apply_model = apply_model_factory(self.orig_apply_model).__get__(helper.model.model, type(helper.model.model)) break del info comfy.samplers.sampling_function = evolved_sampling_function - comfy.samplers.get_area_and_mult = get_area_and_mult_ADE - comfy.sampler_helpers.get_additional_models = get_additional_models_factory(self.orig_get_additional_models, model.motion_models) + #comfy.samplers.get_area_and_mult = get_area_and_mult_ADE + #comfy.sampler_helpers.get_additional_models = get_additional_models_factory(self.orig_get_additional_models, model.motion_models) # create temp_uninjector to help facilitate uninjecting functions self.temp_uninjector = GroupnormUninjectHelper(self) - def restore_functions(self, model: ModelPatcherAndInjector): + def restore_functions(self, helper: ModelPatcherHelper): # Restoration try: - model.model.memory_required = self.orig_memory_required + helper.model.model.memory_required = self.orig_memory_required openaimodel.forward_timestep_embed = self.orig_forward_timestep_embed torch.nn.GroupNorm.forward = self.orig_groupnorm_forward comfy.ops.disable_weight_init.GroupNorm.forward_comfy_cast_weights = self.orig_groupnorm_forward_comfy_cast_weights - model.model.diffusion_model.forward = self.orig_diffusion_model_forward + helper.model.model.diffusion_model.forward = self.orig_diffusion_model_forward comfy.samplers.sampling_function = self.orig_sampling_function - comfy.samplers.get_area_and_mult = self.orig_get_area_and_mult - comfy.sampler_helpers.get_additional_models = self.orig_get_additional_models - model.model.apply_model = self.orig_apply_model + #comfy.samplers.get_area_and_mult = self.orig_get_area_and_mult + #comfy.sampler_helpers.get_additional_models = self.orig_get_additional_models + helper.model.model.apply_model = self.orig_apply_model except AttributeError: logger.error("Encountered AttributeError while attempting to restore functions - likely, an error occured while trying " + \ "to save original functions before injection, and a more specific error was thrown by ComfyUI.") @@ -428,10 +460,140 @@ def can_concat_cond_contextref_injection(c1, c2, *args, **kwargs): return can_concat_cond_contextref_injection +def outer_sample_wrapper(executor, guider: comfy.samplers.CFGGuider, *args, **kwargs): + # NOTE: OUTER_SAMPLE wrapper patch in ModelPatcher + latents = None + cached_latents = None + cached_noise = None + function_injections = FunctionInjectionHolder() + + try: + helper = ModelPatcherHelper(guider.model_patcher) + args = list(args) + # clone params from model + params = helper.get_params().clone() + # get amount of latents passed in, and store in params + noise: Tensor = args[0] + latents: Tensor = args[1] + params.full_length = latents.size(0) + # reset global state - TODO: remove global state + ADGS.reset() + + # apply custom noise, if needed + disable_noise = math.isclose(noise.max(), 0.0) + seed = args[-1] + + # apply params to motion model + # TODO: fill out + params = apply_params_to_motion_models(helper, params) + + # store and inject funtions + function_injections.inject_functions(helper, params) + + # prepare noise_extra_args for noise generation purposes + noise_extra_args = {"disable_noise": disable_noise} + params.set_noise_extra_args(noise_extra_args) + # if noise is not disabled, do noise stuff + if not disable_noise: + noise = helper.get_sample_settings().prepare_noise(seed, latents, noise, extra_args=noise_extra_args, force_create_noise=False) + + # callback setup + original_callback = args[-3] + def ad_callback(step, x0, x, total_steps): + if original_callback is not None: + original_callback(step, x0, x, total_steps) + # store denoised latents if image_injection will be used + if not helper.get_sample_settings().image_injection.is_empty(): + ADGS.callback_output_dict["x0"] = x0 + # update GLOBALSTATE for next iteration + ADGS.current_step = ADGS.start_step + step + 1 + args[-3] = ad_callback + ADGS.model_patcher = helper.model + ADGS.motion_models = MotionModelGroup(helper.get_motion_models()) + ADGS.sample_settings = helper.get_sample_settings() + ADGS.function_injections = function_injections + + # apply adapt_denoise_steps - does not work here! would need to mess with this elsewhere... + # TODO: implement proper wrapper to handle this feature... + + iter_opts = helper.get_sample_settings().iteration_opts + iter_opts.initialize(latents) + # cache initial noise and latents, if needed + if iter_opts.cache_init_latents: + cached_latents = latents.clone() + if iter_opts.cache_init_noise: + cached_noise = noise.clone() + # prepare iter opts preprocess kwargs, if needed + iter_kwargs = {} + # NOTE: original KSampler stuff is not doable here, so skipping + + for curr_i in range(iter_opts.iterations): + # handle GLOBALSTATE vars and step tally + ADGS.update_with_inject_params(params) + ADGS.start_step = kwargs.get("start_step") or 0 + ADGS.current_step = ADGS.start_step + ADGS.last_step = kwargs.get("last_step") or 0 + if iter_opts.iterations > 1: + logger.info(f"Iteration {curr_i+1}/{iter_opts.iterations}") + # perform any iter_opts preprocessing on latents + latents, noise = iter_opts.preprocess_latents(curr_i=curr_i, model=helper.model, latents=latents, noise=noise, + cached_latents=cached_latents, cached_noise=cached_noise, + seed=seed, + sample_settings=helper.get_sample_settings(), noise_extra_args=noise_extra_args, + **iter_kwargs) + if helper.get_sample_settings().noise_calibration is not None: + latents, noise = helper.get_sample_settings().noise_calibration.perform_calibration(sample_func=executor, model=helper.model, latents=latents, noise=noise, + is_custom=True, args=args, kwargs=kwargs) + # finalize latent_image in args + args[1] = latents + + helper.pre_run() + + if ADGS.sample_settings.image_injection.is_empty(): + latents = executor(guider, *tuple(args), **kwargs) + else: + ADGS.sample_settings.image_injection.initialize_timesteps(helper.model.model) + sigmas = args[3] + sigmas_list, injection_list = ADGS.sample_settings.image_injection.custom_ksampler_get_injections(helper.model, sigmas) + # useful logging + if len(injection_list) > 0: + inj_str = "s" if len(injection_list) > 1 else "" + logger.info(f"Found {len(injection_list)} applicable image injection{inj_str}; sampling will be split into {len(sigmas_list)}.") + else: + logger.info(f"Found 0 applicable image injections within the step bounds of this sampler; sampling unaffected.") + is_first = True + new_noise = noise + for i in range(len(sigmas_list)): + args[0] = new_noise + args[1] = latents + args[3] = sigmas_list[i] + latents = executor(guider, *tuple(args), **kwargs) + if is_first: + new_noise = torch.zeros_like(latents) + # if injection expected, perform injection + if i < len(injection_list): + to_inject = injection_list[i] + latents = perform_image_injection(helper.model.model, latents, to_inject) + return latents + finally: + del noise + del latents + del cached_latents + del cached_noise + # reset global state + ADGS.reset() + # clean motion_models + helper.cleanup_motion_models() + # restore injected functions + function_injections.restore_functions(helper) + del function_injections + del helper + + def motion_sample_factory(orig_comfy_sample: Callable, is_custom: bool=False) -> Callable: - def motion_sample(model: ModelPatcherAndInjector, noise: Tensor, *args, **kwargs): + def motion_sample(model: ModelPatcher, noise: Tensor, *args, **kwargs): # check if model is intended for injecting - if type(model) != ModelPatcherAndInjector: + if type(model) != ModelPatcher: return orig_comfy_sample(model, noise, *args, **kwargs) # otherwise, injection time latents = None @@ -452,7 +614,7 @@ def motion_sample(model: ModelPatcherAndInjector, noise: Tensor, *args, **kwargs seed = kwargs["seed"] # apply params to motion model - params = apply_params_to_motion_models(model.motion_models, params) + params = apply_params_to_motion_models_old(model.motion_models, params) # store and inject functions function_injections.inject_functions(model, params) @@ -672,7 +834,7 @@ def perform_image_injection(model: BaseModel, latents: Tensor, to_inject: Noised if hasattr(comfy.model_management, "loaded_models"): cached_loaded_models = comfy.model_management.loaded_models(only_currently_used=True) else: - cached_loaded_models: list[ModelPatcherAndInjector] = [x.model for x in comfy.model_management.current_loaded_models] + cached_loaded_models: list[ModelPatcher] = [x.model for x in comfy.model_management.current_loaded_models] try: orig_device = latents.device orig_dtype = latents.dtype From 3e3d29729fc0f1bc9df62800837b04869b53db2c Mon Sep 17 00:00:00 2001 From: Kosinkadink Date: Sat, 21 Sep 2024 21:48:24 +0900 Subject: [PATCH 09/43] Fixed FreeInit for new implementation --- animatediff/freeinit.py | 4 +++- animatediff/sample_settings.py | 8 ++++---- animatediff/sampling.py | 4 ++-- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/animatediff/freeinit.py b/animatediff/freeinit.py index ac9edfa..4b3bbc2 100644 --- a/animatediff/freeinit.py +++ b/animatediff/freeinit.py @@ -24,7 +24,7 @@ class FreeInitFilter: LIST = [GAUSSIAN, BUTTERWORTH, IDEAL, BOX] -def freq_mix_3d(x, noise, LPF): +def freq_mix_3d(x: torch.Tensor, noise: torch.Tensor, LPF: torch.Tensor): """ Noise reinitialization. @@ -33,6 +33,8 @@ def freq_mix_3d(x, noise, LPF): noise: randomly sampled noise LPF: low pass filter """ + noise = noise.to(dtype=x.dtype, device=x.device) + LPF = LPF.to(dtype=x.dtype, device=x.device) # FFT x_freq = fft.fftn(x, dim=(-4, -2, -1)) x_freq = fft.fftshift(x_freq, dim=(-4, -2, -1)) diff --git a/animatediff/sample_settings.py b/animatediff/sample_settings.py index 898414c..b7eeb38 100644 --- a/animatediff/sample_settings.py +++ b/animatediff/sample_settings.py @@ -476,7 +476,7 @@ def preprocess_latents(self, curr_i: int, model: ModelPatcher, latents: Tensor, alpha_cumprod = 1 / ((sigma * sigma) + 1) sqrt_alpha_prod = alpha_cumprod ** 0.5 sqrt_one_minus_alpha_prod = (1 - alpha_cumprod) ** 0.5 - noised_latents = latents * sqrt_alpha_prod + noise * sqrt_one_minus_alpha_prod + noised_latents = latents * sqrt_alpha_prod + noise.to(dtype=latents.dtype, device=latents.device) * sqrt_one_minus_alpha_prod # 2. create random noise z_rand for high frequency temp_sample_settings = sample_settings.clone() temp_sample_settings.batch_offset += self.iter_batch_offset * curr_i @@ -484,7 +484,7 @@ def preprocess_latents(self, curr_i: int, model: ModelPatcher, latents: Tensor, z_rand = temp_sample_settings.prepare_noise(seed=seed, latents=latents, noise=None, extra_args=noise_extra_args, force_create_noise=True) # 3. noise reinitialization - combines low freq. noise from noised_latents and high freq. noise from z_rand - noised_latents = freeinit.freq_mix_3d(x=noised_latents, noise=z_rand.to(dtype=latents.dtype, device=latents.device), LPF=self.freq_filter) + noised_latents = freeinit.freq_mix_3d(x=noised_latents, noise=z_rand, LPF=self.freq_filter) return cached_latents, noised_latents elif self.init_type == self.DINKINIT_V1: # NOTE: This was my first attempt at implementing FreeInit; it sorta works due to my alpha_cumprod shenanigans, @@ -492,7 +492,7 @@ def preprocess_latents(self, curr_i: int, model: ModelPatcher, latents: Tensor, # 1. apply initial noise with appropriate step sigma sigma = self.get_sigma(model, self.step-1000).to(latents.device) alpha_cumprod = 1 / ((sigma * sigma) + 1) #1 / ((sigma * sigma)) # 1 / ((sigma * sigma) + 1) - noised_latents = (latents + (cached_noise * sigma)) * alpha_cumprod + noised_latents = (latents + (cached_noise.to(dtype=latents.dtype, device=latents.device) * sigma)) * alpha_cumprod # 2. create random noise z_rand for high frequency temp_sample_settings = sample_settings.clone() temp_sample_settings.batch_offset += self.iter_batch_offset * curr_i @@ -501,7 +501,7 @@ def preprocess_latents(self, curr_i: int, model: ModelPatcher, latents: Tensor, extra_args=noise_extra_args, force_create_noise=True) ####z_rand = torch.randn_like(latents, dtype=latents.dtype, device=latents.device) # 3. noise reinitialization - combines low freq. noise from noised_latents and high freq. noise from z_rand - noised_latents = freeinit.freq_mix_3d(x=noised_latents, noise=z_rand.to(dtype=latents.dtype, device=latents.device), LPF=self.freq_filter) + noised_latents = freeinit.freq_mix_3d(x=noised_latents, noise=z_rand, LPF=self.freq_filter) return cached_latents, noised_latents else: raise ValueError(f"FreeInit init_type '{self.init_type}' is not recognized.") diff --git a/animatediff/sampling.py b/animatediff/sampling.py index ff845e5..1a5dae8 100644 --- a/animatediff/sampling.py +++ b/animatediff/sampling.py @@ -484,7 +484,6 @@ def outer_sample_wrapper(executor, guider: comfy.samplers.CFGGuider, *args, **kw seed = args[-1] # apply params to motion model - # TODO: fill out params = apply_params_to_motion_models(helper, params) # store and inject funtions @@ -525,7 +524,7 @@ def ad_callback(step, x0, x, total_steps): cached_noise = noise.clone() # prepare iter opts preprocess kwargs, if needed iter_kwargs = {} - # NOTE: original KSampler stuff is not doable here, so skipping + # NOTE: original KSampler stuff is not doable here, so skipping... for curr_i in range(iter_opts.iterations): # handle GLOBALSTATE vars and step tally @@ -545,6 +544,7 @@ def ad_callback(step, x0, x, total_steps): latents, noise = helper.get_sample_settings().noise_calibration.perform_calibration(sample_func=executor, model=helper.model, latents=latents, noise=noise, is_custom=True, args=args, kwargs=kwargs) # finalize latent_image in args + args[0] = noise args[1] = latents helper.pre_run() From a1b03d685144b6fb476218cad5cd19e7ed28baa2 Mon Sep 17 00:00:00 2001 From: Kosinkadink Date: Sun, 22 Sep 2024 16:38:57 +0900 Subject: [PATCH 10/43] Ported Gen1 nodes to use upcoming ComfyUI features, refactored some code to match ComfyUI changes --- animatediff/model_injection.py | 3 ++- animatediff/nodes_gen1.py | 42 +++++++++++++++++++--------------- animatediff/sampling.py | 6 +++-- 3 files changed, 30 insertions(+), 21 deletions(-) diff --git a/animatediff/model_injection.py b/animatediff/model_injection.py index 1857b78..70af9c2 100644 --- a/animatediff/model_injection.py +++ b/animatediff/model_injection.py @@ -36,6 +36,7 @@ class ModelPatcherHelper: MOTION_MODELS = "ADE_motion_models" SAMPLE_SETTINGS = "ADE_sample_settings" PARAMS = "ADE_params" + ADE = "ADE" def __init__(self, model: ModelPatcher): self.model = model @@ -90,7 +91,7 @@ def set_params(self, params: 'InjectionParams'): self.model.set_attachments(self.PARAMS, params) def set_outer_sample_wrapper(self, wrapper): - self.model.add_wrapper(WrappersMP.OUTER_SAMPLE, wrapper) + self.model.add_wrapper_with_key(WrappersMP.OUTER_SAMPLE, self.ADE, wrapper) def pre_run(self): diff --git a/animatediff/nodes_gen1.py b/animatediff/nodes_gen1.py index 75dbced..f9916eb 100644 --- a/animatediff/nodes_gen1.py +++ b/animatediff/nodes_gen1.py @@ -11,11 +11,11 @@ from .utils_motion import ADKeyframeGroup, get_combined_multival from .motion_lora import MotionLoraInfo, MotionLoraList from .motion_module_ad import AllPerBlocks -from .model_injection import (InjectionParams, MotionModelGroup, +from .model_injection import (ModelPatcherHelper, InjectionParams, MotionModelGroup, load_motion_lora_as_patches, load_motion_module_gen1, load_motion_module_gen2, validate_model_compatibility_gen2, validate_per_block_compatibility) from .sample_settings import SampleSettings, SeedNoiseGeneration -from .sampling import motion_sample_factory +from .sampling import outer_sample_wrapper class AnimateDiffLoaderGen1: @@ -82,24 +82,26 @@ def load_mm_and_inject_params(self, if params.motion_model_settings.mask_attn_scale is not None: motion_model.scale_multival = get_combined_multival(scale_multival, (params.motion_model_settings.mask_attn_scale * params.motion_model_settings.attn_scale)) + sample_settings = sample_settings if sample_settings is not None else SampleSettings() # need to use a ModelPatcher that supports injection of motion modules into unet - # need to use a ModelPatcher that supports injection of motion modules into unet - model = ModelPatcherAndInjector.create_from(model, hooks_only=True) - model.motion_models = MotionModelGroup(motion_model) - model.sample_settings = sample_settings if sample_settings is not None else SampleSettings() - model.motion_injection_params = params + model = model.clone() + helper = ModelPatcherHelper(model) + helper.set_motion_models([motion_model]) + helper.set_sample_settings(sample_settings) + helper.set_params(params) + helper.set_outer_sample_wrapper(outer_sample_wrapper) - if model.sample_settings.custom_cfg is not None: + if sample_settings.custom_cfg is not None: logger.info("[Sample Settings] custom_cfg is set; will override any KSampler cfg values or patches.") - if model.sample_settings.sigma_schedule is not None: + if sample_settings.sigma_schedule is not None: logger.info("[Sample Settings] sigma_schedule is set; will override beta_schedule.") - model.add_object_patch("model_sampling", model.sample_settings.sigma_schedule.clone().model_sampling) + model.add_object_patch("model_sampling", sample_settings.sigma_schedule.clone().model_sampling) else: # save model sampling from BetaSchedule as object patch # if autoselect, get suggested beta_schedule from motion model - if beta_schedule == BetaSchedules.AUTOSELECT and not model.motion_models.is_empty(): - beta_schedule = model.motion_models[0].model.get_best_beta_schedule(log=True) + if beta_schedule == BetaSchedules.AUTOSELECT and helper.get_motion_models(): + beta_schedule = helper.get_motion_models()[0].model.get_best_beta_schedule(log=True) new_model_sampling = BetaSchedules.to_model_sampling(beta_schedule, model) if new_model_sampling is not None: model.add_object_patch("model_sampling", new_model_sampling) @@ -165,15 +167,19 @@ def load_mm_and_inject_params(self, motion_model.keyframes = ad_keyframes.clone() if ad_keyframes else ADKeyframeGroup() - model = ModelPatcherAndInjector.create_from(model, hooks_only=True) - model.motion_models = MotionModelGroup(motion_model) - model.sample_settings = sample_settings if sample_settings is not None else SampleSettings() - model.motion_injection_params = params + sample_settings = sample_settings if sample_settings is not None else SampleSettings() + # need to use a ModelPatcher that supports injection of motion modules into unet + model = model.clone() + helper = ModelPatcherHelper() + helper.set_motion_models([motion_model]) + helper.set_sample_settings(sample_settings) + helper.set_params(params) + helper.set_outer_sample_wrapper(outer_sample_wrapper) # save model sampling from BetaSchedule as object patch # if autoselect, get suggested beta_schedule from motion model - if beta_schedule == BetaSchedules.AUTOSELECT and not model.motion_models.is_empty(): - beta_schedule = model.motion_models[0].model.get_best_beta_schedule(log=True) + if beta_schedule == BetaSchedules.AUTOSELECT and helper.get_motion_models(): + beta_schedule = helper.get_motion_models()[0].model.get_best_beta_schedule(log=True) new_model_sampling = BetaSchedules.to_model_sampling(beta_schedule, model) if new_model_sampling is not None: model.add_object_patch("model_sampling", new_model_sampling) diff --git a/animatediff/sampling.py b/animatediff/sampling.py index 1a5dae8..154d2e9 100644 --- a/animatediff/sampling.py +++ b/animatediff/sampling.py @@ -528,6 +528,8 @@ def ad_callback(step, x0, x, total_steps): for curr_i in range(iter_opts.iterations): # handle GLOBALSTATE vars and step tally + # NOTE: only KSampler/KSampler (Advanced) would have steps; + # explore modifying ComfyUI to provide this when possible? ADGS.update_with_inject_params(params) ADGS.start_step = kwargs.get("start_step") or 0 ADGS.current_step = ADGS.start_step @@ -810,7 +812,7 @@ def evolved_sampling_function(model, x: Tensor, timestep: Tensor, uncond, cond, del cfg_multival if not ADGS.is_using_sliding_context(): - cond_pred, uncond_pred = calc_conds_batch_wrapper(model, [cond, uncond_], x, timestep, model_options) + cond_pred, uncond_pred = comfy.samplers.calc_cond_batch(model, [cond, uncond_], x, timestep, model_options) else: cond_pred, uncond_pred = sliding_calc_conds_batch(model, [cond, uncond_], x, timestep, model_options) @@ -887,7 +889,7 @@ def wrapped_cfg_sliding_calc_cond_batch(model, conds, x_in, timestep, model_opti # when inside sliding_calc_conds_batch, should return to original calc_cond_batch comfy.samplers.calc_cond_batch = orig_calc_cond_batch if not ADGS.is_using_sliding_context(): - return calc_conds_batch_wrapper(model, conds, x_in, timestep, model_options) + return comfy.samplers.calc_cond_batch(model, conds, x_in, timestep, model_options) else: return sliding_calc_conds_batch(model, conds, x_in, timestep, model_options) finally: From f003299708c27e66c16a0d49ca916c6d7619020f Mon Sep 17 00:00:00 2001 From: Kosinkadink Date: Tue, 24 Sep 2024 14:05:46 +0900 Subject: [PATCH 11/43] Adapted sliding_calc_cond_batch and forward_timestep_embed to use new ModelPatcher wrappers/patches, removed no-longer used code, removed model_name from InjectionParams --- animatediff/model_injection.py | 76 +++- animatediff/nodes.py | 6 - animatediff/nodes_context.py | 1 + animatediff/nodes_deprecated.py | 4 - animatediff/nodes_gen1.py | 3 +- animatediff/nodes_gen2.py | 92 +---- animatediff/sampling.py | 604 +------------------------------- 7 files changed, 86 insertions(+), 700 deletions(-) diff --git a/animatediff/model_injection.py b/animatediff/model_injection.py index 70af9c2..44eff90 100644 --- a/animatediff/model_injection.py +++ b/animatediff/model_injection.py @@ -21,7 +21,7 @@ from .adapter_cameractrl import CameraPoseEncoder, CameraEntry, prepare_pose_embedding from .context import ContextOptions, ContextOptions, ContextOptionsGroup from .motion_module_ad import (AnimateDiffModel, AnimateDiffFormat, AnimateDiffInfo, EncoderOnlyAnimateDiffModel, VersatileAttention, PerBlock, AllPerBlocks, - has_mid_block, normalize_ad_state_dict, get_position_encoding_max_len) + VanillaTemporalModule, has_mid_block, normalize_ad_state_dict, get_position_encoding_max_len) from .logger import logger from .utils_motion import (ADKeyframe, ADKeyframeGroup, MotionCompatibilityError, InputPIA, get_combined_multival, get_combined_input, get_combined_input_effect_multival, @@ -33,7 +33,6 @@ class ModelPatcherHelper: - MOTION_MODELS = "ADE_motion_models" SAMPLE_SETTINGS = "ADE_sample_settings" PARAMS = "ADE_params" ADE = "ADE" @@ -41,25 +40,58 @@ class ModelPatcherHelper: def __init__(self, model: ModelPatcher): self.model = model + def set_all_properties(self, outer_sampler_wrapper: Callable, calc_cond_batch_wrapper: Callable, + params: 'InjectionParams', sample_settings: SampleSettings=None, motion_models: 'MotionModelGroup'=None): + self.set_outer_sample_wrapper(outer_sampler_wrapper) + self.set_calc_cond_batch_wrapper(calc_cond_batch_wrapper) + self.set_sample_settings(sample_settings = sample_settings if sample_settings is not None else SampleSettings()) + self.set_params(params) + if motion_models is not None: + self.set_motion_models(motion_models.models.copy()) + self.set_forward_timestep_embed_patch() + else: + self.remove_motion_models() + self.remove_forward_timestep_embed_patch() + def get_adgs(self): pass def get_motion_models(self) -> list['MotionModelPatcher']: - return self.model.additional_models.get(self.MOTION_MODELS, []) - + return self.model.additional_models.get(self.ADE, []) + def set_motion_models(self, motion_models: list['MotionModelPatcher']): - self.model.set_additional_models(self.MOTION_MODELS, motion_models) - self.model.set_injections(self.MOTION_MODELS, + self.model.set_additional_models(self.ADE, motion_models) + self.model.set_injections(self.ADE, [PatcherInjection(inject=inject_motion_models, eject=eject_motion_models)]) def remove_motion_models(self): - if self.MOTION_MODELS in self.model.additional_models: - self.model.additional_models.pop(self.MOTION_MODELS) - self.model.injections.pop(self.MOTION_MODELS) + self.model.remove_additional_models(self.ADE) + self.model.remove_injections(self.ADE) + def cleanup_motion_models(self): for motion_model in self.get_motion_models(): motion_model.cleanup() + + def set_forward_timestep_embed_patch(self): + self.remove_forward_timestep_embed_patch() + self.model.set_model_forward_timestep_embed_patch(create_forward_timestep_embed_patch()) + + def remove_forward_timestep_embed_patch(self): + if "transformer_options" in self.model.model_options: + transformer_options = self.model.model_options["transformer_options"] + if "patches" in transformer_options: + patches = transformer_options["patches"] + if "forward_timestep_embed_patch" in patches: + forward_timestep_patches: list = patches["forward_timestep_embed_patch"] + to_remove = [] + for idx, patch in enumerate(forward_timestep_patches): + if patch[1] == forward_timestep_embed_patch_ade: + to_remove.append(idx) + for idx in to_remove: + forward_timestep_patches.pop(idx) + + ########################## # motion models helpers def set_video_length(self, video_length: int, full_length: int): @@ -90,10 +122,19 @@ def get_params(self) -> 'InjectionParams': def set_params(self, params: 'InjectionParams'): self.model.set_attachments(self.PARAMS, params) - def set_outer_sample_wrapper(self, wrapper): - self.model.add_wrapper_with_key(WrappersMP.OUTER_SAMPLE, self.ADE, wrapper) + def set_outer_sample_wrapper(self, wrapper: Callable): + self.model.remove_wrappers_with_key(WrappersMP.OUTER_SAMPLE, self.ADE) + self.model.add_wrapper_with_key(WrappersMP.OUTER_SAMPLE, self.ADE, wrapper) + def set_calc_cond_batch_wrapper(self, wrapper: Callable): + self.model.remove_wrappers_with_key(WrappersMP.CALC_COND_BATCH, self.ADE) + self.model.add_wrapper_with_key(WrappersMP.CALC_COND_BATCH, self.ADE, wrapper) + + def remove_wrappers(self): + self.model.remove_wrappers_with_key(WrappersMP.OUTER_SAMPLE, self.ADE) + self.model.remove_wrappers_with_key(WrappersMP.CALC_COND_BATCH, self.ADE) + def pre_run(self): # TODO: could implement this as a ModelPatcher ON_PRE_RUN callback for motion_model in self.get_motion_models(): @@ -115,6 +156,13 @@ def eject_motion_models(patcher: ModelPatcher): mm.model.eject(patcher) +def create_forward_timestep_embed_patch(): + return (VanillaTemporalModule, forward_timestep_embed_patch_ade) + +def forward_timestep_embed_patch_ade(layer, x, emb, context, transformer_options, output_shape, time_context, num_video_frames, image_only_indicator, *args, **kwargs): + return layer(x, context) + + class MotionModelPatcher(ModelPatcher): # Mostly here so that type hints work in IDEs def __init__(self, *args, **kwargs): @@ -977,12 +1025,11 @@ def apply_mm_settings(model_dict: dict[str, Tensor], mm_settings: AnimateDiffSet class InjectionParams: - def __init__(self, unlimited_area_hack: bool=False, apply_mm_groupnorm_hack: bool=True, model_name: str="", + def __init__(self, unlimited_area_hack: bool=False, apply_mm_groupnorm_hack: bool=True, apply_v2_properly: bool=True) -> None: self.full_length = None self.unlimited_area_hack = unlimited_area_hack self.apply_mm_groupnorm_hack = apply_mm_groupnorm_hack - self.model_name = model_name self.apply_v2_properly = apply_v2_properly self.context_options: ContextOptionsGroup = ContextOptionsGroup.default() self.motion_model_settings = AnimateDiffSettings() # Gen1 @@ -1008,8 +1055,7 @@ def reset_context(self): def clone(self) -> 'InjectionParams': new_params = InjectionParams( - self.unlimited_area_hack, self.apply_mm_groupnorm_hack, - self.model_name, apply_v2_properly=self.apply_v2_properly, + self.unlimited_area_hack, self.apply_mm_groupnorm_hack, apply_v2_properly=self.apply_v2_properly, ) new_params.full_length = self.full_length new_params.set_context(self.context_options) diff --git a/animatediff/nodes.py b/animatediff/nodes.py index 40fa7a0..17cf97c 100644 --- a/animatediff/nodes.py +++ b/animatediff/nodes.py @@ -1,7 +1,5 @@ import comfy.sample as comfy_sample -from .sampling import motion_sample_factory - from .nodes_gen1 import (AnimateDiffLoaderGen1, LegacyAnimateDiffLoaderWithContext) from .nodes_gen2 import (UseEvolvedSamplingNode, ApplyAnimateDiffModelNode, ApplyAnimateDiffModelBasicNode, ADKeyframeNode, LoadAnimateDiffModelNode) @@ -50,10 +48,6 @@ from .logger import logger -# override comfy_sample.sample with animatediff-support version -#comfy_sample.sample = motion_sample_factory(comfy_sample.sample) -#comfy_sample.sample_custom = motion_sample_factory(comfy_sample.sample_custom, is_custom=True) - NODE_CLASS_MAPPINGS = { # Unencapsulated diff --git a/animatediff/nodes_context.py b/animatediff/nodes_context.py index c14a8b8..f7bc6d8 100644 --- a/animatediff/nodes_context.py +++ b/animatediff/nodes_context.py @@ -236,6 +236,7 @@ def INPUT_TYPES(s): "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), "guarantee_steps": ("INT", {"default": 1, "min": 0, "max": BIGMAX}), "prev_context": ("CONTEXT_OPTIONS",), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } diff --git a/animatediff/nodes_deprecated.py b/animatediff/nodes_deprecated.py index baa1ddf..c14eea7 100644 --- a/animatediff/nodes_deprecated.py +++ b/animatediff/nodes_deprecated.py @@ -50,8 +50,6 @@ def load_mm_and_inject_params( # set injection params params = InjectionParams( unlimited_area_hack=unlimited_area_hack, - apply_mm_groupnorm_hack=True, - model_name=model_name, apply_v2_properly=False, ) # inject for use in sampling code @@ -108,8 +106,6 @@ def load_mm_and_inject_params(self, # set injection params params = InjectionParams( unlimited_area_hack=unlimited_area_hack, - apply_mm_groupnorm_hack=True, - model_name=model_name, apply_v2_properly=False, ) context_group = ContextOptionsGroup() diff --git a/animatediff/nodes_gen1.py b/animatediff/nodes_gen1.py index f9916eb..7d54386 100644 --- a/animatediff/nodes_gen1.py +++ b/animatediff/nodes_gen1.py @@ -67,7 +67,7 @@ def load_mm_and_inject_params(self, motion_model.keyframes = ad_keyframes.clone() if ad_keyframes else ADKeyframeGroup() # create injection params - params = InjectionParams(unlimited_area_hack=False, model_name=motion_model.model.mm_info.mm_name) + params = InjectionParams(unlimited_area_hack=False) # apply context options if context_options: params.set_context(context_options) @@ -149,7 +149,6 @@ def load_mm_and_inject_params(self, # set injection params params = InjectionParams( unlimited_area_hack=False, - model_name=model_name, apply_v2_properly=apply_v2_models_properly, ) if context_options: diff --git a/animatediff/nodes_gen2.py b/animatediff/nodes_gen2.py index 8970c54..b97865e 100644 --- a/animatediff/nodes_gen2.py +++ b/animatediff/nodes_gen2.py @@ -14,7 +14,7 @@ InjectionParams, MotionModelGroup, MotionModelPatcher, create_fresh_motion_module, load_motion_module_gen2, load_motion_lora_as_patches, validate_model_compatibility_gen2, validate_per_block_compatibility) from .sample_settings import SampleSettings -from .sampling import outer_sample_wrapper +from .sampling import outer_sample_wrapper, sliding_calc_cond_batch class UseEvolvedSamplingNode: @@ -40,30 +40,27 @@ def use_evolved_sampling(self, model: ModelPatcher, beta_schedule: str, m_models sample_settings: SampleSettings=None): model = model.clone() helper = ModelPatcherHelper(model) + params = InjectionParams() if m_models is not None: m_models = m_models.clone() # for each motion model, confirm that it is compatible with SD model for motion_model in m_models.models: validate_model_compatibility_gen2(model=model, motion_model=motion_model) - # create injection params - model_name_list = [motion_model.model.mm_info.mm_name for motion_model in m_models.models] - model_names = ",".join(model_name_list) # TODO: check if any apply_v2_properly is set to False - params = InjectionParams(unlimited_area_hack=False, model_name=model_names) - helper.set_motion_models(m_models.models.copy()) - else: - params = InjectionParams() - helper.remove_motion_models() # apply context options if context_options: params.set_context(context_options) - sample_settings = sample_settings if sample_settings is not None else SampleSettings() - # attach sample settings and params to model - helper.set_sample_settings(sample_settings) - helper.set_params(params) - helper.set_outer_sample_wrapper(outer_sample_wrapper) - + # attach all properties to model to enable AnimateDiff functionality + helper.set_all_properties( + outer_sampler_wrapper=outer_sample_wrapper, + calc_cond_batch_wrapper=sliding_calc_cond_batch, + params=params, + sample_settings=sample_settings, + motion_models=m_models, + ) + + sample_settings = helper.get_sample_settings() if sample_settings.custom_cfg is not None: logger.info("[Sample Settings] custom_cfg is set; will override any KSampler cfg values or patches.") @@ -87,71 +84,6 @@ def use_evolved_sampling(self, model: ModelPatcher, beta_schedule: str, m_models return (model,) -class UseEvolvedSamplingNodeOld: - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "model": ("MODEL",), - "beta_schedule": (BetaSchedules.ALIAS_LIST, {"default": BetaSchedules.AUTOSELECT}), - }, - "optional": { - "m_models": ("M_MODELS",), - "context_options": ("CONTEXT_OPTIONS",), - "sample_settings": ("SAMPLE_SETTINGS",), - #"beta_schedule_override": ("BETA_SCHEDULE",), - } - } - - RETURN_TYPES = ("MODEL",) - CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/β‘‘ Gen2 nodes β‘‘" - FUNCTION = "use_evolved_sampling" - - def use_evolved_sampling(self, model: ModelPatcher, beta_schedule: str, m_models: MotionModelGroup=None, context_options: ContextOptionsGroup=None, - sample_settings: SampleSettings=None, beta_schedule_override=None): - if m_models is not None: - m_models = m_models.clone() - # for each motion model, confirm that it is compatible with SD model - for motion_model in m_models.models: - validate_model_compatibility_gen2(model=model, motion_model=motion_model) - # create injection params - model_name_list = [motion_model.model.mm_info.mm_name for motion_model in m_models.models] - model_names = ",".join(model_name_list) - # TODO: check if any apply_v2_properly is set to False - params = InjectionParams(unlimited_area_hack=False, model_name=model_names) - else: - params = InjectionParams() - # apply context options - if context_options: - params.set_context(context_options) - # need to use a ModelPatcher that supports injection of motion modules into unet - model = ModelPatcherAndInjector.create_from(model, hooks_only=True) - model.motion_models = m_models - model.sample_settings = sample_settings if sample_settings is not None else SampleSettings() - model.motion_injection_params = params - - if model.sample_settings.custom_cfg is not None: - logger.info("[Sample Settings] custom_cfg is set; will override any KSampler cfg values or patches.") - - if model.sample_settings.sigma_schedule is not None: - logger.info("[Sample Settings] sigma_schedule is set; will override beta_schedule.") - model.add_object_patch("model_sampling", model.sample_settings.sigma_schedule.clone().model_sampling) - else: - # save model_sampling from BetaSchedule as object patch - # if autoselect, get suggested beta_schedule from motion model - if beta_schedule == BetaSchedules.AUTOSELECT: - if model.motion_models is None or model.motion_models.is_empty(): - beta_schedule = BetaSchedules.USE_EXISTING - else: - beta_schedule = model.motion_models[0].model.get_best_beta_schedule(log=True) - new_model_sampling = BetaSchedules.to_model_sampling(beta_schedule, model) - if new_model_sampling is not None: - model.add_object_patch("model_sampling", new_model_sampling) - - del m_models - return (model,) - - class ApplyAnimateDiffModelNode: @classmethod def INPUT_TYPES(s): diff --git a/animatediff/sampling.py b/animatediff/sampling.py index 154d2e9..4905a61 100644 --- a/animatediff/sampling.py +++ b/animatediff/sampling.py @@ -20,7 +20,6 @@ import comfy.conds import comfy.ops -from .conditioning import COND_CONST, LoraHookGroup, conditioning_set_values from .context import ContextFuseMethod, ContextSchedules, get_context_weights, get_context_windows from .context_extras import ContextRefMode from .sample_settings import IterationOptions, SampleSettings, SeedNoiseGeneration, NoisedImageToInject @@ -63,10 +62,6 @@ def prepare_current_keyframes(self, x: Tensor, timestep: Tensor): if self.sample_settings.custom_cfg is not None: self.sample_settings.custom_cfg.prepare_current_keyframe(t=timestep) - def prepare_hooks_current_keyframes(self, timestep: Tensor, hook_groups: list[LoraHookGroup]): - if self.model_patcher is not None: - self.model_patcher.prepare_hooked_patches_current_keyframe(t=timestep, hook_groups=hook_groups) - def perform_special_model_features(self, model: BaseModel, conds: list, x_in: Tensor, model_options: dict[str]): if self.motion_models is not None: special_models = self.motion_models.get_special_models() @@ -151,36 +146,6 @@ def create_exposed_params(self): ################################################################################## #### Code Injection ################################################## -# refer to forward_timestep_embed in comfy/ldm/modules/diffusionmodules/openaimodel.py -def forward_timestep_embed_factory() -> Callable: - def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, output_shape=None, time_context=None, num_video_frames=None, image_only_indicator=None): - for layer in ts: - if isinstance(layer, openaimodel.VideoResBlock): - x = layer(x, emb, num_video_frames, image_only_indicator) - elif isinstance(layer, openaimodel.TimestepBlock): - x = layer(x, emb) - elif isinstance(layer, VanillaTemporalModule): - x = layer(x, context) - elif isinstance(layer, attention.SpatialVideoTransformer): - x = layer(x, context, time_context, num_video_frames, image_only_indicator, transformer_options) - if "transformer_index" in transformer_options: - transformer_options["transformer_index"] += 1 - if "current_index" in transformer_options: # keep this for backward compat, for now - transformer_options["current_index"] += 1 - elif isinstance(layer, attention.SpatialTransformer): - x = layer(x, context, transformer_options) - if "transformer_index" in transformer_options: - transformer_options["transformer_index"] += 1 - if "current_index" in transformer_options: # keep this for backward compat, for now - transformer_options["current_index"] += 1 - elif isinstance(layer, openaimodel.Upsample): - x = layer(x, output_shape=output_shape) - else: - x = layer(x) - return x - return forward_timestep_embed - - def unlimited_memory_required(*args, **kwargs): return 0 @@ -204,17 +169,6 @@ def groupnorm_mm_forward(self, input: Tensor) -> Tensor: return input return groupnorm_mm_forward - -def get_additional_models_factory(orig_get_additional_models: Callable, motion_models: MotionModelGroup): - def get_additional_models_with_motion(*args, **kwargs): - models, inference_memory = orig_get_additional_models(*args, **kwargs) - if motion_models is not None: - for motion_model in motion_models.models: - models.append(motion_model) - # TODO: account for inference memory as well? - return models, inference_memory - return get_additional_models_with_motion - def apply_model_factory(orig_apply_model: Callable): def apply_model_ade_wrapper(self, *args, **kwargs): x: Tensor = args[0] @@ -326,17 +280,13 @@ def __init__(self): def inject_functions(self, helper: ModelPatcherHelper, params: InjectionParams): # Save Original Functions - order must match between here and restore_functions - self.orig_forward_timestep_embed = openaimodel.forward_timestep_embed # needed to account for VanillaTemporalModule self.orig_memory_required = helper.model.model.memory_required # allows for "unlimited area hack" to prevent halving of conds/unconds self.orig_groupnorm_forward = torch.nn.GroupNorm.forward # used to normalize latents to remove "flickering" of colors/brightness between frames self.orig_groupnorm_forward_comfy_cast_weights = comfy.ops.disable_weight_init.GroupNorm.forward_comfy_cast_weights self.orig_diffusion_model_forward = helper.model.model.diffusion_model.forward self.orig_sampling_function = comfy.samplers.sampling_function # used to support sliding context windows in samplers - #self.orig_get_area_and_mult = comfy.samplers.get_area_and_mult - #self.orig_get_additional_models = comfy.sampler_helpers.get_additional_models self.orig_apply_model = helper.model.model.apply_model # Inject Functions - openaimodel.forward_timestep_embed = forward_timestep_embed_factory() if params.unlimited_area_hack: helper.model.model.memory_required = unlimited_memory_required if helper.get_motion_models(): @@ -362,8 +312,6 @@ def inject_functions(self, helper: ModelPatcherHelper, params: InjectionParams): break del info comfy.samplers.sampling_function = evolved_sampling_function - #comfy.samplers.get_area_and_mult = get_area_and_mult_ADE - #comfy.sampler_helpers.get_additional_models = get_additional_models_factory(self.orig_get_additional_models, model.motion_models) # create temp_uninjector to help facilitate uninjecting functions self.temp_uninjector = GroupnormUninjectHelper(self) @@ -371,13 +319,10 @@ def restore_functions(self, helper: ModelPatcherHelper): # Restoration try: helper.model.model.memory_required = self.orig_memory_required - openaimodel.forward_timestep_embed = self.orig_forward_timestep_embed torch.nn.GroupNorm.forward = self.orig_groupnorm_forward comfy.ops.disable_weight_init.GroupNorm.forward_comfy_cast_weights = self.orig_groupnorm_forward_comfy_cast_weights helper.model.model.diffusion_model.forward = self.orig_diffusion_model_forward comfy.samplers.sampling_function = self.orig_sampling_function - #comfy.samplers.get_area_and_mult = self.orig_get_area_and_mult - #comfy.sampler_helpers.get_additional_models = self.orig_get_additional_models helper.model.model.apply_model = self.orig_apply_model except AttributeError: logger.error("Encountered AttributeError while attempting to restore functions - likely, an error occured while trying " + \ @@ -592,201 +537,6 @@ def ad_callback(step, x0, x, total_steps): del helper -def motion_sample_factory(orig_comfy_sample: Callable, is_custom: bool=False) -> Callable: - def motion_sample(model: ModelPatcher, noise: Tensor, *args, **kwargs): - # check if model is intended for injecting - if type(model) != ModelPatcher: - return orig_comfy_sample(model, noise, *args, **kwargs) - # otherwise, injection time - latents = None - cached_latents = None - cached_noise = None - function_injections = FunctionInjectionHolder() - try: - # clone params from model - params = model.motion_injection_params.clone() - # get amount of latents passed in, and store in params - latents: Tensor = args[-1] - params.full_length = latents.size(0) - # reset global state - ADGS.reset() - - # apply custom noise, if needed - disable_noise = kwargs.get("disable_noise") or False - seed = kwargs["seed"] - - # apply params to motion model - params = apply_params_to_motion_models_old(model.motion_models, params) - - # store and inject functions - function_injections.inject_functions(model, params) - - # prepare noise_extra_args for noise generation purposes - noise_extra_args = {"disable_noise": disable_noise} - params.set_noise_extra_args(noise_extra_args) - # if noise is not disabled, do noise stuff - if not disable_noise: - noise = model.sample_settings.prepare_noise(seed, latents, noise, extra_args=noise_extra_args, force_create_noise=False) - - # callback setup - original_callback = kwargs.get("callback", None) - def ad_callback(step, x0, x, total_steps): - if original_callback is not None: - original_callback(step, x0, x, total_steps) - # store denoised latents if image_injection will be used - if not model.sample_settings.image_injection.is_empty(): - ADGS.callback_output_dict["x0"] = x0 - # update GLOBALSTATE for next iteration - ADGS.current_step = ADGS.start_step + step + 1 - kwargs["callback"] = ad_callback - ADGS.model_patcher = model - ADGS.motion_models = model.motion_models - ADGS.sample_settings = model.sample_settings - ADGS.function_injections = function_injections - - # apply adapt_denoise_steps - args = list(args) - if model.sample_settings.adapt_denoise_steps and not is_custom: - # only applicable when denoise and steps are provided (from simple KSampler nodes) - denoise = kwargs.get("denoise", None) - steps = args[0] - if denoise is not None and type(steps) == int: - args[0] = max(int(denoise * steps), 1) - - - iter_opts = IterationOptions() - if model.sample_settings is not None: - iter_opts = model.sample_settings.iteration_opts - iter_opts.initialize(latents) - # cache initial noise and latents, if needed - if iter_opts.cache_init_latents: - cached_latents = latents.clone() - if iter_opts.cache_init_noise: - cached_noise = noise.clone() - # prepare iter opts preprocess kwargs, if needed - iter_kwargs = {} - if iter_opts.need_sampler: - # -5 for sampler_name (not custom) and sampler (custom) - if is_custom: - iter_kwargs[IterationOptions.SAMPLER] = None #args[-5] - else: - iter_model = model - current_device = model.model.device - iter_kwargs[IterationOptions.SAMPLER] = comfy.samplers.KSampler( - iter_model, steps=999, #steps=args[-7], - device=current_device, sampler=args[-5], - scheduler=args[-4], denoise=kwargs.get("denoise", None), - model_options=model.model_options) - del iter_model - - for curr_i in range(iter_opts.iterations): - # handle GLOBALSTATE vars and step tally - ADGS.update_with_inject_params(params) - ADGS.start_step = kwargs.get("start_step") or 0 - ADGS.current_step = ADGS.start_step - ADGS.last_step = kwargs.get("last_step") or 0 - ADGS.hooks_initialized = False - if iter_opts.iterations > 1: - logger.info(f"Iteration {curr_i+1}/{iter_opts.iterations}") - # perform any iter_opts preprocessing on latents - latents, noise = iter_opts.preprocess_latents(curr_i=curr_i, model=model, latents=latents, noise=noise, - cached_latents=cached_latents, cached_noise=cached_noise, - seed=seed, - sample_settings=model.sample_settings, noise_extra_args=noise_extra_args, - **iter_kwargs) - if model.sample_settings.noise_calibration is not None: - latents, noise = model.sample_settings.noise_calibration.perform_calibration(sample_func=orig_comfy_sample, model=model, latents=latents, noise=noise, - is_custom=is_custom, args=args, kwargs=kwargs) - args[-1] = latents - - if model.motion_models is not None: - model.motion_models.pre_run(model) - if model.sample_settings is not None: - model.sample_settings.pre_run(model) - - if ADGS.sample_settings.image_injection.is_empty(): - latents = orig_comfy_sample(model, noise, *args, **kwargs) - else: - ADGS.sample_settings.image_injection.initialize_timesteps(model.model) - # separate handling for KSampler vs Custom KSampler - if is_custom: - sigmas = args[2] - sigmas_list, injection_list = ADGS.sample_settings.image_injection.custom_ksampler_get_injections(model, sigmas) - # useful logging - if len(injection_list) > 0: - inj_str = "s" if len(injection_list) > 1 else "" - logger.info(f"Found {len(injection_list)} applicable image injection{inj_str}; sampling will be split into {len(sigmas_list)}.") - else: - logger.info(f"Found 0 applicable image injections within the step bounds of this sampler; sampling unaffected.") - is_first = True - new_noise = noise - for i in range(len(sigmas_list)): - args[2] = sigmas_list[i] - args[-1] = latents - latents = orig_comfy_sample(model, new_noise, *args, **kwargs) - if is_first: - new_noise = torch.zeros_like(latents) - # if injection expected, perform injection - if i < len(injection_list): - to_inject = injection_list[i] - latents = perform_image_injection(model.model, latents, to_inject) - else: - is_ksampler_advanced = kwargs.get("start_step", None) is not None - # force_full_denoise should be respected on final sampling - should be True for normal KSampler - final_force_full_denoise = kwargs.get("force_full_denoise", False) - new_kwargs = kwargs.copy() - if not is_ksampler_advanced: - final_force_full_denoise = True - new_kwargs["start_step"] = 0 - new_kwargs["last_step"] = 10000 - steps_list, injection_list = ADGS.sample_settings.image_injection.ksampler_get_injections(model, scheduler=args[-4], sampler_name=args[-5], denoise=kwargs["denoise"], force_full_denoise=final_force_full_denoise, - start_step=new_kwargs["start_step"], last_step=new_kwargs["last_step"], total_steps=args[0]) - # useful logging - if len(injection_list) > 0: - inj_str = "s" if len(injection_list) > 1 else "" - logger.info(f"Found {len(injection_list)} applicable image injection{inj_str}; sampling will be split into {len(steps_list)}.") - else: - logger.info(f"Found 0 applicable image injections within the step bounds of this sampler; sampling unaffected.") - is_first = True - new_noise = noise - for i in range(len(steps_list)): - steps_range = steps_list[i] - args[-1] = latents - # first run will respect original disable_noise, but should have no effect on anything - # as disable_noise only does something in the functions that call this one - if not is_first: - new_kwargs["disable_noise"] = True - new_kwargs["start_step"] = steps_range[0] - new_kwargs["last_step"] = steps_range[1] - # if is last, respect original sampler's force_full_denoise - if i == len(steps_list)-1: - new_kwargs["force_full_denoise"] = final_force_full_denoise - else: - new_kwargs["force_full_denoise"] = False - latents = orig_comfy_sample(model, new_noise, *args, **new_kwargs) - if is_first: - new_noise = torch.zeros_like(latents) - # if injection expected, perform injection - if i < len(injection_list): - to_inject = injection_list[i] - latents = perform_image_injection(model.model, latents, to_inject) - return latents - finally: - del latents - del noise - del cached_latents - del cached_noise - # reset global state - ADGS.reset() - # clean motion_models - if model.motion_models is not None: - model.motion_models.cleanup() - # restore injected functions - function_injections.restore_functions(model) - del function_injections - return motion_sample - - def evolved_sampling_function(model, x: Tensor, timestep: Tensor, uncond, cond, cond_scale, model_options: dict={}, seed=None): ADGS.initialize(model) ADGS.prepare_current_keyframes(x=x, timestep=timestep) @@ -811,21 +561,13 @@ def evolved_sampling_function(model, x: Tensor, timestep: Tensor, uncond, cond, uncond_ = None del cfg_multival - if not ADGS.is_using_sliding_context(): - cond_pred, uncond_pred = comfy.samplers.calc_cond_batch(model, [cond, uncond_], x, timestep, model_options) - else: - cond_pred, uncond_pred = sliding_calc_conds_batch(model, [cond, uncond_], x, timestep, model_options) + cond_pred, uncond_pred = comfy.samplers.calc_cond_batch(model, [cond, uncond_], x, timestep, model_options) if ADGS.sample_settings.custom_cfg is not None: cond_scale = ADGS.sample_settings.custom_cfg.get_cfg_scale(cond_pred) model_options = ADGS.sample_settings.custom_cfg.get_model_options(model_options) - try: - cached_calc_cond_batch = comfy.samplers.calc_cond_batch - # support hooks and sliding context for PAG/other sampler_post_cfg_function tech that may use calc_cond_batch - comfy.samplers.calc_cond_batch = wrapped_cfg_sliding_calc_cond_batch_factory(cached_calc_cond_batch) - return comfy.samplers.cfg_function(model, cond_pred, uncond_pred, cond_scale, x, timestep, model_options, cond, uncond) - finally: - comfy.samplers.calc_cond_batch = cached_calc_cond_batch + + return comfy.samplers.cfg_function(model, cond_pred, uncond_pred, cond_scale, x, timestep, model_options, cond, uncond) finally: ADGS.restore_special_model_features(model) @@ -891,16 +633,19 @@ def wrapped_cfg_sliding_calc_cond_batch(model, conds, x_in, timestep, model_opti if not ADGS.is_using_sliding_context(): return comfy.samplers.calc_cond_batch(model, conds, x_in, timestep, model_options) else: - return sliding_calc_conds_batch(model, conds, x_in, timestep, model_options) + return sliding_calc_cond_batch(model, conds, x_in, timestep, model_options) finally: # make sure calc_cond_batch will become wrapped again comfy.samplers.calc_cond_batch = current_calc_cond_batch return wrapped_cfg_sliding_calc_cond_batch -# sliding_calc_conds_batch inspired by ashen's initial hack for 16-frame sliding context: +# initial sliding_calc_conds_batch inspired by ashen's initial hack for 16-frame sliding context: # https://github.com/comfyanonymous/ComfyUI/compare/master...ashen-sensored:ComfyUI:master -def sliding_calc_conds_batch(model, conds, x_in: Tensor, timestep, model_options): +def sliding_calc_cond_batch(executor: Callable, model, conds: list[list[dict]], x_in: Tensor, timestep, model_options): + if not ADGS.is_using_sliding_context(): + return executor(model, conds, x_in, timestep, model_options) + def prepare_control_objects(control: ControlBase, full_idxs: list[int]): if control.previous_controlnet is not None: prepare_control_objects(control.previous_controlnet, full_idxs) @@ -1066,7 +811,7 @@ def get_resized_cond(cond_in, full_idxs: list[int], context_length: int) -> list model_options["transformer_options"][CONTEXTREF_MACHINE_STATE] = MachineState.OFF #logger.info(f"window: {curr_window_idx} - {model_options['transformer_options'][CONTEXTREF_MACHINE_STATE]}") - sub_conds_out = calc_conds_batch_wrapper(model, sub_conds, sub_x, sub_timestep, model_options) + sub_conds_out = executor(model, sub_conds, sub_x, sub_timestep, model_options) if ADGS.params.context_options.fuse_method == ContextFuseMethod.RELATIVE: full_length = ADGS.params.full_length @@ -1149,331 +894,4 @@ def get_conds_with_c_concat(conds: list[dict], c_concat: comfy.conds.CONDNoiseSh resized_actual_cond[key] = new_model_conds resized_cond.append(resized_actual_cond) new_conds.append(resized_cond) - return new_conds - - -def calc_conds_batch_wrapper(model, conds: list[dict], x_in: Tensor, timestep, model_options): - # check if conds or unconds contain lora_hook or default_cond - contains_lora_hooks = False - has_default_cond = False - hook_groups = [] - for cond_uncond in conds: - if cond_uncond is None: - continue - for t in cond_uncond: - if COND_CONST.KEY_LORA_HOOK in t: - contains_lora_hooks = True - hook_groups.append(t[COND_CONST.KEY_LORA_HOOK]) - if COND_CONST.KEY_DEFAULT_COND in t: - has_default_cond = True - # if contains_lora_hooks: - # break - if contains_lora_hooks or has_default_cond: - ADGS.hooks_initialize(model, hook_groups=hook_groups) - ADGS.prepare_hooks_current_keyframes(timestep, hook_groups=hook_groups) - return calc_conds_batch_lora_hook(model, conds, x_in, timestep, model_options, has_default_cond) - return comfy.samplers.calc_cond_batch(model, conds, x_in, timestep, model_options) - - -# modified from comfy.samplers.get_area_and_mult -COND_OBJ = collections.namedtuple('cond_obj', ['input_x', 'mult', 'conditioning', 'area', 'control', 'patches']) -def get_area_and_mult_ADE(conds, x_in, timestep_in): - area = (x_in.shape[2], x_in.shape[3], 0, 0) - strength = 1.0 - - if 'timestep_start' in conds: - timestep_start = conds['timestep_start'] - if timestep_in[0] > timestep_start: - return None - if 'timestep_end' in conds: - timestep_end = conds['timestep_end'] - if timestep_in[0] < timestep_end: - return None - if 'area' in conds: - area = conds['area'] - if 'strength' in conds: - strength = conds['strength'] - - input_x = x_in[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] - if 'mask' in conds: - # Scale the mask to the size of the input - # The mask should have been resized as we began the sampling process - mask_strength = 1.0 - if "mask_strength" in conds: - mask_strength = conds["mask_strength"] - mask = conds['mask'] - assert(mask.shape[1] == x_in.shape[2]) - assert(mask.shape[2] == x_in.shape[3]) - # make sure mask is capped at input_shape batch length to prevent 0 as dimension - mask = mask[:input_x.shape[0], area[2]:area[0] + area[2], area[3]:area[1] + area[3]] * mask_strength - mask = mask.unsqueeze(1).repeat(input_x.shape[0] // mask.shape[0], input_x.shape[1], 1, 1) - else: - mask = torch.ones_like(input_x) - mult = mask * strength - - if 'mask' not in conds: - rr = 8 - if area[2] != 0: - for t in range(rr): - mult[:,:,t:1+t,:] *= ((1.0/rr) * (t + 1)) - if (area[0] + area[2]) < x_in.shape[2]: - for t in range(rr): - mult[:,:,area[0] - 1 - t:area[0] - t,:] *= ((1.0/rr) * (t + 1)) - if area[3] != 0: - for t in range(rr): - mult[:,:,:,t:1+t] *= ((1.0/rr) * (t + 1)) - if (area[1] + area[3]) < x_in.shape[3]: - for t in range(rr): - mult[:,:,:,area[1] - 1 - t:area[1] - t] *= ((1.0/rr) * (t + 1)) - - conditioning = {} - model_conds = conds["model_conds"] - for c in model_conds: - conditioning[c] = model_conds[c].process_cond(batch_size=x_in.shape[0], device=x_in.device, area=area) - - control = conds.get('control', None) - - patches = None - if 'gligen' in conds: - gligen = conds['gligen'] - patches = {} - gligen_type = gligen[0] - gligen_model = gligen[1] - if gligen_type == "position": - gligen_patch = gligen_model.model.set_position(input_x.shape, gligen[2], input_x.device) - elif gligen_type == "position_batched": - try: - gligen_model.model.set_position_batched_ADE = MethodType(gligen_batch_set_position_ADE, gligen_model.model) - gligen_patch = gligen_model.model.set_position_batched_ADE(input_x.shape, gligen[2], input_x.device) - finally: - delattr(gligen_model.model, "set_position_batched_ADE") - else: - gligen_patch = gligen_model.model.set_empty(input_x.shape, input_x.device) - - patches['middle_patch'] = [gligen_patch] - - return COND_OBJ(input_x, mult, conditioning, area, control, patches) - - -def separate_default_conds(conds: list[dict]): - normal_conds = [] - default_conds = [] - for i in range(len(conds)): - c = [] - default_c = [] - # if cond is None, make normal/default_conds reflect that too - if conds[i] is None: - c = None - default_c = [] - else: - for t in conds[i]: - # check if cond is a default cond - if COND_CONST.KEY_DEFAULT_COND in t: - default_c.append(t) - else: - c.append(t) - normal_conds.append(c) - default_conds.append(default_c) - return normal_conds, default_conds - - -def finalize_default_conds(hooked_to_run: dict[LoraHookGroup,list[tuple[COND_OBJ,int]]], default_conds: list[list[dict]], x_in: Tensor, timestep): - # need to figure out remaining unmasked area for conds - default_mults = [] - for d in default_conds: - default_mults.append(torch.ones_like(x_in)) - - # look through each finalized cond in hooked_to_run for 'mult' and subtract it from each cond - for lora_hooks, to_run in hooked_to_run.items(): - for cond_obj, i in to_run: - # if no default_cond for cond_type, do nothing - if len(default_conds[i]) == 0: - continue - area: list[int] = cond_obj.area - default_mults[i][:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] -= cond_obj.mult - - # for each default_mult, ReLU to make negatives=0, and then check for any nonzeros - for i, mult in enumerate(default_mults): - # if no default_cond for cond type, do nothing - if len(default_conds[i]) == 0: - continue - torch.nn.functional.relu(mult, inplace=True) - # if mult is all zeros, then don't add default_cond - if torch.max(mult) == 0.0: - continue - - cond = default_conds[i] - for x in cond: - # do get_area_and_mult to get all the expected values - p = comfy.samplers.get_area_and_mult(x, x_in, timestep) - if p is None: - continue - # replace p's mult with calculated mult - p = p._replace(mult=mult) - hook: LoraHookGroup = x.get(COND_CONST.KEY_LORA_HOOK, None) - hooked_to_run.setdefault(hook, list()) - hooked_to_run[hook] += [(p, i)] - - -# based on comfy.samplers.calc_conds_batch -def calc_conds_batch_lora_hook(model: BaseModel, conds: list[list[dict]], x_in: Tensor, timestep, model_options: dict, has_default_cond=False): - out_conds = [] - out_counts = [] - # separate conds by matching lora_hooks - hooked_to_run: dict[LoraHookGroup,list[tuple[collections.namedtuple,int]]] = {} - - # separate out default_conds, if needed - if has_default_cond: - conds, default_conds = separate_default_conds(conds) - - # cond is i=0, uncond is i=1 - for i in range(len(conds)): - out_conds.append(torch.zeros_like(x_in)) - out_counts.append(torch.ones_like(x_in) * 1e-37) - - cond = conds[i] - if cond is not None: - for x in cond: - p = comfy.samplers.get_area_and_mult(x, x_in, timestep) - if p is None: - continue - hook: LoraHookGroup = x.get(COND_CONST.KEY_LORA_HOOK, None) - hooked_to_run.setdefault(hook, list()) - hooked_to_run[hook] += [(p, i)] - - # finalize default_conds, if needed - if has_default_cond: - finalize_default_conds(hooked_to_run, default_conds, x_in, timestep) - - # run every hooked_to_run separately - for lora_hooks, to_run in hooked_to_run.items(): - while len(to_run) > 0: - first = to_run[0] - first_shape = first[0][0].shape - to_batch_temp = [] - for x in range(len(to_run)): - if comfy.samplers.can_concat_cond(to_run[x][0], first[0]): - to_batch_temp += [x] - - to_batch_temp.reverse() - to_batch = to_batch_temp[:1] - - free_memory = comfy.model_management.get_free_memory(x_in.device) - for i in range(1, len(to_batch_temp) + 1): - batch_amount = to_batch_temp[:len(to_batch_temp)//i] - input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:] - if model.memory_required(input_shape) < free_memory: - to_batch = batch_amount - break - ADGS.model_patcher.apply_lora_hooks(lora_hooks=lora_hooks) - - input_x = [] - mult = [] - c = [] - cond_or_uncond = [] - area = [] - control = None - patches = None - for x in to_batch: - o = to_run.pop(x) - p = o[0] - input_x.append(p.input_x) - mult.append(p.mult) - c.append(p.conditioning) - area.append(p.area) - cond_or_uncond.append(o[1]) - control = p.control - patches = p.patches - - batch_chunks = len(cond_or_uncond) - input_x = torch.cat(input_x) - c = comfy.samplers.cond_cat(c) - timestep_ = torch.cat([timestep] * batch_chunks) - - if control is not None: - c['control'] = control.get_control(input_x, timestep_, c, len(cond_or_uncond)) - - transformer_options = {} - if 'transformer_options' in model_options: - transformer_options = model_options['transformer_options'].copy() - - if patches is not None: - if "patches" in transformer_options: - cur_patches = transformer_options["patches"].copy() - for p in patches: - if p in cur_patches: - cur_patches[p] = cur_patches[p] + patches[p] - else: - cur_patches[p] = patches[p] - transformer_options["patches"] = cur_patches - else: - transformer_options["patches"] = patches - - transformer_options["cond_or_uncond"] = cond_or_uncond[:] - transformer_options["sigmas"] = timestep - - c['transformer_options'] = transformer_options - - if 'model_function_wrapper' in model_options: - output = model_options['model_function_wrapper'](model.apply_model, {"input": input_x, "timestep": timestep_, "c": c, "cond_or_uncond": cond_or_uncond}).chunk(batch_chunks) - else: - output = model.apply_model(input_x, timestep_, **c).chunk(batch_chunks) - - for o in range(batch_chunks): - cond_index = cond_or_uncond[o] - out_conds[cond_index][:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += output[o] * mult[o] - out_counts[cond_index][:,:,area[o][2]:area[o][0] + area[o][2],area[o][3]:area[o][1] + area[o][3]] += mult[o] - - for i in range(len(out_conds)): - out_conds[i] /= out_counts[i] - - return out_conds - -def gligen_batch_set_position_ADE(self, latent_image_shape: torch.Size, position_params_batch: list[list[tuple[Tensor, int, int, int, int]]], device): - batch, c, h, w = latent_image_shape - - all_boxes = [] - all_masks = [] - all_conds = [] - - # make sure there are enough position_params to match expected amount - if len(position_params_batch) < ADGS.params.full_length: - position_params_batch = position_params_batch.copy() - for _ in range(ADGS.params.full_length-len(position_params_batch)): - position_params_batch.append(position_params_batch[-1]) - - for batch_idx in range(batch): - if ADGS.params.sub_idxs is not None: - position_params = position_params_batch[ADGS.params.sub_idxs[batch_idx]] - else: - position_params = position_params_batch[batch_idx] - masks = torch.zeros([self.max_objs], device="cpu") - boxes = [] - positive_embeddings = [] - - for p in position_params: - x1 = (p[4]) / w - y1 = (p[3]) / h - x2 = (p[4] + p[2]) / w - y2 = (p[3] + p[1]) / h - masks[len(boxes)] = 1.0 - boxes.append(torch.tensor((x1, y1, x2, y2)).unsqueeze(0)) - positive_embeddings.append(p[0]) - - if len(boxes) < self.max_objs: - append_boxes = torch.zeros([self.max_objs - len(boxes), 4], device="cpu") - append_conds = torch.zeros([self.max_objs - len(boxes), self.key_dim], device="cpu") - boxes = torch.cat(boxes + [append_boxes]) - conds = torch.cat(positive_embeddings + [append_conds]) - else: - boxes = torch.cat(boxes) - conds = torch.cat(positive_embeddings) - all_boxes.append(boxes) - all_masks.append(masks) - all_conds.append(conds) - - box_out = torch.stack(all_boxes).to(device) - masks_out = torch.stack(all_masks).to(device) - conds_out = torch.stack(all_conds).to(device) - - return self._set_position(box_out, masks_out, conds_out) + return new_conds \ No newline at end of file From 8c3c948a2caf94bc76aec4b7e6d87f0356668353 Mon Sep 17 00:00:00 2001 From: Kosinkadink Date: Tue, 24 Sep 2024 16:03:38 +0900 Subject: [PATCH 12/43] Updated gen1 nodes and deprecated nodes to use new ModelPatcher system properly --- animatediff/nodes_deprecated.py | 32 ++++++++++++++----- animatediff/nodes_gen1.py | 54 ++++++++++++++++++++------------- 2 files changed, 58 insertions(+), 28 deletions(-) diff --git a/animatediff/nodes_deprecated.py b/animatediff/nodes_deprecated.py index c14eea7..a26e138 100644 --- a/animatediff/nodes_deprecated.py +++ b/animatediff/nodes_deprecated.py @@ -16,7 +16,8 @@ from .context import ContextOptionsGroup, ContextOptions, ContextSchedules from .logger import logger from .utils_model import Folders, BetaSchedules, get_available_motion_models -from .model_injection import InjectionParams, MotionModelGroup, load_motion_module_gen1 +from .model_injection import ModelPatcherHelper, InjectionParams, MotionModelGroup, load_motion_module_gen1 +from .sampling import outer_sample_wrapper, sliding_calc_cond_batch class AnimateDiffLoader_Deprecated: @@ -36,6 +37,7 @@ def INPUT_TYPES(s): RETURN_TYPES = ("MODEL", "LATENT") CATEGORY = "" FUNCTION = "load_mm_and_inject_params" + DEPRECATED = True def load_mm_and_inject_params( self, @@ -53,9 +55,14 @@ def load_mm_and_inject_params( apply_v2_properly=False, ) # inject for use in sampling code - model = ModelPatcherAndInjector.create_from(model, hooks_only=True) - model.motion_models = MotionModelGroup(motion_model) - model.motion_injection_params = params + model = model.clone() + helper = ModelPatcherHelper(model) + helper.set_all_properties( + outer_sampler_wrapper=outer_sample_wrapper, + calc_cond_batch_wrapper=sliding_calc_cond_batch, + params=params, + motion_models=MotionModelGroup(motion_model), + ) # save model sampling from BetaSchedule as object patch # if autoselect, get suggested beta_schedule from motion model @@ -91,6 +98,7 @@ def INPUT_TYPES(s): RETURN_TYPES = ("MODEL", "LATENT") CATEGORY = "" FUNCTION = "load_mm_and_inject_params" + DEPRECATED = True def load_mm_and_inject_params(self, model: ModelPatcher, @@ -121,9 +129,14 @@ def load_mm_and_inject_params(self, # set context settings params.set_context(context_options=context_group) # inject for use in sampling code - model = ModelPatcherAndInjector.create_from(model, hooks_only=True) - model.motion_models = MotionModelGroup(motion_model) - model.motion_injection_params = params + model = model.clone() + helper = ModelPatcherHelper(model) + helper.set_all_properties( + outer_sampler_wrapper=outer_sample_wrapper, + calc_cond_batch_wrapper=sliding_calc_cond_batch, + params=params, + motion_models=MotionModelGroup(motion_model), + ) # save model sampling from BetaSchedule as object patch # if autoselect, get suggested beta_schedule from motion model @@ -175,6 +188,7 @@ def INPUT_TYPES(s): OUTPUT_NODE = True CATEGORY = "" FUNCTION = "generate_gif" + DEPRECATED = True def generate_gif( self, @@ -295,6 +309,7 @@ def INPUT_TYPES(s): RETURN_TYPES = ("AD_SETTINGS",) CATEGORY = "" #"Animate Diff πŸŽ­πŸ…πŸ…“/β‘  Gen1 nodes β‘ /motion settings" FUNCTION = "get_motion_model_settings" + DEPRECATED = True def get_motion_model_settings(self, mask_motion_scale: torch.Tensor=None, min_motion_scale: float=1.0, max_motion_scale: float=1.0): motion_model_settings = AnimateDiffSettings( @@ -324,6 +339,7 @@ def INPUT_TYPES(s): RETURN_TYPES = ("AD_SETTINGS",) CATEGORY = "" #"Animate Diff πŸŽ­πŸ…πŸ…“/β‘  Gen1 nodes β‘ /motion settings/experimental" FUNCTION = "get_motion_model_settings" + DEPRECATED = True def get_motion_model_settings(self, motion_pe_stretch: int, mask_motion_scale: torch.Tensor=None, min_motion_scale: float=1.0, max_motion_scale: float=1.0): @@ -363,6 +379,7 @@ def INPUT_TYPES(s): RETURN_TYPES = ("AD_SETTINGS",) CATEGORY = "" #"Animate Diff πŸŽ­πŸ…πŸ…“/β‘  Gen1 nodes β‘ /motion settings/experimental" FUNCTION = "get_motion_model_settings" + DEPRECATED = True def get_motion_model_settings(self, pe_strength: float, attn_strength: float, other_strength: float, motion_pe_stretch: int, @@ -418,6 +435,7 @@ def INPUT_TYPES(s): RETURN_TYPES = ("AD_SETTINGS",) CATEGORY = "" #"Animate Diff πŸŽ­πŸ…πŸ…“/β‘  Gen1 nodes β‘ /motion settings/experimental" FUNCTION = "get_motion_model_settings" + DEPRECATED = True def get_motion_model_settings(self, pe_strength: float, attn_strength: float, attn_q_strength: float, diff --git a/animatediff/nodes_gen1.py b/animatediff/nodes_gen1.py index 7d54386..e0ac0db 100644 --- a/animatediff/nodes_gen1.py +++ b/animatediff/nodes_gen1.py @@ -15,7 +15,7 @@ load_motion_lora_as_patches, load_motion_module_gen1, load_motion_module_gen2, validate_model_compatibility_gen2, validate_per_block_compatibility) from .sample_settings import SampleSettings, SeedNoiseGeneration -from .sampling import outer_sample_wrapper +from .sampling import outer_sample_wrapper, sliding_calc_cond_batch class AnimateDiffLoaderGen1: @@ -82,15 +82,18 @@ def load_mm_and_inject_params(self, if params.motion_model_settings.mask_attn_scale is not None: motion_model.scale_multival = get_combined_multival(scale_multival, (params.motion_model_settings.mask_attn_scale * params.motion_model_settings.attn_scale)) - sample_settings = sample_settings if sample_settings is not None else SampleSettings() # need to use a ModelPatcher that supports injection of motion modules into unet model = model.clone() helper = ModelPatcherHelper(model) - helper.set_motion_models([motion_model]) - helper.set_sample_settings(sample_settings) - helper.set_params(params) - helper.set_outer_sample_wrapper(outer_sample_wrapper) + helper.set_all_properties( + outer_sampler_wrapper=outer_sample_wrapper, + calc_cond_batch_wrapper=sliding_calc_cond_batch, + params=params, + sample_settings=sample_settings, + motion_models=MotionModelGroup(motion_model), + ) + sample_settings = helper.get_sample_settings() if sample_settings.custom_cfg is not None: logger.info("[Sample Settings] custom_cfg is set; will override any KSampler cfg values or patches.") @@ -135,7 +138,6 @@ def INPUT_TYPES(s): CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/β‘  Gen1 nodes β‘ " FUNCTION = "load_mm_and_inject_params" - def load_mm_and_inject_params(self, model: ModelPatcher, model_name: str, beta_schedule: str,# apply_mm_groupnorm_hack: bool, @@ -166,22 +168,32 @@ def load_mm_and_inject_params(self, motion_model.keyframes = ad_keyframes.clone() if ad_keyframes else ADKeyframeGroup() - sample_settings = sample_settings if sample_settings is not None else SampleSettings() # need to use a ModelPatcher that supports injection of motion modules into unet model = model.clone() - helper = ModelPatcherHelper() - helper.set_motion_models([motion_model]) - helper.set_sample_settings(sample_settings) - helper.set_params(params) - helper.set_outer_sample_wrapper(outer_sample_wrapper) - - # save model sampling from BetaSchedule as object patch - # if autoselect, get suggested beta_schedule from motion model - if beta_schedule == BetaSchedules.AUTOSELECT and helper.get_motion_models(): - beta_schedule = helper.get_motion_models()[0].model.get_best_beta_schedule(log=True) - new_model_sampling = BetaSchedules.to_model_sampling(beta_schedule, model) - if new_model_sampling is not None: - model.add_object_patch("model_sampling", new_model_sampling) + helper = ModelPatcherHelper(model) + helper.set_all_properties( + outer_sampler_wrapper=outer_sample_wrapper, + calc_cond_batch_wrapper=sliding_calc_cond_batch, + params=params, + sample_settings=sample_settings, + motion_models=MotionModelGroup(motion_model), + ) + + sample_settings = helper.get_sample_settings() + if sample_settings.custom_cfg is not None: + logger.info("[Sample Settings] custom_cfg is set; will override any KSampler cfg values or patches.") + + if sample_settings.sigma_schedule is not None: + logger.info("[Sample Settings] sigma_schedule is set; will override beta_schedule.") + model.add_object_patch("model_sampling", sample_settings.sigma_schedule.clone().model_sampling) + else: + # save model sampling from BetaSchedule as object patch + # if autoselect, get suggested beta_schedule from motion model + if beta_schedule == BetaSchedules.AUTOSELECT and helper.get_motion_models(): + beta_schedule = helper.get_motion_models()[0].model.get_best_beta_schedule(log=True) + new_model_sampling = BetaSchedules.to_model_sampling(beta_schedule, model) + if new_model_sampling is not None: + model.add_object_patch("model_sampling", new_model_sampling) del motion_model return (model,) From 10e6b0bbae2e685eebc1d48d060537f392d627a4 Mon Sep 17 00:00:00 2001 From: Kosinkadink Date: Tue, 24 Sep 2024 21:49:35 +0900 Subject: [PATCH 13/43] Made Visualize Context Options nodes work after the refactor --- animatediff/context.py | 11 +++++++++-- animatediff/model_injection.py | 4 ++-- animatediff/nodes_context.py | 24 ++++++++++++------------ 3 files changed, 23 insertions(+), 16 deletions(-) diff --git a/animatediff/context.py b/animatediff/context.py index 8ba6dcf..df116d1 100644 --- a/animatediff/context.py +++ b/animatediff/context.py @@ -604,9 +604,14 @@ def draw_view(window: list[int], gd: GridDisplay): draw_subidxs(window=window, gd=gd, y_grid_offset=2, color=gd.vs.view_color) -def generate_context_visualization(context_opts: ContextOptionsGroup, model: ModelPatcher, sampler_name: str=None, scheduler: str=None, +def generate_context_visualization(model: ModelPatcher, context_opts: ContextOptionsGroup=None, sampler_name: str=None, scheduler: str=None, width=1440, height=200, video_length=32, steps=None, start_step=None, end_step=None, sigmas=None, force_full_denoise=False, denoise=None): + if context_opts is None: + context_opts = ContextOptionsGroup.default() + params = model.get_attachment("ADE_params") + if params is not None: + context_opts = params.context_options context_opts = context_opts.clone() vs = VisualizeSettings(width, video_length) all_imgs = [] @@ -642,7 +647,9 @@ def generate_context_visualization(context_opts: ContextOptionsGroup, model: Mod # check if context should even be active in this case context_active = True - if video_length < context_opts.context_length: + if context_opts.context_length is None: + context_active = False + elif video_length < context_opts.context_length: context_active = False elif video_length == context_opts.context_length and not context_opts.use_on_equal_length: context_active = False diff --git a/animatediff/model_injection.py b/animatediff/model_injection.py index 44eff90..6332471 100644 --- a/animatediff/model_injection.py +++ b/animatediff/model_injection.py @@ -110,14 +110,14 @@ def get_name_string(self, show_version=False): def get_sample_settings(self) -> SampleSettings: - return self.model.attachments.get(self.SAMPLE_SETTINGS, None) + return self.model.get_attachment(self.SAMPLE_SETTINGS) def set_sample_settings(self, sample_settings: SampleSettings): self.model.set_attachments(self.SAMPLE_SETTINGS, sample_settings) def get_params(self) -> 'InjectionParams': - return self.model.attachments.get(self.PARAMS) + return self.model.get_attachment(self.PARAMS) def set_params(self, params: 'InjectionParams'): self.model.set_attachments(self.PARAMS, params) diff --git a/animatediff/nodes_context.py b/animatediff/nodes_context.py index f7bc6d8..92babc4 100644 --- a/animatediff/nodes_context.py +++ b/animatediff/nodes_context.py @@ -362,11 +362,11 @@ def INPUT_TYPES(s): return { "required": { "model": ("MODEL",), - "context_opts": ("CONTEXT_OPTIONS",), "sampler_name": (comfy.samplers.KSampler.SAMPLERS, ), "scheduler": (comfy.samplers.KSampler.SCHEDULERS, ), }, "optional": { + "context_opts": ("CONTEXT_OPTIONS",), "visual_width": ("INT", {"min": 32, "max": MAX_RESOLUTION, "default": 1440}), "latents_length": ("INT", {"min": 1, "max": BIGMAX, "default": 32}), "steps": ("INT", {"min": 0, "max": BIGMAX, "default": 20}), @@ -379,9 +379,9 @@ def INPUT_TYPES(s): CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/context opts/visualize" FUNCTION = "visualize" - def visualize(self, model: ModelPatcher, context_opts: ContextOptionsGroup, sampler_name: str, scheduler: str, - visual_width: 1280, latents_length=32, steps=20, start_step=0, end_step=20): - images = generate_context_visualization(context_opts=context_opts, model=model, width=visual_width, video_length=latents_length, + def visualize(self, model: ModelPatcher, sampler_name: str, scheduler: str, context_opts: ContextOptionsGroup=None, + visual_width=1440, latents_length=32, steps=20, start_step=0, end_step=20): + images = generate_context_visualization(model=model, context_opts=context_opts, width=visual_width, video_length=latents_length, sampler_name=sampler_name, scheduler=scheduler, steps=steps, start_step=start_step, end_step=end_step) return (images,) @@ -393,11 +393,11 @@ def INPUT_TYPES(s): return { "required": { "model": ("MODEL",), - "context_opts": ("CONTEXT_OPTIONS",), "sampler_name": (comfy.samplers.KSampler.SAMPLERS, ), "scheduler": (comfy.samplers.KSampler.SCHEDULERS, ), }, "optional": { + "context_opts": ("CONTEXT_OPTIONS",), "visual_width": ("INT", {"min": 32, "max": MAX_RESOLUTION, "default": 1440}), "latents_length": ("INT", {"min": 1, "max": BIGMAX, "default": 32}), "steps": ("INT", {"min": 0, "max": BIGMAX, "default": 20}), @@ -409,9 +409,9 @@ def INPUT_TYPES(s): CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/context opts/visualize" FUNCTION = "visualize" - def visualize(self, model: ModelPatcher, context_opts: ContextOptionsGroup, sampler_name: str, scheduler: str, - visual_width: 1280, latents_length=32, steps=20, denoise=1.0): - images = generate_context_visualization(context_opts=context_opts, model=model, width=visual_width, video_length=latents_length, + def visualize(self, model: ModelPatcher, sampler_name: str, scheduler: str, context_opts: ContextOptionsGroup=None, + visual_width=1440, latents_length=32, steps=20, denoise=1.0): + images = generate_context_visualization(model=model, context_opts=context_opts, width=visual_width, video_length=latents_length, sampler_name=sampler_name, scheduler=scheduler, steps=steps, denoise=denoise) return (images,) @@ -423,10 +423,10 @@ def INPUT_TYPES(s): return { "required": { "model": ("MODEL",), - "context_opts": ("CONTEXT_OPTIONS",), "sigmas": ("SIGMAS", ), }, "optional": { + "context_opts": ("CONTEXT_OPTIONS",), "visual_width": ("INT", {"min": 32, "max": MAX_RESOLUTION, "default": 1440}), "latents_length": ("INT", {"min": 1, "max": BIGMAX, "default": 32}), } @@ -436,8 +436,8 @@ def INPUT_TYPES(s): CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/context opts/visualize" FUNCTION = "visualize" - def visualize(self, model: ModelPatcher, context_opts: ContextOptionsGroup, sigmas, - visual_width: 1280, latents_length=32): - images = generate_context_visualization(context_opts=context_opts, model=model, width=visual_width, video_length=latents_length, + def visualize(self, model: ModelPatcher, sigmas, context_opts: ContextOptionsGroup=None, + visual_width=1440, latents_length=32): + images = generate_context_visualization(model=model, context_opts=context_opts, width=visual_width, video_length=latents_length, sigmas=sigmas) return (images,) From b960317e42d9f992b8ebdf0a900e373f54c2ac58 Mon Sep 17 00:00:00 2001 From: Kosinkadink Date: Fri, 27 Sep 2024 12:15:47 +0900 Subject: [PATCH 14/43] Refactored outer_sample_wrapper to work with my changes in in-progress ComfyUI branch --- animatediff/sampling.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/animatediff/sampling.py b/animatediff/sampling.py index 4905a61..4fdf3eb 100644 --- a/animatediff/sampling.py +++ b/animatediff/sampling.py @@ -16,7 +16,7 @@ import comfy.utils from comfy.controlnet import ControlBase from comfy.model_base import BaseModel -from comfy.model_patcher import ModelPatcher +from comfy.model_patcher import ModelPatcher, WrapperExecutor import comfy.conds import comfy.ops @@ -405,7 +405,7 @@ def can_concat_cond_contextref_injection(c1, c2, *args, **kwargs): return can_concat_cond_contextref_injection -def outer_sample_wrapper(executor, guider: comfy.samplers.CFGGuider, *args, **kwargs): +def outer_sample_wrapper(executor: WrapperExecutor, *args, **kwargs): # NOTE: OUTER_SAMPLE wrapper patch in ModelPatcher latents = None cached_latents = None @@ -413,6 +413,7 @@ def outer_sample_wrapper(executor, guider: comfy.samplers.CFGGuider, *args, **kw function_injections = FunctionInjectionHolder() try: + guider: comfy.samplers.CFGGuider = executor.class_obj helper = ModelPatcherHelper(guider.model_patcher) args = list(args) # clone params from model @@ -497,7 +498,7 @@ def ad_callback(step, x0, x, total_steps): helper.pre_run() if ADGS.sample_settings.image_injection.is_empty(): - latents = executor(guider, *tuple(args), **kwargs) + latents = executor(*tuple(args), **kwargs) else: ADGS.sample_settings.image_injection.initialize_timesteps(helper.model.model) sigmas = args[3] @@ -514,7 +515,7 @@ def ad_callback(step, x0, x, total_steps): args[0] = new_noise args[1] = latents args[3] = sigmas_list[i] - latents = executor(guider, *tuple(args), **kwargs) + latents = executor(*tuple(args), **kwargs) if is_first: new_noise = torch.zeros_like(latents) # if injection expected, perform injection From 1698f9f4aae5e39b1eb5bb81d8d0cc938e290bbb Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Mon, 7 Oct 2024 11:00:24 -0500 Subject: [PATCH 15/43] Fixed device mismatch issue in perform_image_injection (Image Injection) code after code refactor --- animatediff/sampling.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/animatediff/sampling.py b/animatediff/sampling.py index 4fdf3eb..db26626 100644 --- a/animatediff/sampling.py +++ b/animatediff/sampling.py @@ -595,6 +595,7 @@ def perform_image_injection(model: BaseModel, latents: Tensor, to_inject: Noised encoded_x0 = vae_encode_raw_batched(to_inject.vae, decoded_images) # get difference between sampled latents and encoded_x0 + latents = latents.to(device=encoded_x0.device) encoded_x0 = latents - encoded_x0 # get mask, or default to full mask @@ -619,7 +620,7 @@ def perform_image_injection(model: BaseModel, latents: Tensor, to_inject: Noised strength = to_inject.strength_multival if type(strength) == Tensor: strength = extend_to_batch_size(prepare_mask_batch(strength, composited.shape), b) - return composited * strength + latents * (1.0 - strength) + return (composited * strength + latents * (1.0 - strength)).to(dtype=orig_dtype, device=orig_device) finally: comfy.model_management.load_models_gpu(cached_loaded_models) From 3e93f152b011a4453a6ccd955da319f52a72a6a6 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Tue, 8 Oct 2024 17:48:47 -0500 Subject: [PATCH 16/43] Removed no longer needed code from sampling.py, removed ContextRefInjector as upcoming ComfyUI/ACN changes will make its use obsolete --- animatediff/sampling.py | 103 +++------------------------------------- 1 file changed, 6 insertions(+), 97 deletions(-) diff --git a/animatediff/sampling.py b/animatediff/sampling.py index db26626..2053e17 100644 --- a/animatediff/sampling.py +++ b/animatediff/sampling.py @@ -1,15 +1,11 @@ from typing import Callable -import collections import math import torch from torch import Tensor from torch.nn.functional import group_norm from einops import rearrange -from types import MethodType -import comfy.ldm.modules.attention as attention -from comfy.ldm.modules.diffusionmodules import openaimodel import comfy.model_management import comfy.samplers import comfy.sampler_helpers @@ -22,11 +18,11 @@ from .context import ContextFuseMethod, ContextSchedules, get_context_weights, get_context_windows from .context_extras import ContextRefMode -from .sample_settings import IterationOptions, SampleSettings, SeedNoiseGeneration, NoisedImageToInject -from .utils_model import ModelTypeSD, MachineState, vae_encode_raw_batched, vae_decode_raw_batched -from .utils_motion import composite_extend, get_combined_multival, prepare_mask_batch, extend_to_batch_size -from .model_injection import InjectionParams, ModelPatcherHelper, MotionModelGroup, MotionModelPatcher -from .motion_module_ad import AnimateDiffFormat, AnimateDiffInfo, AnimateDiffVersion, VanillaTemporalModule +from .sample_settings import SampleSettings, NoisedImageToInject +from .utils_model import MachineState, vae_encode_raw_batched, vae_decode_raw_batched +from .utils_motion import composite_extend, prepare_mask_batch, extend_to_batch_size +from .model_injection import InjectionParams, ModelPatcherHelper, MotionModelGroup +from .motion_module_ad import AnimateDiffFormat, AnimateDiffInfo, AnimateDiffVersion from .logger import logger @@ -145,7 +141,6 @@ def create_exposed_params(self): ################################################################################## #### Code Injection ################################################## - def unlimited_memory_required(*args, **kwargs): return 0 @@ -187,52 +182,10 @@ def diffusion_model_forward_groupnormed(*args, **kwargs): with inject_helper: return orig_diffusion_model_forward(*args, **kwargs) return diffusion_model_forward_groupnormed - - ###################################################################### ################################################################################## -def apply_params_to_motion_models_old(motion_models: MotionModelGroup, params: InjectionParams): - params = params.clone() - for context in params.context_options.contexts: - if context.context_schedule == ContextSchedules.VIEW_AS_CONTEXT: - context.context_length = params.full_length - # TODO: check (and message) should be different based on use_on_equal_length setting - if params.context_options.context_length: - pass - - allow_equal = params.context_options.use_on_equal_length - if params.context_options.context_length: - enough_latents = params.full_length >= params.context_options.context_length if allow_equal else params.full_length > params.context_options.context_length - else: - enough_latents = False - if params.context_options.context_length and enough_latents: - logger.info(f"Sliding context window activated - latents passed in ({params.full_length}) greater than context_length {params.context_options.context_length}.") - else: - logger.info(f"Regular AnimateDiff activated - latents passed in ({params.full_length}) less or equal to context_length {params.context_options.context_length}.") - params.reset_context() - if motion_models is not None: - # if no context_length, treat video length as intended AD frame window - if not params.context_options.context_length: - for motion_model in motion_models.models: - if not motion_model.model.is_length_valid_for_encoding_max_len(params.full_length): - raise ValueError(f"Without a context window, AnimateDiff model {motion_model.model.mm_info.mm_name} has upper limit of {motion_model.model.encoding_max_len} frames, but received {params.full_length} latents.") - motion_models.set_video_length(params.full_length, params.full_length) - # otherwise, treat context_length as intended AD frame window - else: - for motion_model in motion_models.models: - view_options = params.context_options.view_options - context_length = view_options.context_length if view_options else params.context_options.context_length - if not motion_model.model.is_length_valid_for_encoding_max_len(context_length): - raise ValueError(f"AnimateDiff model {motion_model.model.mm_info.mm_name} has upper limit of {motion_model.model.encoding_max_len} frames for a context window, but received context length of {params.context_options.context_length}.") - motion_models.set_video_length(params.context_options.context_length, params.full_length) - # inject model - module_str = "modules" if len(motion_models.models) > 1 else "module" - logger.info(f"Using motion {module_str} {motion_models.get_name_string(show_version=True)}.") - return params - - def apply_params_to_motion_models(helper: ModelPatcherHelper, params: InjectionParams): params = params.clone() for context in params.context_options.contexts: @@ -383,28 +336,6 @@ def __exit__(self, *args, **kwargs): self.previous_dwi_gn_cast_weights = None -class ContextRefInjector: - def __init__(self): - self.orig_can_concat_cond = None - - def inject(self): - self.orig_can_concat_cond = comfy.samplers.can_concat_cond - comfy.samplers.can_concat_cond = ContextRefInjector.can_concat_cond_contextref_factory(self.orig_can_concat_cond) - - def restore(self): - if self.orig_can_concat_cond is not None: - comfy.samplers.can_concat_cond = self.orig_can_concat_cond - - @staticmethod - def can_concat_cond_contextref_factory(orig_func: Callable): - def can_concat_cond_contextref_injection(c1, c2, *args, **kwargs): - #return orig_func(c1, c2, *args, **kwargs) - if c1 is c2: - return True - return False - return can_concat_cond_contextref_injection - - def outer_sample_wrapper(executor: WrapperExecutor, *args, **kwargs): # NOTE: OUTER_SAMPLE wrapper patch in ModelPatcher latents = None @@ -625,23 +556,6 @@ def perform_image_injection(model: BaseModel, latents: Tensor, to_inject: Noised comfy.model_management.load_models_gpu(cached_loaded_models) -def wrapped_cfg_sliding_calc_cond_batch_factory(orig_calc_cond_batch): - def wrapped_cfg_sliding_calc_cond_batch(model, conds, x_in, timestep, model_options): - # current call to calc_cond_batch should refer to sliding version - try: - current_calc_cond_batch = comfy.samplers.calc_cond_batch - # when inside sliding_calc_conds_batch, should return to original calc_cond_batch - comfy.samplers.calc_cond_batch = orig_calc_cond_batch - if not ADGS.is_using_sliding_context(): - return comfy.samplers.calc_cond_batch(model, conds, x_in, timestep, model_options) - else: - return sliding_calc_cond_batch(model, conds, x_in, timestep, model_options) - finally: - # make sure calc_cond_batch will become wrapped again - comfy.samplers.calc_cond_batch = current_calc_cond_batch - return wrapped_cfg_sliding_calc_cond_batch - - # initial sliding_calc_conds_batch inspired by ashen's initial hack for 16-frame sliding context: # https://github.com/comfyanonymous/ComfyUI/compare/master...ashen-sensored:ComfyUI:master def sliding_calc_cond_batch(executor: Callable, model, conds: list[list[dict]], x_in: Tensor, timestep, model_options): @@ -727,7 +641,6 @@ def get_resized_cond(cond_in, full_idxs: list[int], context_length: int) -> list CONTEXTREF_MACHINE_STATE = "contextref_machine_state" CONTEXTREF_CLEAN_FUNC = "contextref_clean_func" contextref_active = False - contextref_injector = None contextref_mode = None contextref_idxs_set = None first_context = True @@ -753,9 +666,6 @@ def get_resized_cond(cond_in, full_idxs: list[int], context_length: int) -> list # get mode_override if present, mode otherwise contextref_mode = refcn.get_contextref_mode_replace() or ADGS.params.context_options.extras.context_ref.mode contextref_idxs_set = contextref_mode.indexes.copy() - # use injector to ensure only 1 cond or uncond will be batched at a time - contextref_injector = ContextRefInjector() - contextref_injector.inject() curr_window_idx = -1 naivereuse_active = False @@ -848,7 +758,6 @@ def get_resized_cond(cond_in, full_idxs: list[int], context_length: int) -> list # clean contextref stuff with provided ACN function, if applicable if contextref_active: model_options["transformer_options"][CONTEXTREF_CLEAN_FUNC]() - contextref_injector.restore() # handle NaiveReuse if cached_naive_conds is not None: @@ -896,4 +805,4 @@ def get_conds_with_c_concat(conds: list[dict], c_concat: comfy.conds.CONDNoiseSh resized_actual_cond[key] = new_model_conds resized_cond.append(resized_actual_cond) new_conds.append(resized_cond) - return new_conds \ No newline at end of file + return new_conds From 0b01ea1b2786be234485396de0604bd677a2b3f8 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Thu, 24 Oct 2024 22:24:32 -0500 Subject: [PATCH 17/43] Refactored ADGS to be a variable stored within transformer_options instead of a global variable --- animatediff/sampling.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/animatediff/sampling.py b/animatediff/sampling.py index 2053e17..c7d28af 100644 --- a/animatediff/sampling.py +++ b/animatediff/sampling.py @@ -7,6 +7,7 @@ from einops import rearrange import comfy.model_management +import comfy.model_patcher import comfy.samplers import comfy.sampler_helpers import comfy.utils @@ -29,7 +30,7 @@ ################################################################################## ###################################################################### # Global variable to use to more conveniently hack variable access into samplers -class AnimateDiffHelper_GlobalState: +class AnimateDiffGlobalState: def __init__(self): self.model_patcher: ModelPatcher = None self.motion_models: MotionModelGroup = None @@ -133,8 +134,6 @@ def create_exposed_params(self): "context_length": self.params.context_options.context_length, "sub_idxs": self.params.sub_idxs, } - -ADGS = AnimateDiffHelper_GlobalState() ###################################################################### ################################################################################## @@ -169,6 +168,7 @@ def apply_model_ade_wrapper(self, *args, **kwargs): x: Tensor = args[0] cond_or_uncond = kwargs["transformer_options"]["cond_or_uncond"] ad_params = kwargs["transformer_options"]["ad_params"] + ADGS: AnimateDiffGlobalState = kwargs["transformer_options"]["ADGS"] if ADGS.motion_models is not None: for motion_model in ADGS.motion_models.models: motion_model.prepare_alcmi2v_features(x=x, cond_or_uncond=cond_or_uncond, ad_params=ad_params, latent_format=self.latent_format) @@ -346,6 +346,13 @@ def outer_sample_wrapper(executor: WrapperExecutor, *args, **kwargs): try: guider: comfy.samplers.CFGGuider = executor.class_obj helper = ModelPatcherHelper(guider.model_patcher) + + orig_model_options = guider.model_options + guider.model_options = comfy.model_patcher.create_model_options_clone(guider.model_options) + # create ADGS in transformer_options + ADGS = AnimateDiffGlobalState() + guider.model_options["transformer_options"]["ADGS"] = ADGS + args = list(args) # clone params from model params = helper.get_params().clone() @@ -353,7 +360,7 @@ def outer_sample_wrapper(executor: WrapperExecutor, *args, **kwargs): noise: Tensor = args[0] latents: Tensor = args[1] params.full_length = latents.size(0) - # reset global state - TODO: remove global state + # reset global state ADGS.reset() # apply custom noise, if needed @@ -452,13 +459,15 @@ def ad_callback(step, x0, x, total_steps): # if injection expected, perform injection if i < len(injection_list): to_inject = injection_list[i] - latents = perform_image_injection(helper.model.model, latents, to_inject) + latents = perform_image_injection(ADGS, helper.model.model, latents, to_inject) return latents finally: + guider.model_options = orig_model_options del noise del latents del cached_latents del cached_noise + del orig_model_options # reset global state ADGS.reset() # clean motion_models @@ -470,6 +479,7 @@ def ad_callback(step, x0, x, total_steps): def evolved_sampling_function(model, x: Tensor, timestep: Tensor, uncond, cond, cond_scale, model_options: dict={}, seed=None): + ADGS: AnimateDiffGlobalState = model_options["transformer_options"]["ADGS"] ADGS.initialize(model) ADGS.prepare_current_keyframes(x=x, timestep=timestep) try: @@ -504,7 +514,7 @@ def evolved_sampling_function(model, x: Tensor, timestep: Tensor, uncond, cond, ADGS.restore_special_model_features(model) -def perform_image_injection(model: BaseModel, latents: Tensor, to_inject: NoisedImageToInject) -> Tensor: +def perform_image_injection(ADGS: AnimateDiffGlobalState, model: BaseModel, latents: Tensor, to_inject: NoisedImageToInject) -> Tensor: # NOTE: the latents here have already been process_latent_out'ed # get currently used models so they can be properly reloaded after perfoming VAE Encoding if hasattr(comfy.model_management, "loaded_models"): @@ -559,6 +569,7 @@ def perform_image_injection(model: BaseModel, latents: Tensor, to_inject: Noised # initial sliding_calc_conds_batch inspired by ashen's initial hack for 16-frame sliding context: # https://github.com/comfyanonymous/ComfyUI/compare/master...ashen-sensored:ComfyUI:master def sliding_calc_cond_batch(executor: Callable, model, conds: list[list[dict]], x_in: Tensor, timestep, model_options): + ADGS: AnimateDiffGlobalState = model_options["transformer_options"]["ADGS"] if not ADGS.is_using_sliding_context(): return executor(model, conds, x_in, timestep, model_options) From 959fd3fd4500f6ba10fb2b043e969766cbbdbb7b Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Sat, 2 Nov 2024 22:21:45 -0500 Subject: [PATCH 18/43] Match changes in patch_hooks_improved_memory ComfyUI branch --- animatediff/model_injection.py | 3 ++- animatediff/sampling.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/animatediff/model_injection.py b/animatediff/model_injection.py index 6332471..71fdeef 100644 --- a/animatediff/model_injection.py +++ b/animatediff/model_injection.py @@ -13,7 +13,8 @@ import comfy.lora import comfy.model_management import comfy.utils -from comfy.model_patcher import ModelPatcher, PatcherInjection, WrappersMP +from comfy.model_patcher import ModelPatcher +from comfy.patcher_extension import WrappersMP, PatcherInjection from comfy.model_base import BaseModel from comfy.sd import CLIP, VAE diff --git a/animatediff/sampling.py b/animatediff/sampling.py index c7d28af..7dcc4a6 100644 --- a/animatediff/sampling.py +++ b/animatediff/sampling.py @@ -13,7 +13,8 @@ import comfy.utils from comfy.controlnet import ControlBase from comfy.model_base import BaseModel -from comfy.model_patcher import ModelPatcher, WrapperExecutor +from comfy.model_patcher import ModelPatcher +from comfy.patcher_extension import WrapperExecutor import comfy.conds import comfy.ops From 0fb102e4c0003f87c8acb534aba51f5c98d4798a Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Mon, 11 Nov 2024 07:30:05 -0600 Subject: [PATCH 19/43] Modified FunctionInjectionHolder to work more smoothily with new ComfyUI memory management system --- animatediff/sampling.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/animatediff/sampling.py b/animatediff/sampling.py index 7dcc4a6..2730e16 100644 --- a/animatediff/sampling.py +++ b/animatediff/sampling.py @@ -232,16 +232,18 @@ def __init__(self): self.temp_uninjector: GroupnormUninjectHelper = GroupnormUninjectHelper() self.groupnorm_injector: GroupnormInjectHelper = GroupnormInjectHelper() - def inject_functions(self, helper: ModelPatcherHelper, params: InjectionParams): + def inject_functions(self, helper: ModelPatcherHelper, params: InjectionParams, model_options: dict): # Save Original Functions - order must match between here and restore_functions - self.orig_memory_required = helper.model.model.memory_required # allows for "unlimited area hack" to prevent halving of conds/unconds + self.orig_memory_required = None self.orig_groupnorm_forward = torch.nn.GroupNorm.forward # used to normalize latents to remove "flickering" of colors/brightness between frames self.orig_groupnorm_forward_comfy_cast_weights = comfy.ops.disable_weight_init.GroupNorm.forward_comfy_cast_weights - self.orig_diffusion_model_forward = helper.model.model.diffusion_model.forward + self.orig_diffusion_model_forward = None self.orig_sampling_function = comfy.samplers.sampling_function # used to support sliding context windows in samplers - self.orig_apply_model = helper.model.model.apply_model + self.orig_apply_model = None # Inject Functions if params.unlimited_area_hack: + # allows for "unlimited area hack" to prevent halving of conds/unconds + self.orig_memory_required = helper.model.model.memory_required helper.model.model.memory_required = unlimited_memory_required if helper.get_motion_models(): # only apply groupnorm hack if PIA, v2 and not properly applied, or v1 @@ -252,16 +254,19 @@ def inject_functions(self, helper: ModelPatcherHelper, params: InjectionParams): self.inject_groupnorm_forward = groupnorm_mm_factory(params) self.inject_groupnorm_forward_comfy_cast_weights = groupnorm_mm_factory(params, manual_cast=True) self.groupnorm_injector = GroupnormInjectHelper(self) + self.orig_diffusion_model_forward = helper.model.model.diffusion_model.forward helper.model.model.diffusion_model.forward = diffusion_model_forward_groupnormed_factory(self.orig_diffusion_model_forward, self.groupnorm_injector) # if mps device (Apple Silicon), disable batched conds to avoid black images with groupnorm hack try: if helper.model.load_device.type == "mps": + self.orig_memory_required = helper.model.model.memory_required helper.model.model.memory_required = unlimited_memory_required except Exception: pass # if img_encoder or camera_encoder present, inject apply_model to handle correctly for motion_model in helper.get_motion_models(): if (motion_model.model.img_encoder is not None) or (motion_model.model.camera_encoder is not None): + self.orig_apply_model = helper.model.model.apply_model helper.model.model.apply_model = apply_model_factory(self.orig_apply_model).__get__(helper.model.model, type(helper.model.model)) break del info @@ -272,12 +277,15 @@ def inject_functions(self, helper: ModelPatcherHelper, params: InjectionParams): def restore_functions(self, helper: ModelPatcherHelper): # Restoration try: - helper.model.model.memory_required = self.orig_memory_required + if self.orig_memory_required is not None: + helper.model.model.memory_required = self.orig_memory_required torch.nn.GroupNorm.forward = self.orig_groupnorm_forward comfy.ops.disable_weight_init.GroupNorm.forward_comfy_cast_weights = self.orig_groupnorm_forward_comfy_cast_weights - helper.model.model.diffusion_model.forward = self.orig_diffusion_model_forward + if self.orig_diffusion_model_forward is not None: + helper.model.model.diffusion_model.forward = self.orig_diffusion_model_forward comfy.samplers.sampling_function = self.orig_sampling_function - helper.model.model.apply_model = self.orig_apply_model + if self.orig_apply_model is not None: + helper.model.model.apply_model = self.orig_apply_model except AttributeError: logger.error("Encountered AttributeError while attempting to restore functions - likely, an error occured while trying " + \ "to save original functions before injection, and a more specific error was thrown by ComfyUI.") @@ -372,7 +380,7 @@ def outer_sample_wrapper(executor: WrapperExecutor, *args, **kwargs): params = apply_params_to_motion_models(helper, params) # store and inject funtions - function_injections.inject_functions(helper, params) + function_injections.inject_functions(helper, params, guider.model_options) # prepare noise_extra_args for noise generation purposes noise_extra_args = {"disable_noise": disable_noise} From f32a382238979d713db96e5d08cd0775da4165cc Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Tue, 12 Nov 2024 09:06:48 -0600 Subject: [PATCH 20/43] Refactored nodes_conditioning.py to use new built-in ComfyUI hook features (will deprecate almost all these nodes soon after testing) --- animatediff/model_injection.py | 1 + animatediff/nodes_conditioning.py | 235 ++++++++++++++---------------- 2 files changed, 109 insertions(+), 127 deletions(-) diff --git a/animatediff/model_injection.py b/animatediff/model_injection.py index 71fdeef..8eceabe 100644 --- a/animatediff/model_injection.py +++ b/animatediff/model_injection.py @@ -511,6 +511,7 @@ def is_fancyvideo(self): return self.model.mm_info.mm_format == AnimateDiffFormat.FANCYVIDEO def cleanup(self): + super().cleanup() if self.model is not None: self.model.cleanup() # AnimateLCM-I2V diff --git a/animatediff/nodes_conditioning.py b/animatediff/nodes_conditioning.py index 6dfda75..3e2db0d 100644 --- a/animatediff/nodes_conditioning.py +++ b/animatediff/nodes_conditioning.py @@ -7,10 +7,12 @@ from comfy.model_patcher import ModelPatcher from comfy.sd import CLIP import comfy.sd +from comfy.hooks import HookGroup, HookKeyframeGroup, HookKeyframe +import comfy_extras.nodes_hooks +import comfy.hooks import comfy.utils -from .conditioning import (COND_CONST, TimestepsCond, set_mask_conds, set_mask_and_combine_conds, set_unmasked_and_combine_conds, - LoraHook, LoraHookGroup, LoraHookKeyframe, LoraHookKeyframeGroup) +from .conditioning import (COND_CONST) from .utils_model import BIGMAX, InterpolationMethod from .logger import logger @@ -30,8 +32,8 @@ def INPUT_TYPES(s): }, "optional": { "opt_mask": ("MASK", ), - "opt_lora_hook": ("LORA_HOOK",), - "opt_timesteps": ("TIMESTEPS_COND",), + "opt_lora_hook": ("HOOKS",), + "opt_timesteps": ("TIMESTEPS_RANGE",), "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -43,10 +45,10 @@ def INPUT_TYPES(s): def append_and_hook(self, positive_ADD, negative_ADD, strength: float, set_cond_area: str, - opt_mask: Tensor=None, opt_lora_hook: LoraHookGroup=None, opt_timesteps: TimestepsCond=None): - final_positive, final_negative = set_mask_conds(conds=[positive_ADD, negative_ADD], + opt_mask: Tensor=None, opt_lora_hook: HookGroup=None, opt_timesteps: tuple=None): + final_positive, final_negative = comfy.hooks.set_mask_conds(conds=[positive_ADD, negative_ADD], strength=strength, set_cond_area=set_cond_area, - opt_mask=opt_mask, opt_lora_hook=opt_lora_hook, opt_timesteps=opt_timesteps) + opt_mask=opt_mask, opt_hooks=opt_lora_hook, opt_timestep_range=opt_timesteps) return (final_positive, final_negative) @@ -61,8 +63,8 @@ def INPUT_TYPES(s): }, "optional": { "opt_mask": ("MASK", ), - "opt_lora_hook": ("LORA_HOOK",), - "opt_timesteps": ("TIMESTEPS_COND",), + "opt_lora_hook": ("HOOKS",), + "opt_timesteps": ("TIMESTEPS_RANGE",), "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -73,10 +75,10 @@ def INPUT_TYPES(s): def append_and_hook(self, cond_ADD, strength: float, set_cond_area: str, - opt_mask: Tensor=None, opt_lora_hook: LoraHookGroup=None, opt_timesteps: TimestepsCond=None): - (final_conditioning,) = set_mask_conds(conds=[cond_ADD], + opt_mask: Tensor=None, opt_lora_hook: HookGroup=None, opt_timesteps: tuple=None): + (final_conditioning,) = comfy.hooks.set_mask_conds(conds=[cond_ADD], strength=strength, set_cond_area=set_cond_area, - opt_mask=opt_mask, opt_lora_hook=opt_lora_hook, opt_timesteps=opt_timesteps) + opt_mask=opt_mask, opt_hooks=opt_lora_hook, opt_timestep_range=opt_timesteps) return (final_conditioning,) @@ -94,8 +96,8 @@ def INPUT_TYPES(s): }, "optional": { "opt_mask": ("MASK", ), - "opt_lora_hook": ("LORA_HOOK",), - "opt_timesteps": ("TIMESTEPS_COND",), + "opt_lora_hook": ("HOOKS",), + "opt_timesteps": ("TIMESTEPS_RANGE",), "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -107,10 +109,10 @@ def INPUT_TYPES(s): def append_and_combine(self, positive, negative, positive_ADD, negative_ADD, strength: float, set_cond_area: str, - opt_mask: Tensor=None, opt_lora_hook: LoraHookGroup=None, opt_timesteps: TimestepsCond=None): - final_positive, final_negative = set_mask_and_combine_conds(conds=[positive, negative], new_conds=[positive_ADD, negative_ADD], + opt_mask: Tensor=None, opt_lora_hook: HookGroup=None, opt_timesteps: tuple=None): + final_positive, final_negative = comfy.hooks.set_mask_and_combine_conds(conds=[positive, negative], new_conds=[positive_ADD, negative_ADD], strength=strength, set_cond_area=set_cond_area, - opt_mask=opt_mask, opt_lora_hook=opt_lora_hook, opt_timesteps=opt_timesteps) + opt_mask=opt_mask, opt_hooks=opt_lora_hook, opt_timestep_range=opt_timesteps) return (final_positive, final_negative,) @@ -126,8 +128,8 @@ def INPUT_TYPES(s): }, "optional": { "opt_mask": ("MASK", ), - "opt_lora_hook": ("LORA_HOOK",), - "opt_timesteps": ("TIMESTEPS_COND",), + "opt_lora_hook": ("HOOKS",), + "opt_timesteps": ("TIMESTEPS_RANGE",), "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -138,10 +140,10 @@ def INPUT_TYPES(s): def append_and_combine(self, cond, cond_ADD, strength: float, set_cond_area: str, - opt_mask: Tensor=None, opt_lora_hook: LoraHookGroup=None, opt_timesteps: TimestepsCond=None): - (final_conditioning,) = set_mask_and_combine_conds(conds=[cond], new_conds=[cond_ADD], + opt_mask: Tensor=None, opt_lora_hook: HookGroup=None, opt_timesteps: tuple=None): + (final_conditioning,) = comfy.hooks.set_mask_and_combine_conds(conds=[cond], new_conds=[cond_ADD], strength=strength, set_cond_area=set_cond_area, - opt_mask=opt_mask, opt_lora_hook=opt_lora_hook, opt_timesteps=opt_timesteps) + opt_mask=opt_mask, opt_hooks=opt_lora_hook, opt_timestep_range=opt_timesteps) return (final_conditioning,) @@ -156,7 +158,7 @@ def INPUT_TYPES(s): "negative_DEFAULT": ("CONDITIONING",), }, "optional": { - "opt_lora_hook": ("LORA_HOOK",), + "opt_lora_hook": ("HOOKS",), "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -167,9 +169,9 @@ def INPUT_TYPES(s): FUNCTION = "append_and_combine" def append_and_combine(self, positive, negative, positive_DEFAULT, negative_DEFAULT, - opt_lora_hook: LoraHookGroup=None): - final_positive, final_negative = set_unmasked_and_combine_conds(conds=[positive, negative], new_conds=[positive_DEFAULT, negative_DEFAULT], - opt_lora_hook=opt_lora_hook) + opt_lora_hook: HookGroup=None): + final_positive, final_negative = comfy.hooks.set_default_and_combine_conds(conds=[positive, negative], new_conds=[positive_DEFAULT, negative_DEFAULT], + opt_hooks=opt_lora_hook) return (final_positive, final_negative,) @@ -182,7 +184,7 @@ def INPUT_TYPES(s): "cond_DEFAULT": ("CONDITIONING",), }, "optional": { - "opt_lora_hook": ("LORA_HOOK",), + "opt_lora_hook": ("HOOKS",), "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -192,9 +194,9 @@ def INPUT_TYPES(s): FUNCTION = "append_and_combine" def append_and_combine(self, cond, cond_DEFAULT, - opt_lora_hook: LoraHookGroup=None): - (final_conditioning,) = set_unmasked_and_combine_conds(conds=[cond], new_conds=[cond_DEFAULT], - opt_lora_hook=opt_lora_hook) + opt_lora_hook: HookGroup=None): + (final_conditioning,) = comfy.hooks.set_default_and_combine_conds(conds=[cond], new_conds=[cond_DEFAULT], + opt_hooks=opt_lora_hook) return (final_conditioning,) @@ -216,7 +218,7 @@ def INPUT_TYPES(s): FUNCTION = "combine" def combine(self, positive_A, negative_A, positive_B, negative_B): - final_positive, final_negative = set_mask_and_combine_conds(conds=[positive_A, negative_A], new_conds=[positive_B, negative_B],) + final_positive, final_negative = comfy.hooks.set_mask_and_combine_conds(conds=[positive_A, negative_A], new_conds=[positive_B, negative_B],) return (final_positive, final_negative,) @@ -235,7 +237,7 @@ def INPUT_TYPES(s): FUNCTION = "combine" def combine(self, cond_A, cond_B): - (final_conditioning,) = set_mask_and_combine_conds(conds=[cond_A], new_conds=[cond_B],) + (final_conditioning,) = comfy.hooks.set_mask_and_combine_conds(conds=[cond_A], new_conds=[cond_B],) return (final_conditioning,) ############################################### ############################################### @@ -259,12 +261,12 @@ def INPUT_TYPES(s): } } - RETURN_TYPES = ("TIMESTEPS_COND",) + RETURN_TYPES = ("TIMESTEPS_RANGE",) CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/conditioning" FUNCTION = "create_schedule" def create_schedule(self, start_percent: float, end_percent: float): - return (TimestepsCond(start_percent=start_percent, end_percent=end_percent),) + return ((start_percent, end_percent),) class SetLoraHookKeyframes: @@ -272,19 +274,19 @@ class SetLoraHookKeyframes: def INPUT_TYPES(s): return { "required": { - "lora_hook": ("LORA_HOOK",), - "hook_kf": ("LORA_HOOK_KEYFRAMES",), + "lora_hook": ("HOOKS",), + "hook_kf": ("HOOK_KEYFRAMES",), }, "optional": { "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } - RETURN_TYPES = ("LORA_HOOK",) + RETURN_TYPES = ("HOOKS",) CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/conditioning" FUNCTION = "set_hook_keyframes" - def set_hook_keyframes(self, lora_hook: LoraHookGroup, hook_kf: LoraHookKeyframeGroup): + def set_hook_keyframes(self, lora_hook: HookGroup, hook_kf: HookKeyframeGroup): new_lora_hook = lora_hook.clone() new_lora_hook.set_keyframes_on_hooks(hook_kf=hook_kf) return (new_lora_hook,) @@ -300,23 +302,23 @@ def INPUT_TYPES(s): "guarantee_steps": ("INT", {"default": 1, "min": 0, "max": BIGMAX}), }, "optional": { - "prev_hook_kf": ("LORA_HOOK_KEYFRAMES",), + "prev_hook_kf": ("HOOK_KEYFRAMES",), "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } - RETURN_TYPES = ("LORA_HOOK_KEYFRAMES",) + RETURN_TYPES = ("HOOK_KEYFRAMES",) RETURN_NAMES = ("HOOK_KF",) CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/conditioning/schedule lora hooks" FUNCTION = "create_hook_keyframe" def create_hook_keyframe(self, strength_model: float, start_percent: float, guarantee_steps: float, - prev_hook_kf: LoraHookKeyframeGroup=None): + prev_hook_kf: HookKeyframeGroup=None): if prev_hook_kf: prev_hook_kf = prev_hook_kf.clone() else: - prev_hook_kf = LoraHookKeyframeGroup() - keyframe = LoraHookKeyframe(strength=strength_model, start_percent=start_percent, guarantee_steps=guarantee_steps) + prev_hook_kf = HookKeyframeGroup() + keyframe = HookKeyframe(strength=strength_model, start_percent=start_percent, guarantee_steps=guarantee_steps) prev_hook_kf.add(keyframe) return (prev_hook_kf,) @@ -335,12 +337,12 @@ def INPUT_TYPES(s): "print_keyframes": ("BOOLEAN", {"default": False}), }, "optional": { - "prev_hook_kf": ("LORA_HOOK_KEYFRAMES",), + "prev_hook_kf": ("HOOK_KEYFRAMES",), "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } - RETURN_TYPES = ("LORA_HOOK_KEYFRAMES",) + RETURN_TYPES = ("HOOK_KEYFRAMES",) RETURN_NAMES = ("HOOK_KF",) CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/conditioning/schedule lora hooks" FUNCTION = "create_hook_keyframes" @@ -348,11 +350,11 @@ def INPUT_TYPES(s): def create_hook_keyframes(self, start_percent: float, end_percent: float, strength_start: float, strength_end: float, interpolation: str, intervals: int, - prev_hook_kf: LoraHookKeyframeGroup=None, print_keyframes=False): + prev_hook_kf: HookKeyframeGroup=None, print_keyframes=False): if prev_hook_kf: prev_hook_kf = prev_hook_kf.clone() else: - prev_hook_kf = LoraHookKeyframeGroup() + prev_hook_kf = HookKeyframeGroup() percents = InterpolationMethod.get_weights(num_from=start_percent, num_to=end_percent, length=intervals, method=InterpolationMethod.LINEAR) strengths = InterpolationMethod.get_weights(num_from=strength_start, num_to=strength_end, length=intervals, method=interpolation) @@ -362,9 +364,9 @@ def create_hook_keyframes(self, if is_first: guarantee_steps = 1 is_first = False - prev_hook_kf.add(LoraHookKeyframe(strength=strength, start_percent=percent, guarantee_steps=guarantee_steps)) + prev_hook_kf.add(HookKeyframe(strength=strength, start_percent=percent, guarantee_steps=guarantee_steps)) if print_keyframes: - logger.info(f"LoraHookKeyframe - start_percent:{percent} = {strength}") + logger.info(f"HookKeyframe - start_percent:{percent} = {strength}") return (prev_hook_kf,) @@ -379,22 +381,22 @@ def INPUT_TYPES(s): "print_keyframes": ("BOOLEAN", {"default": False}), }, "optional": { - "prev_hook_kf": ("LORA_HOOK_KEYFRAMES",), + "prev_hook_kf": ("HOOK_KEYFRAMES",), } } - RETURN_TYPES = ("LORA_HOOK_KEYFRAMES",) + RETURN_TYPES = ("HOOK_KEYFRAMES",) RETURN_NAMES = ("HOOK_KF",) CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/conditioning/schedule lora hooks" FUNCTION = "create_hook_keyframes" def create_hook_keyframes(self, strengths_float: Union[float, list[float]], start_percent: float, end_percent: float, - prev_hook_kf: LoraHookKeyframeGroup=None, print_keyframes=False): + prev_hook_kf: HookKeyframeGroup=None, print_keyframes=False): if prev_hook_kf: prev_hook_kf = prev_hook_kf.clone() else: - prev_hook_kf = LoraHookKeyframeGroup() + prev_hook_kf = HookKeyframeGroup() if type(strengths_float) in (float, int): strengths_float = [float(strengths_float)] elif isinstance(strengths_float, Iterable): @@ -409,9 +411,9 @@ def create_hook_keyframes(self, strengths_float: Union[float, list[float]], if is_first: guarantee_steps = 1 is_first = False - prev_hook_kf.add(LoraHookKeyframe(strength=strength, start_percent=percent, guarantee_steps=guarantee_steps)) + prev_hook_kf.add(HookKeyframe(strength=strength, start_percent=percent, guarantee_steps=guarantee_steps)) if print_keyframes: - logger.info(f"LoraHookKeyframe - start_percent:{percent} = {strength}") + logger.info(f"HookKeyframe - start_percent:{percent} = {strength}") return (prev_hook_kf,) ############################################### ############################################### @@ -438,13 +440,13 @@ def INPUT_TYPES(s): } } - RETURN_TYPES = ("MODEL", "CLIP", "LORA_HOOK") + RETURN_TYPES = ("MODEL", "CLIP", "HOOKS") CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/conditioning/register lora hooks" FUNCTION = "load_lora" def load_lora(self, model: Union[ModelPatcher], clip: CLIP, lora_name: str, strength_model: float, strength_clip: float): if strength_model == 0 and strength_clip == 0: - return (model, clip) + return (model, clip, None) lora_path = folder_paths.get_full_path("loras", lora_name) lora = None @@ -459,13 +461,10 @@ def load_lora(self, model: Union[ModelPatcher], clip: CLIP, lora_name: str, stre if lora is None: lora = comfy.utils.load_torch_file(lora_path, safe_load=True) self.loaded_lora = (lora_path, lora) - - lora_hook = LoraHook(lora_name=lora_name) - lora_hook_group = LoraHookGroup() - lora_hook_group.add(lora_hook) - model_lora, clip_lora = load_hooked_lora_for_models(model=model, clip=clip, lora=lora, lora_hook=lora_hook, - strength_model=strength_model, strength_clip=strength_clip) - return (model_lora, clip_lora, lora_hook_group) + + model_lora, clip_lora, hooks = comfy.hooks.load_hook_lora_for_models(model=model, clip=clip, lora=lora, + strength_model=strength_model, strength_clip=strength_clip) + return (model_lora, clip_lora, hooks) class MaskableLoraLoaderModelOnly(MaskableLoraLoader): @@ -479,17 +478,17 @@ def INPUT_TYPES(s): } } - RETURN_TYPES = ("MODEL", "LORA_HOOK") + RETURN_TYPES = ("MODEL", "HOOKS") CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/conditioning/register lora hooks" FUNCTION = "load_lora_model_only" def load_lora_model_only(self, model: ModelPatcher, lora_name: str, strength_model: float): - model_lora, clip_lora, lora_hook = self.load_lora(model=model, clip=None, lora_name=lora_name, - strength_model=strength_model, strength_clip=0) - return (model_lora, lora_hook) + model_lora, _, hooks = self.load_lora(model=model, clip=None, lora_name=lora_name, + strength_model=strength_model, strength_clip=0) + return (model_lora, hooks) -class MaskableSDModelLoader: +class MaskableSDModelLoader(comfy_extras.nodes_hooks.CreateHookModelAsLora): @classmethod def INPUT_TYPES(s): return { @@ -502,24 +501,13 @@ def INPUT_TYPES(s): } } - RETURN_TYPES = ("MODEL", "CLIP", "LORA_HOOK") + RETURN_TYPES = ("MODEL", "CLIP", "HOOKS") CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/conditioning/register lora hooks" FUNCTION = "load_model_as_lora" def load_model_as_lora(self, model: ModelPatcher, clip: CLIP, ckpt_name: str, strength_model: float, strength_clip: float): - ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) - out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings")) - model_loaded = out[0] - clip_loaded = out[1] - - lora_hook = LoraHook(lora_name=ckpt_name) - lora_hook_group = LoraHookGroup() - lora_hook_group.add(lora_hook) - model_lora, clip_lora = load_model_as_hooked_lora_for_models(model=model, clip=clip, - model_loaded=model_loaded, clip_loaded=clip_loaded, - lora_hook=lora_hook, - strength_model=strength_model, strength_clip=strength_clip) - return (model_lora, clip_lora, lora_hook_group) + returned = self.create_hook(ckpt_name=ckpt_name, strength_model=strength_model, strength_clip=strength_clip) + return (model.clone(), clip.clone(), returned[0]) class MaskableSDModelLoaderModelOnly(MaskableSDModelLoader): @@ -533,14 +521,14 @@ def INPUT_TYPES(s): } } - RETURN_TYPES = ("MODEL", "LORA_HOOK") + RETURN_TYPES = ("MODEL", "HOOKS") CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/conditioning/register lora hooks" FUNCTION = "load_model_as_lora_model_only" def load_model_as_lora_model_only(self, model: ModelPatcher, ckpt_name: str, strength_model: float): - model_lora, clip_lora, lora_hook = self.load_model_as_lora(model=model, clip=None, ckpt_name=ckpt_name, - strength_model=strength_model, strength_clip=0) - return (model_lora, lora_hook) + model_lora, _, hooks = self.load_model_as_lora(model=model, clip=None, ckpt_name=ckpt_name, + strength_model=strength_model, strength_clip=0) + return (model_lora, hooks) ############################################### ############################################### ############################################### @@ -556,7 +544,7 @@ def INPUT_TYPES(s): return { "required": { "conditioning": ("CONDITIONING",), - "lora_hook": ("LORA_HOOK",), + "lora_hook": ("HOOKS",), }, "optional": { "autosize": ("ADEAUTOSIZE", {"padding": 0}), @@ -567,13 +555,8 @@ def INPUT_TYPES(s): CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/conditioning/single cond ops" FUNCTION = "attach_lora_hook" - def attach_lora_hook(self, conditioning, lora_hook: LoraHookGroup): - c = [] - for t in conditioning: - n = [t[0], t[1].copy()] - n[1]["lora_hook"] = lora_hook - c.append(n) - return (c, ) + def attach_lora_hook(self, conditioning, lora_hook: HookGroup): + return (comfy.hooks.set_hooks_for_conditioning(conditioning, lora_hook),) class SetClipLoraHook: @@ -582,7 +565,7 @@ def INPUT_TYPES(s): return { "required": { "clip": ("CLIP",), - "lora_hook": ("LORA_HOOK",), + "lora_hook": ("HOOKS",), }, "optional": { "autosize": ("ADEAUTOSIZE", {"padding": 0}), @@ -594,10 +577,8 @@ def INPUT_TYPES(s): CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/conditioning" FUNCTION = "apply_lora_hook" - def apply_lora_hook(self, clip: CLIP, lora_hook: LoraHookGroup): - new_clip = CLIPWithHooks(clip) - new_clip.set_desired_hooks(lora_hooks=lora_hook) - return (new_clip, ) + def apply_lora_hook(self, clip: CLIP, lora_hook: HookGroup): + return comfy_extras.nodes_hooks.SetClipHooks.apply_hooks(self, clip, False, lora_hook) class CombineLoraHooks: @@ -607,19 +588,19 @@ def INPUT_TYPES(s): "required": { }, "optional": { - "lora_hook_A": ("LORA_HOOK",), - "lora_hook_B": ("LORA_HOOK",), + "lora_hook_A": ("HOOKS",), + "lora_hook_B": ("HOOKS",), "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } - RETURN_TYPES = ("LORA_HOOK",) + RETURN_TYPES = ("HOOKS",) CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/conditioning/combine lora hooks" FUNCTION = "combine_lora_hooks" - def combine_lora_hooks(self, lora_hook_A: LoraHookGroup=None, lora_hook_B: LoraHookGroup=None): + def combine_lora_hooks(self, lora_hook_A: HookGroup=None, lora_hook_B: HookGroup=None): candidates = [lora_hook_A, lora_hook_B] - return (LoraHookGroup.combine_all_lora_hooks(candidates),) + return (HookGroup.combine_all_hooks(candidates),) class CombineLoraHookFourOptional: @@ -629,23 +610,23 @@ def INPUT_TYPES(s): "required": { }, "optional": { - "lora_hook_A": ("LORA_HOOK",), - "lora_hook_B": ("LORA_HOOK",), - "lora_hook_C": ("LORA_HOOK",), - "lora_hook_D": ("LORA_HOOK",), + "lora_hook_A": ("HOOKS",), + "lora_hook_B": ("HOOKS",), + "lora_hook_C": ("HOOKS",), + "lora_hook_D": ("HOOKS",), "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } - RETURN_TYPES = ("LORA_HOOK",) + RETURN_TYPES = ("HOOKS",) CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/conditioning/combine lora hooks" FUNCTION = "combine_lora_hooks" def combine_lora_hooks(self, - lora_hook_A: LoraHookGroup=None, lora_hook_B: LoraHookGroup=None, - lora_hook_C: LoraHookGroup=None, lora_hook_D: LoraHookGroup=None,): + lora_hook_A: HookGroup=None, lora_hook_B: HookGroup=None, + lora_hook_C: HookGroup=None, lora_hook_D: HookGroup=None,): candidates = [lora_hook_A, lora_hook_B, lora_hook_C, lora_hook_D] - return (LoraHookGroup.combine_all_lora_hooks(candidates),) + return (HookGroup.combine_all_hooks(candidates),) class CombineLoraHookEightOptional: @@ -655,30 +636,30 @@ def INPUT_TYPES(s): "required": { }, "optional": { - "lora_hook_A": ("LORA_HOOK",), - "lora_hook_B": ("LORA_HOOK",), - "lora_hook_C": ("LORA_HOOK",), - "lora_hook_D": ("LORA_HOOK",), - "lora_hook_E": ("LORA_HOOK",), - "lora_hook_F": ("LORA_HOOK",), - "lora_hook_G": ("LORA_HOOK",), - "lora_hook_H": ("LORA_HOOK",), + "lora_hook_A": ("HOOKS",), + "lora_hook_B": ("HOOKS",), + "lora_hook_C": ("HOOKS",), + "lora_hook_D": ("HOOKS",), + "lora_hook_E": ("HOOKS",), + "lora_hook_F": ("HOOKS",), + "lora_hook_G": ("HOOKS",), + "lora_hook_H": ("HOOKS",), "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } - RETURN_TYPES = ("LORA_HOOK",) + RETURN_TYPES = ("HOOKS",) CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/conditioning/combine lora hooks" FUNCTION = "combine_lora_hooks" def combine_lora_hooks(self, - lora_hook_A: LoraHookGroup=None, lora_hook_B: LoraHookGroup=None, - lora_hook_C: LoraHookGroup=None, lora_hook_D: LoraHookGroup=None, - lora_hook_E: LoraHookGroup=None, lora_hook_F: LoraHookGroup=None, - lora_hook_G: LoraHookGroup=None, lora_hook_H: LoraHookGroup=None): + lora_hook_A: HookGroup=None, lora_hook_B: HookGroup=None, + lora_hook_C: HookGroup=None, lora_hook_D: HookGroup=None, + lora_hook_E: HookGroup=None, lora_hook_F: HookGroup=None, + lora_hook_G: HookGroup=None, lora_hook_H: HookGroup=None): candidates = [lora_hook_A, lora_hook_B, lora_hook_C, lora_hook_D, lora_hook_E, lora_hook_F, lora_hook_G, lora_hook_H] - return (LoraHookGroup.combine_all_lora_hooks(candidates),) + return (HookGroup.combine_all_hooks(candidates),) # NOTE: if at some point I add more Javascript stuff to this repo, there should be a combine node # that dynamically increases the hooks available to plug in on the node From bac7a818bb7e7513a69830fdd631b58548892e6a Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Tue, 12 Nov 2024 09:28:49 -0600 Subject: [PATCH 21/43] Deprecated most of nodes_conditioning nodes since they exist in vanilla ComgyUI --- animatediff/nodes.py | 79 ++++++------ animatediff/nodes_conditioning.py | 202 ++++++++++++++++++++---------- animatediff/nodes_deprecated.py | 14 +-- 3 files changed, 181 insertions(+), 114 deletions(-) diff --git a/animatediff/nodes.py b/animatediff/nodes.py index 17cf97c..f61aa28 100644 --- a/animatediff/nodes.py +++ b/animatediff/nodes.py @@ -11,15 +11,16 @@ from .nodes_pia import (ApplyAnimateDiffPIAModel, LoadAnimateDiffAndInjectPIANode, InputPIA_MultivalNode, InputPIA_PaperPresetsNode, PIA_ADKeyframeNode) from .nodes_fancyvideo import (ApplyAnimateDiffFancyVideo,) from .nodes_multival import MultivalDynamicNode, MultivalScaledMaskNode, MultivalDynamicFloatInputNode, MultivalDynamicFloatsNode, MultivalConvertToMaskNode -from .nodes_conditioning import (MaskableLoraLoader, MaskableLoraLoaderModelOnly, MaskableSDModelLoader, MaskableSDModelLoaderModelOnly, - SetModelLoraHook, SetClipLoraHook, - CombineLoraHooks, CombineLoraHookFourOptional, CombineLoraHookEightOptional, - PairedConditioningSetMaskHooked, ConditioningSetMaskHooked, - PairedConditioningSetMaskAndCombineHooked, ConditioningSetMaskAndCombineHooked, - PairedConditioningSetUnmaskedAndCombineHooked, ConditioningSetUnmaskedAndCombineHooked, - PairedConditioningCombine, ConditioningCombine, - ConditioningTimestepsNode, SetLoraHookKeyframes, - CreateLoraHookKeyframe, CreateLoraHookKeyframeInterpolation, CreateLoraHookKeyframeFromStrengthList) +from .nodes_conditioning import (CreateLoraHookKeyframeInterpolation, + MaskableLoraLoaderDEPR, MaskableLoraLoaderModelOnlyDEPR, MaskableSDModelLoaderDEPR, MaskableSDModelLoaderModelOnlyDEPR, + SetModelLoraHookDEPR, SetClipLoraHookDEPR, + CombineLoraHooksDEPR, CombineLoraHookFourOptionalDEPR, CombineLoraHookEightOptionalDEPR, + PairedConditioningSetMaskHookedDEPR, ConditioningSetMaskHookedDEPR, + PairedConditioningSetMaskAndCombineHookedDEPR, ConditioningSetMaskAndCombineHookedDEPR, + PairedConditioningSetUnmaskedAndCombineHookedDEPR, ConditioningSetUnmaskedAndCombineHookedDEPR, + PairedConditioningCombineDEPR, ConditioningCombineDEPR, + ConditioningTimestepsNodeDEPR, SetLoraHookKeyframesDEPR, + CreateLoraHookKeyframeDEPR, CreateLoraHookKeyframeFromStrengthListDEPR) from .nodes_sample import (FreeInitOptionsNode, NoiseLayerAddWeightedNode, SampleSettingsNode, NoiseLayerAddNode, NoiseLayerReplaceNode, IterationOptionsNode, CustomCFGNode, CustomCFGSimpleNode, CustomCFGKeyframeNode, CustomCFGKeyframeSimpleNode, CustomCFGKeyframeInterpolationNode, CustomCFGKeyframeFromListNode, CFGExtrasPAGNode, CFGExtrasPAGSimpleNode, CFGExtrasRescaleCFGNode, CFGExtrasRescaleCFGSimpleNode, @@ -42,8 +43,8 @@ PerBlock_SD15_LowLevelNode, PerBlock_SD15_MidLevelNode, PerBlock_SD15_FromFloatsNode, PerBlock_SDXL_LowLevelNode, PerBlock_SDXL_MidLevelNode, PerBlock_SDXL_FromFloatsNode) from .nodes_extras import AnimateDiffUnload, EmptyLatentImageLarge, CheckpointLoaderSimpleWithNoiseSelect, PerturbedAttentionGuidanceMultival, RescaleCFGMultival -from .nodes_deprecated import (AnimateDiffLoader_Deprecated, AnimateDiffLoaderAdvanced_Deprecated, AnimateDiffCombine_Deprecated, - AnimateDiffModelSettings, AnimateDiffModelSettingsSimple, AnimateDiffModelSettingsAdvanced, AnimateDiffModelSettingsAdvancedAttnStrengths) +from .nodes_deprecated import (AnimateDiffLoaderDEPR, AnimateDiffLoaderAdvancedDEPR, AnimateDiffCombineDEPR, + AnimateDiffModelSettingsDEPR, AnimateDiffModelSettingsSimpleDEPR, AnimateDiffModelSettingsAdvancedDEPR, AnimateDiffModelSettingsAdvancedAttnStrengthsDEPR) from .nodes_lora import AnimateDiffLoraLoader from .logger import logger @@ -97,28 +98,28 @@ "ADE_IterationOptsDefault": IterationOptionsNode, "ADE_IterationOptsFreeInit": FreeInitOptionsNode, # Conditioning - "ADE_RegisterLoraHook": MaskableLoraLoader, - "ADE_RegisterLoraHookModelOnly": MaskableLoraLoaderModelOnly, - "ADE_RegisterModelAsLoraHook": MaskableSDModelLoader, - "ADE_RegisterModelAsLoraHookModelOnly": MaskableSDModelLoaderModelOnly, - "ADE_CombineLoraHooks": CombineLoraHooks, - "ADE_CombineLoraHooksFour": CombineLoraHookFourOptional, - "ADE_CombineLoraHooksEight": CombineLoraHookEightOptional, - "ADE_SetLoraHookKeyframe": SetLoraHookKeyframes, - "ADE_AttachLoraHookToCLIP": SetClipLoraHook, - "ADE_LoraHookKeyframe": CreateLoraHookKeyframe, + "ADE_RegisterLoraHook": MaskableLoraLoaderDEPR, + "ADE_RegisterLoraHookModelOnly": MaskableLoraLoaderModelOnlyDEPR, + "ADE_RegisterModelAsLoraHook": MaskableSDModelLoaderDEPR, + "ADE_RegisterModelAsLoraHookModelOnly": MaskableSDModelLoaderModelOnlyDEPR, + "ADE_CombineLoraHooks": CombineLoraHooksDEPR, + "ADE_CombineLoraHooksFour": CombineLoraHookFourOptionalDEPR, + "ADE_CombineLoraHooksEight": CombineLoraHookEightOptionalDEPR, + "ADE_SetLoraHookKeyframe": SetLoraHookKeyframesDEPR, + "ADE_AttachLoraHookToCLIP": SetClipLoraHookDEPR, + "ADE_LoraHookKeyframe": CreateLoraHookKeyframeDEPR, "ADE_LoraHookKeyframeInterpolation": CreateLoraHookKeyframeInterpolation, - "ADE_LoraHookKeyframeFromStrengthList": CreateLoraHookKeyframeFromStrengthList, - "ADE_AttachLoraHookToConditioning": SetModelLoraHook, - "ADE_PairedConditioningSetMask": PairedConditioningSetMaskHooked, - "ADE_ConditioningSetMask": ConditioningSetMaskHooked, - "ADE_PairedConditioningSetMaskAndCombine": PairedConditioningSetMaskAndCombineHooked, - "ADE_ConditioningSetMaskAndCombine": ConditioningSetMaskAndCombineHooked, - "ADE_PairedConditioningSetUnmaskedAndCombine": PairedConditioningSetUnmaskedAndCombineHooked, - "ADE_ConditioningSetUnmaskedAndCombine": ConditioningSetUnmaskedAndCombineHooked, - "ADE_PairedConditioningCombine": PairedConditioningCombine, - "ADE_ConditioningCombine": ConditioningCombine, - "ADE_TimestepsConditioning": ConditioningTimestepsNode, + "ADE_LoraHookKeyframeFromStrengthList": CreateLoraHookKeyframeFromStrengthListDEPR, + "ADE_AttachLoraHookToConditioning": SetModelLoraHookDEPR, + "ADE_PairedConditioningSetMask": PairedConditioningSetMaskHookedDEPR, + "ADE_ConditioningSetMask": ConditioningSetMaskHookedDEPR, + "ADE_PairedConditioningSetMaskAndCombine": PairedConditioningSetMaskAndCombineHookedDEPR, + "ADE_ConditioningSetMaskAndCombine": ConditioningSetMaskAndCombineHookedDEPR, + "ADE_PairedConditioningSetUnmaskedAndCombine": PairedConditioningSetUnmaskedAndCombineHookedDEPR, + "ADE_ConditioningSetUnmaskedAndCombine": ConditioningSetUnmaskedAndCombineHookedDEPR, + "ADE_PairedConditioningCombine": PairedConditioningCombineDEPR, + "ADE_ConditioningCombine": ConditioningCombineDEPR, + "ADE_TimestepsConditioning": ConditioningTimestepsNodeDEPR, # Noise Layer Nodes "ADE_NoiseLayerAdd": NoiseLayerAddNode, "ADE_NoiseLayerAddWeighted": NoiseLayerAddWeightedNode, @@ -211,13 +212,13 @@ # FancyVideo ApplyAnimateDiffFancyVideo.NodeID: ApplyAnimateDiffFancyVideo, # Deprecated Nodes - "AnimateDiffLoaderV1": AnimateDiffLoader_Deprecated, - "ADE_AnimateDiffLoaderV1Advanced": AnimateDiffLoaderAdvanced_Deprecated, - "ADE_AnimateDiffCombine": AnimateDiffCombine_Deprecated, - "ADE_AnimateDiffModelSettings_Release": AnimateDiffModelSettings, - "ADE_AnimateDiffModelSettingsSimple": AnimateDiffModelSettingsSimple, - "ADE_AnimateDiffModelSettings": AnimateDiffModelSettingsAdvanced, - "ADE_AnimateDiffModelSettingsAdvancedAttnStrengths": AnimateDiffModelSettingsAdvancedAttnStrengths, + "AnimateDiffLoaderV1": AnimateDiffLoaderDEPR, + "ADE_AnimateDiffLoaderV1Advanced": AnimateDiffLoaderAdvancedDEPR, + "ADE_AnimateDiffCombine": AnimateDiffCombineDEPR, + "ADE_AnimateDiffModelSettings_Release": AnimateDiffModelSettingsDEPR, + "ADE_AnimateDiffModelSettingsSimple": AnimateDiffModelSettingsSimpleDEPR, + "ADE_AnimateDiffModelSettings": AnimateDiffModelSettingsAdvancedDEPR, + "ADE_AnimateDiffModelSettingsAdvancedAttnStrengths": AnimateDiffModelSettingsAdvancedAttnStrengthsDEPR, } NODE_DISPLAY_NAME_MAPPINGS = { # Unencapsulated diff --git a/animatediff/nodes_conditioning.py b/animatediff/nodes_conditioning.py index 3e2db0d..0d7c36e 100644 --- a/animatediff/nodes_conditioning.py +++ b/animatediff/nodes_conditioning.py @@ -17,10 +17,69 @@ from .logger import logger +class CreateLoraHookKeyframeInterpolation: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), + "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}), + "strength_start": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ), + "strength_end": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ), + "interpolation": (InterpolationMethod._LIST, ), + "intervals": ("INT", {"default": 5, "min": 2, "max": 100, "step": 1}), + "print_keyframes": ("BOOLEAN", {"default": False}), + }, + "optional": { + "prev_hook_kf": ("HOOK_KEYFRAMES",), + "autosize": ("ADEAUTOSIZE", {"padding": 0}), + } + } + + RETURN_TYPES = ("HOOK_KEYFRAMES",) + RETURN_NAMES = ("HOOK_KF",) + CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/conditioning/schedule lora hooks" + FUNCTION = "create_hook_keyframes" + + def create_hook_keyframes(self, + start_percent: float, end_percent: float, + strength_start: float, strength_end: float, interpolation: str, intervals: int, + prev_hook_kf: HookKeyframeGroup=None, print_keyframes=False): + if prev_hook_kf: + prev_hook_kf = prev_hook_kf.clone() + else: + prev_hook_kf = HookKeyframeGroup() + percents = InterpolationMethod.get_weights(num_from=start_percent, num_to=end_percent, length=intervals, method=InterpolationMethod.LINEAR) + strengths = InterpolationMethod.get_weights(num_from=strength_start, num_to=strength_end, length=intervals, method=interpolation) + + is_first = True + for percent, strength in zip(percents, strengths): + guarantee_steps = 0 + if is_first: + guarantee_steps = 1 + is_first = False + prev_hook_kf.add(HookKeyframe(strength=strength, start_percent=percent, guarantee_steps=guarantee_steps)) + if print_keyframes: + logger.info(f"HookKeyframe - start_percent:{percent} = {strength}") + return (prev_hook_kf,) + + + +################################################################### +# EVERYTHING BELOW HERE IS DEPRECATED; +# Can be replaced with vanilla ComfyUI nodes +#------------------------------------------------------------------ +#------------------------------------------------------------------ +#------------------------------------------------------------------ +#------------------------------------------------------------------ +#------------------------------------------------------------------ + + + ############################################### ### Mask, Combine, and Hook Conditioning ############################################### -class PairedConditioningSetMaskHooked: +class PairedConditioningSetMaskHookedDEPR: @classmethod def INPUT_TYPES(s): return { @@ -35,6 +94,7 @@ def INPUT_TYPES(s): "opt_lora_hook": ("HOOKS",), "opt_timesteps": ("TIMESTEPS_RANGE",), "autosize": ("ADEAUTOSIZE", {"padding": 0}), + "deprecation_warning": ("ADEWARN", {"text": "Deprecated - use native ComfyUI nodes instead."}), } } @@ -42,6 +102,7 @@ def INPUT_TYPES(s): RETURN_NAMES = ("positive", "negative") CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/conditioning" FUNCTION = "append_and_hook" + DEPRECATED = True def append_and_hook(self, positive_ADD, negative_ADD, strength: float, set_cond_area: str, @@ -52,7 +113,7 @@ def append_and_hook(self, positive_ADD, negative_ADD, return (final_positive, final_negative) -class ConditioningSetMaskHooked: +class ConditioningSetMaskHookedDEPR: @classmethod def INPUT_TYPES(s): return { @@ -66,12 +127,14 @@ def INPUT_TYPES(s): "opt_lora_hook": ("HOOKS",), "opt_timesteps": ("TIMESTEPS_RANGE",), "autosize": ("ADEAUTOSIZE", {"padding": 0}), + "deprecation_warning": ("ADEWARN", {"text": "Deprecated - use native ComfyUI nodes instead."}), } } RETURN_TYPES = ("CONDITIONING",) CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/conditioning/single cond ops" FUNCTION = "append_and_hook" + DEPRECATED = True def append_and_hook(self, cond_ADD, strength: float, set_cond_area: str, @@ -82,7 +145,7 @@ def append_and_hook(self, cond_ADD, return (final_conditioning,) -class PairedConditioningSetMaskAndCombineHooked: +class PairedConditioningSetMaskAndCombineHookedDEPR: @classmethod def INPUT_TYPES(s): return { @@ -99,6 +162,7 @@ def INPUT_TYPES(s): "opt_lora_hook": ("HOOKS",), "opt_timesteps": ("TIMESTEPS_RANGE",), "autosize": ("ADEAUTOSIZE", {"padding": 0}), + "deprecation_warning": ("ADEWARN", {"text": "Deprecated - use native ComfyUI nodes instead."}), } } @@ -106,6 +170,7 @@ def INPUT_TYPES(s): RETURN_NAMES = ("positive", "negative") CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/conditioning" FUNCTION = "append_and_combine" + DEPRECATED = True def append_and_combine(self, positive, negative, positive_ADD, negative_ADD, strength: float, set_cond_area: str, @@ -116,7 +181,7 @@ def append_and_combine(self, positive, negative, positive_ADD, negative_ADD, return (final_positive, final_negative,) -class ConditioningSetMaskAndCombineHooked: +class ConditioningSetMaskAndCombineHookedDEPR: @classmethod def INPUT_TYPES(s): return { @@ -131,12 +196,14 @@ def INPUT_TYPES(s): "opt_lora_hook": ("HOOKS",), "opt_timesteps": ("TIMESTEPS_RANGE",), "autosize": ("ADEAUTOSIZE", {"padding": 0}), + "deprecation_warning": ("ADEWARN", {"text": "Deprecated - use native ComfyUI nodes instead."}), } } RETURN_TYPES = ("CONDITIONING",) CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/conditioning/single cond ops" FUNCTION = "append_and_combine" + DEPRECATED = True def append_and_combine(self, cond, cond_ADD, strength: float, set_cond_area: str, @@ -147,7 +214,7 @@ def append_and_combine(self, cond, cond_ADD, return (final_conditioning,) -class PairedConditioningSetUnmaskedAndCombineHooked: +class PairedConditioningSetUnmaskedAndCombineHookedDEPR: @classmethod def INPUT_TYPES(s): return { @@ -160,6 +227,7 @@ def INPUT_TYPES(s): "optional": { "opt_lora_hook": ("HOOKS",), "autosize": ("ADEAUTOSIZE", {"padding": 0}), + "deprecation_warning": ("ADEWARN", {"text": "Deprecated - use native ComfyUI nodes instead."}), } } @@ -167,6 +235,7 @@ def INPUT_TYPES(s): RETURN_NAMES = ("positive", "negative") CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/conditioning" FUNCTION = "append_and_combine" + DEPRECATED = True def append_and_combine(self, positive, negative, positive_DEFAULT, negative_DEFAULT, opt_lora_hook: HookGroup=None): @@ -175,7 +244,7 @@ def append_and_combine(self, positive, negative, positive_DEFAULT, negative_DEFA return (final_positive, final_negative,) -class ConditioningSetUnmaskedAndCombineHooked: +class ConditioningSetUnmaskedAndCombineHookedDEPR: @classmethod def INPUT_TYPES(s): return { @@ -186,12 +255,14 @@ def INPUT_TYPES(s): "optional": { "opt_lora_hook": ("HOOKS",), "autosize": ("ADEAUTOSIZE", {"padding": 0}), + "deprecation_warning": ("ADEWARN", {"text": "Deprecated - use native ComfyUI nodes instead."}), } } RETURN_TYPES = ("CONDITIONING",) CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/conditioning/single cond ops" FUNCTION = "append_and_combine" + DEPRECATED = True def append_and_combine(self, cond, cond_DEFAULT, opt_lora_hook: HookGroup=None): @@ -200,7 +271,7 @@ def append_and_combine(self, cond, cond_DEFAULT, return (final_conditioning,) -class PairedConditioningCombine: +class PairedConditioningCombineDEPR: @classmethod def INPUT_TYPES(s): return { @@ -210,19 +281,23 @@ def INPUT_TYPES(s): "positive_B": ("CONDITIONING",), "negative_B": ("CONDITIONING",), }, + "optional": { + "deprecation_warning": ("ADEWARN", {"text": "Deprecated - use native ComfyUI nodes instead."}), + } } RETURN_TYPES = ("CONDITIONING", "CONDITIONING") RETURN_NAMES = ("positive", "negative") CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/conditioning" FUNCTION = "combine" + DEPRECATED = True def combine(self, positive_A, negative_A, positive_B, negative_B): final_positive, final_negative = comfy.hooks.set_mask_and_combine_conds(conds=[positive_A, negative_A], new_conds=[positive_B, negative_B],) return (final_positive, final_negative,) -class ConditioningCombine: +class ConditioningCombineDEPR: @classmethod def INPUT_TYPES(s): return { @@ -230,11 +305,15 @@ def INPUT_TYPES(s): "cond_A": ("CONDITIONING",), "cond_B": ("CONDITIONING",), }, + "optional": { + "deprecation_warning": ("ADEWARN", {"text": "Deprecated - use native ComfyUI nodes instead."}), + } } RETURN_TYPES = ("CONDITIONING",) CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/conditioning/single cond ops" FUNCTION = "combine" + DEPRECATED = True def combine(self, cond_A, cond_B): (final_conditioning,) = comfy.hooks.set_mask_and_combine_conds(conds=[cond_A], new_conds=[cond_B],) @@ -248,7 +327,7 @@ def combine(self, cond_A, cond_B): ############################################### ### Scheduling ############################################### -class ConditioningTimestepsNode: +class ConditioningTimestepsNodeDEPR: @classmethod def INPUT_TYPES(s): return { @@ -258,18 +337,20 @@ def INPUT_TYPES(s): }, "optional": { "autosize": ("ADEAUTOSIZE", {"padding": 0}), + "deprecation_warning": ("ADEWARN", {"text": "Deprecated - use native ComfyUI nodes instead."}), } } RETURN_TYPES = ("TIMESTEPS_RANGE",) CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/conditioning" FUNCTION = "create_schedule" + DEPRECATED = True def create_schedule(self, start_percent: float, end_percent: float): return ((start_percent, end_percent),) -class SetLoraHookKeyframes: +class SetLoraHookKeyframesDEPR: @classmethod def INPUT_TYPES(s): return { @@ -279,12 +360,14 @@ def INPUT_TYPES(s): }, "optional": { "autosize": ("ADEAUTOSIZE", {"padding": 0}), + "deprecation_warning": ("ADEWARN", {"text": "Deprecated - use native ComfyUI nodes instead."}), } } RETURN_TYPES = ("HOOKS",) CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/conditioning" FUNCTION = "set_hook_keyframes" + DEPRECATED = True def set_hook_keyframes(self, lora_hook: HookGroup, hook_kf: HookKeyframeGroup): new_lora_hook = lora_hook.clone() @@ -292,7 +375,7 @@ def set_hook_keyframes(self, lora_hook: HookGroup, hook_kf: HookKeyframeGroup): return (new_lora_hook,) -class CreateLoraHookKeyframe: +class CreateLoraHookKeyframeDEPR: @classmethod def INPUT_TYPES(s): return { @@ -304,6 +387,7 @@ def INPUT_TYPES(s): "optional": { "prev_hook_kf": ("HOOK_KEYFRAMES",), "autosize": ("ADEAUTOSIZE", {"padding": 0}), + "deprecation_warning": ("ADEWARN", {"text": "Deprecated - use native ComfyUI nodes instead."}), } } @@ -311,6 +395,7 @@ def INPUT_TYPES(s): RETURN_NAMES = ("HOOK_KF",) CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/conditioning/schedule lora hooks" FUNCTION = "create_hook_keyframe" + DEPRECATED = True def create_hook_keyframe(self, strength_model: float, start_percent: float, guarantee_steps: float, prev_hook_kf: HookKeyframeGroup=None): @@ -321,56 +406,9 @@ def create_hook_keyframe(self, strength_model: float, start_percent: float, guar keyframe = HookKeyframe(strength=strength_model, start_percent=start_percent, guarantee_steps=guarantee_steps) prev_hook_kf.add(keyframe) return (prev_hook_kf,) - - -class CreateLoraHookKeyframeInterpolation: - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), - "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}), - "strength_start": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ), - "strength_end": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ), - "interpolation": (InterpolationMethod._LIST, ), - "intervals": ("INT", {"default": 5, "min": 2, "max": 100, "step": 1}), - "print_keyframes": ("BOOLEAN", {"default": False}), - }, - "optional": { - "prev_hook_kf": ("HOOK_KEYFRAMES",), - "autosize": ("ADEAUTOSIZE", {"padding": 0}), - } - } - RETURN_TYPES = ("HOOK_KEYFRAMES",) - RETURN_NAMES = ("HOOK_KF",) - CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/conditioning/schedule lora hooks" - FUNCTION = "create_hook_keyframes" - def create_hook_keyframes(self, - start_percent: float, end_percent: float, - strength_start: float, strength_end: float, interpolation: str, intervals: int, - prev_hook_kf: HookKeyframeGroup=None, print_keyframes=False): - if prev_hook_kf: - prev_hook_kf = prev_hook_kf.clone() - else: - prev_hook_kf = HookKeyframeGroup() - percents = InterpolationMethod.get_weights(num_from=start_percent, num_to=end_percent, length=intervals, method=InterpolationMethod.LINEAR) - strengths = InterpolationMethod.get_weights(num_from=strength_start, num_to=strength_end, length=intervals, method=interpolation) - - is_first = True - for percent, strength in zip(percents, strengths): - guarantee_steps = 0 - if is_first: - guarantee_steps = 1 - is_first = False - prev_hook_kf.add(HookKeyframe(strength=strength, start_percent=percent, guarantee_steps=guarantee_steps)) - if print_keyframes: - logger.info(f"HookKeyframe - start_percent:{percent} = {strength}") - return (prev_hook_kf,) - - -class CreateLoraHookKeyframeFromStrengthList: +class CreateLoraHookKeyframeFromStrengthListDEPR: @classmethod def INPUT_TYPES(s): return { @@ -382,6 +420,7 @@ def INPUT_TYPES(s): }, "optional": { "prev_hook_kf": ("HOOK_KEYFRAMES",), + "deprecation_warning": ("ADEWARN", {"text": "Deprecated - use native ComfyUI nodes instead."}), } } @@ -389,6 +428,7 @@ def INPUT_TYPES(s): RETURN_NAMES = ("HOOK_KF",) CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/conditioning/schedule lora hooks" FUNCTION = "create_hook_keyframes" + DEPRECATED = True def create_hook_keyframes(self, strengths_float: Union[float, list[float]], start_percent: float, end_percent: float, @@ -424,7 +464,7 @@ def create_hook_keyframes(self, strengths_float: Union[float, list[float]], ### Register LoRA Hooks ############################################### # based on ComfyUI's nodes.py LoraLoader -class MaskableLoraLoader: +class MaskableLoraLoaderDEPR: def __init__(self): self.loaded_lora = None @@ -437,12 +477,16 @@ def INPUT_TYPES(s): "lora_name": (folder_paths.get_filename_list("loras"), ), "strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}), "strength_clip": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}), + }, + "optional": { + "deprecation_warning": ("ADEWARN", {"text": "Deprecated - use native ComfyUI nodes instead."}), } } RETURN_TYPES = ("MODEL", "CLIP", "HOOKS") CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/conditioning/register lora hooks" FUNCTION = "load_lora" + DEPRECATED = True def load_lora(self, model: Union[ModelPatcher], clip: CLIP, lora_name: str, strength_model: float, strength_clip: float): if strength_model == 0 and strength_clip == 0: @@ -467,7 +511,7 @@ def load_lora(self, model: Union[ModelPatcher], clip: CLIP, lora_name: str, stre return (model_lora, clip_lora, hooks) -class MaskableLoraLoaderModelOnly(MaskableLoraLoader): +class MaskableLoraLoaderModelOnlyDEPR(MaskableLoraLoaderDEPR): @classmethod def INPUT_TYPES(s): return { @@ -475,12 +519,16 @@ def INPUT_TYPES(s): "model": ("MODEL",), "lora_name": (folder_paths.get_filename_list("loras"), ), "strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}), + }, + "optional": { + "deprecation_warning": ("ADEWARN", {"text": "Deprecated - use native ComfyUI nodes instead."}), } } RETURN_TYPES = ("MODEL", "HOOKS") CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/conditioning/register lora hooks" FUNCTION = "load_lora_model_only" + DEPRECATED = True def load_lora_model_only(self, model: ModelPatcher, lora_name: str, strength_model: float): model_lora, _, hooks = self.load_lora(model=model, clip=None, lora_name=lora_name, @@ -488,7 +536,7 @@ def load_lora_model_only(self, model: ModelPatcher, lora_name: str, strength_mod return (model_lora, hooks) -class MaskableSDModelLoader(comfy_extras.nodes_hooks.CreateHookModelAsLora): +class MaskableSDModelLoaderDEPR(comfy_extras.nodes_hooks.CreateHookModelAsLora): @classmethod def INPUT_TYPES(s): return { @@ -498,19 +546,23 @@ def INPUT_TYPES(s): "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ), "strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}), "strength_clip": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}), + }, + "optional": { + "deprecation_warning": ("ADEWARN", {"text": "Deprecated - use native ComfyUI nodes instead."}), } } RETURN_TYPES = ("MODEL", "CLIP", "HOOKS") CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/conditioning/register lora hooks" FUNCTION = "load_model_as_lora" + DEPRECATED = True def load_model_as_lora(self, model: ModelPatcher, clip: CLIP, ckpt_name: str, strength_model: float, strength_clip: float): returned = self.create_hook(ckpt_name=ckpt_name, strength_model=strength_model, strength_clip=strength_clip) return (model.clone(), clip.clone(), returned[0]) -class MaskableSDModelLoaderModelOnly(MaskableSDModelLoader): +class MaskableSDModelLoaderModelOnlyDEPR(MaskableSDModelLoaderDEPR): @classmethod def INPUT_TYPES(s): return { @@ -518,12 +570,16 @@ def INPUT_TYPES(s): "model": ("MODEL",), "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ), "strength_model": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}), + }, + "optional": { + "deprecation_warning": ("ADEWARN", {"text": "Deprecated - use native ComfyUI nodes instead."}), } } RETURN_TYPES = ("MODEL", "HOOKS") CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/conditioning/register lora hooks" FUNCTION = "load_model_as_lora_model_only" + DEPRECATED = True def load_model_as_lora_model_only(self, model: ModelPatcher, ckpt_name: str, strength_model: float): model_lora, _, hooks = self.load_model_as_lora(model=model, clip=None, ckpt_name=ckpt_name, @@ -538,7 +594,7 @@ def load_model_as_lora_model_only(self, model: ModelPatcher, ckpt_name: str, str ############################################### ### Set LoRA Hooks ############################################### -class SetModelLoraHook: +class SetModelLoraHookDEPR: @classmethod def INPUT_TYPES(s): return { @@ -548,18 +604,20 @@ def INPUT_TYPES(s): }, "optional": { "autosize": ("ADEAUTOSIZE", {"padding": 0}), + "deprecation_warning": ("ADEWARN", {"text": "Deprecated - use native ComfyUI nodes instead."}), } } RETURN_TYPES = ("CONDITIONING",) CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/conditioning/single cond ops" FUNCTION = "attach_lora_hook" + DEPRECATED = True def attach_lora_hook(self, conditioning, lora_hook: HookGroup): return (comfy.hooks.set_hooks_for_conditioning(conditioning, lora_hook),) -class SetClipLoraHook: +class SetClipLoraHookDEPR: @classmethod def INPUT_TYPES(s): return { @@ -569,6 +627,7 @@ def INPUT_TYPES(s): }, "optional": { "autosize": ("ADEAUTOSIZE", {"padding": 0}), + "deprecation_warning": ("ADEWARN", {"text": "Deprecated - use native ComfyUI nodes instead."}), } } @@ -576,12 +635,13 @@ def INPUT_TYPES(s): RETURN_NAMES = ("hook_CLIP",) CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/conditioning" FUNCTION = "apply_lora_hook" + DEPRECATED = True def apply_lora_hook(self, clip: CLIP, lora_hook: HookGroup): return comfy_extras.nodes_hooks.SetClipHooks.apply_hooks(self, clip, False, lora_hook) -class CombineLoraHooks: +class CombineLoraHooksDEPR: @classmethod def INPUT_TYPES(s): return { @@ -591,19 +651,21 @@ def INPUT_TYPES(s): "lora_hook_A": ("HOOKS",), "lora_hook_B": ("HOOKS",), "autosize": ("ADEAUTOSIZE", {"padding": 0}), + "deprecation_warning": ("ADEWARN", {"text": "Deprecated - use native ComfyUI nodes instead."}), } } RETURN_TYPES = ("HOOKS",) CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/conditioning/combine lora hooks" FUNCTION = "combine_lora_hooks" + DEPRECATED = True def combine_lora_hooks(self, lora_hook_A: HookGroup=None, lora_hook_B: HookGroup=None): candidates = [lora_hook_A, lora_hook_B] return (HookGroup.combine_all_hooks(candidates),) -class CombineLoraHookFourOptional: +class CombineLoraHookFourOptionalDEPR: @classmethod def INPUT_TYPES(s): return { @@ -615,12 +677,14 @@ def INPUT_TYPES(s): "lora_hook_C": ("HOOKS",), "lora_hook_D": ("HOOKS",), "autosize": ("ADEAUTOSIZE", {"padding": 0}), + "deprecation_warning": ("ADEWARN", {"text": "Deprecated - use native ComfyUI nodes instead."}), } } RETURN_TYPES = ("HOOKS",) CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/conditioning/combine lora hooks" FUNCTION = "combine_lora_hooks" + DEPRECATED = True def combine_lora_hooks(self, lora_hook_A: HookGroup=None, lora_hook_B: HookGroup=None, @@ -629,7 +693,7 @@ def combine_lora_hooks(self, return (HookGroup.combine_all_hooks(candidates),) -class CombineLoraHookEightOptional: +class CombineLoraHookEightOptionalDEPR: @classmethod def INPUT_TYPES(s): return { @@ -645,12 +709,14 @@ def INPUT_TYPES(s): "lora_hook_G": ("HOOKS",), "lora_hook_H": ("HOOKS",), "autosize": ("ADEAUTOSIZE", {"padding": 0}), + "deprecation_warning": ("ADEWARN", {"text": "Deprecated - use native ComfyUI nodes instead."}), } } RETURN_TYPES = ("HOOKS",) CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/conditioning/combine lora hooks" FUNCTION = "combine_lora_hooks" + DEPRECATED = True def combine_lora_hooks(self, lora_hook_A: HookGroup=None, lora_hook_B: HookGroup=None, diff --git a/animatediff/nodes_deprecated.py b/animatediff/nodes_deprecated.py index a26e138..50fe5c9 100644 --- a/animatediff/nodes_deprecated.py +++ b/animatediff/nodes_deprecated.py @@ -20,7 +20,7 @@ from .sampling import outer_sample_wrapper, sliding_calc_cond_batch -class AnimateDiffLoader_Deprecated: +class AnimateDiffLoaderDEPR: @classmethod def INPUT_TYPES(s): return { @@ -76,7 +76,7 @@ def load_mm_and_inject_params( return (model, latents) -class AnimateDiffLoaderAdvanced_Deprecated: +class AnimateDiffLoaderAdvancedDEPR: @classmethod def INPUT_TYPES(s): return { @@ -150,7 +150,7 @@ def load_mm_and_inject_params(self, return (model, latents) -class AnimateDiffCombine_Deprecated: +class AnimateDiffCombineDEPR: ffmpeg_warning_already_shown = False @classmethod def INPUT_TYPES(s): @@ -292,7 +292,7 @@ def generate_gif( -class AnimateDiffModelSettings: +class AnimateDiffModelSettingsDEPR: @classmethod def INPUT_TYPES(s): return { @@ -321,7 +321,7 @@ def get_motion_model_settings(self, mask_motion_scale: torch.Tensor=None, min_mo return (motion_model_settings,) -class AnimateDiffModelSettingsSimple: +class AnimateDiffModelSettingsSimpleDEPR: @classmethod def INPUT_TYPES(s): return { @@ -354,7 +354,7 @@ def get_motion_model_settings(self, motion_pe_stretch: int, return (motion_model_settings,) -class AnimateDiffModelSettingsAdvanced: +class AnimateDiffModelSettingsAdvancedDEPR: @classmethod def INPUT_TYPES(s): return { @@ -405,7 +405,7 @@ def get_motion_model_settings(self, pe_strength: float, attn_strength: float, ot return (motion_model_settings,) -class AnimateDiffModelSettingsAdvancedAttnStrengths: +class AnimateDiffModelSettingsAdvancedAttnStrengthsDEPR: @classmethod def INPUT_TYPES(s): return { From 31955647a1b6270d18bc8729b4ac906273411ca4 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Tue, 12 Nov 2024 09:48:28 -0600 Subject: [PATCH 22/43] Moved autosize param from optional to hidden for all nodes --- animatediff/nodes_ad_settings.py | 20 ++++++++++ animatediff/nodes_animatelcmi2v.py | 2 + animatediff/nodes_cameractrl.py | 4 ++ animatediff/nodes_conditioning.py | 58 ++++++++++++++++++++++------- animatediff/nodes_context.py | 2 + animatediff/nodes_context_extras.py | 32 ++++++++++++++-- animatediff/nodes_fancyvideo.py | 2 + animatediff/nodes_gen2.py | 8 ++++ animatediff/nodes_lora.py | 2 + animatediff/nodes_multival.py | 12 +++++- animatediff/nodes_per_block.py | 2 + animatediff/nodes_sample.py | 32 +++++++++++++++- animatediff/nodes_scheduling.py | 12 ++++-- animatediff/nodes_sigma_schedule.py | 10 ++--- 14 files changed, 169 insertions(+), 29 deletions(-) diff --git a/animatediff/nodes_ad_settings.py b/animatediff/nodes_ad_settings.py index 3eb0de1..fed5f5d 100644 --- a/animatediff/nodes_ad_settings.py +++ b/animatediff/nodes_ad_settings.py @@ -9,6 +9,8 @@ def INPUT_TYPES(s): "optional": { "pe_adjust": ("PE_ADJUST",), "weight_adjust": ("WEIGHT_ADJUST",), + }, + "hidden": { "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -35,6 +37,8 @@ def INPUT_TYPES(s): }, "optional": { "prev_pe_adjust": ("PE_ADJUST",), + }, + "hidden": { "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -67,6 +71,8 @@ def INPUT_TYPES(s): }, "optional": { "prev_pe_adjust": ("PE_ADJUST",), + }, + "hidden": { "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -95,6 +101,8 @@ def INPUT_TYPES(s): }, "optional": { "prev_pe_adjust": ("PE_ADJUST",), + }, + "hidden": { "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -123,6 +131,8 @@ def INPUT_TYPES(s): }, "optional": { "prev_weight_adjust": ("WEIGHT_ADJUST",), + }, + "hidden": { "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -153,6 +163,8 @@ def INPUT_TYPES(s): }, "optional": { "prev_weight_adjust": ("WEIGHT_ADJUST",), + }, + "hidden": { "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -185,6 +197,8 @@ def INPUT_TYPES(s): }, "optional": { "prev_weight_adjust": ("WEIGHT_ADJUST",), + }, + "hidden": { "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -219,6 +233,8 @@ def INPUT_TYPES(s): }, "optional": { "prev_weight_adjust": ("WEIGHT_ADJUST",), + }, + "hidden": { "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -258,6 +274,8 @@ def INPUT_TYPES(s): }, "optional": { "prev_weight_adjust": ("WEIGHT_ADJUST",), + }, + "hidden": { "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -305,6 +323,8 @@ def INPUT_TYPES(s): }, "optional": { "prev_weight_adjust": ("WEIGHT_ADJUST",), + }, + "hidden": { "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } diff --git a/animatediff/nodes_animatelcmi2v.py b/animatediff/nodes_animatelcmi2v.py index 9d63e12..22a91e4 100644 --- a/animatediff/nodes_animatelcmi2v.py +++ b/animatediff/nodes_animatelcmi2v.py @@ -35,6 +35,8 @@ def INPUT_TYPES(s): "ad_keyframes": ("AD_KEYFRAMES",), "prev_m_models": ("M_MODELS",), "per_block": ("PER_BLOCK",), + }, + "hidden": { "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } diff --git a/animatediff/nodes_cameractrl.py b/animatediff/nodes_cameractrl.py index 9e9a1ab..cd64b0d 100644 --- a/animatediff/nodes_cameractrl.py +++ b/animatediff/nodes_cameractrl.py @@ -273,6 +273,8 @@ def INPUT_TYPES(s): "cameractrl_multival": ("MULTIVAL",), "inherit_missing": ("BOOLEAN", {"default": True}, ), "guarantee_steps": ("INT", {"default": 1, "min": 0, "max": BIGMAX}), + }, + "hidden": { "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -373,6 +375,8 @@ def INPUT_TYPES(cls): }, "optional": { "prev_poses": ("CAMERACTRL_POSES",), + }, + "hidden": { "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } diff --git a/animatediff/nodes_conditioning.py b/animatediff/nodes_conditioning.py index 0d7c36e..c70cd86 100644 --- a/animatediff/nodes_conditioning.py +++ b/animatediff/nodes_conditioning.py @@ -32,6 +32,8 @@ def INPUT_TYPES(s): }, "optional": { "prev_hook_kf": ("HOOK_KEYFRAMES",), + }, + "hidden": { "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -93,8 +95,10 @@ def INPUT_TYPES(s): "opt_mask": ("MASK", ), "opt_lora_hook": ("HOOKS",), "opt_timesteps": ("TIMESTEPS_RANGE",), - "autosize": ("ADEAUTOSIZE", {"padding": 0}), "deprecation_warning": ("ADEWARN", {"text": "Deprecated - use native ComfyUI nodes instead."}), + }, + "hidden": { + "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -126,8 +130,10 @@ def INPUT_TYPES(s): "opt_mask": ("MASK", ), "opt_lora_hook": ("HOOKS",), "opt_timesteps": ("TIMESTEPS_RANGE",), - "autosize": ("ADEAUTOSIZE", {"padding": 0}), "deprecation_warning": ("ADEWARN", {"text": "Deprecated - use native ComfyUI nodes instead."}), + }, + "hidden": { + "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -161,8 +167,10 @@ def INPUT_TYPES(s): "opt_mask": ("MASK", ), "opt_lora_hook": ("HOOKS",), "opt_timesteps": ("TIMESTEPS_RANGE",), - "autosize": ("ADEAUTOSIZE", {"padding": 0}), "deprecation_warning": ("ADEWARN", {"text": "Deprecated - use native ComfyUI nodes instead."}), + }, + "hidden": { + "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -195,8 +203,10 @@ def INPUT_TYPES(s): "opt_mask": ("MASK", ), "opt_lora_hook": ("HOOKS",), "opt_timesteps": ("TIMESTEPS_RANGE",), - "autosize": ("ADEAUTOSIZE", {"padding": 0}), "deprecation_warning": ("ADEWARN", {"text": "Deprecated - use native ComfyUI nodes instead."}), + }, + "hidden": { + "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -226,8 +236,10 @@ def INPUT_TYPES(s): }, "optional": { "opt_lora_hook": ("HOOKS",), - "autosize": ("ADEAUTOSIZE", {"padding": 0}), "deprecation_warning": ("ADEWARN", {"text": "Deprecated - use native ComfyUI nodes instead."}), + }, + "hidden": { + "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -254,8 +266,10 @@ def INPUT_TYPES(s): }, "optional": { "opt_lora_hook": ("HOOKS",), - "autosize": ("ADEAUTOSIZE", {"padding": 0}), "deprecation_warning": ("ADEWARN", {"text": "Deprecated - use native ComfyUI nodes instead."}), + }, + "hidden": { + "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -336,8 +350,10 @@ def INPUT_TYPES(s): "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}) }, "optional": { - "autosize": ("ADEAUTOSIZE", {"padding": 0}), "deprecation_warning": ("ADEWARN", {"text": "Deprecated - use native ComfyUI nodes instead."}), + }, + "hidden": { + "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -359,8 +375,10 @@ def INPUT_TYPES(s): "hook_kf": ("HOOK_KEYFRAMES",), }, "optional": { - "autosize": ("ADEAUTOSIZE", {"padding": 0}), "deprecation_warning": ("ADEWARN", {"text": "Deprecated - use native ComfyUI nodes instead."}), + }, + "hidden": { + "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -386,8 +404,10 @@ def INPUT_TYPES(s): }, "optional": { "prev_hook_kf": ("HOOK_KEYFRAMES",), - "autosize": ("ADEAUTOSIZE", {"padding": 0}), "deprecation_warning": ("ADEWARN", {"text": "Deprecated - use native ComfyUI nodes instead."}), + }, + "hidden": { + "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -603,8 +623,10 @@ def INPUT_TYPES(s): "lora_hook": ("HOOKS",), }, "optional": { - "autosize": ("ADEAUTOSIZE", {"padding": 0}), "deprecation_warning": ("ADEWARN", {"text": "Deprecated - use native ComfyUI nodes instead."}), + }, + "hidden": { + "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -626,8 +648,10 @@ def INPUT_TYPES(s): "lora_hook": ("HOOKS",), }, "optional": { - "autosize": ("ADEAUTOSIZE", {"padding": 0}), "deprecation_warning": ("ADEWARN", {"text": "Deprecated - use native ComfyUI nodes instead."}), + }, + "hidden": { + "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -650,8 +674,10 @@ def INPUT_TYPES(s): "optional": { "lora_hook_A": ("HOOKS",), "lora_hook_B": ("HOOKS",), - "autosize": ("ADEAUTOSIZE", {"padding": 0}), "deprecation_warning": ("ADEWARN", {"text": "Deprecated - use native ComfyUI nodes instead."}), + }, + "hidden": { + "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -676,8 +702,10 @@ def INPUT_TYPES(s): "lora_hook_B": ("HOOKS",), "lora_hook_C": ("HOOKS",), "lora_hook_D": ("HOOKS",), - "autosize": ("ADEAUTOSIZE", {"padding": 0}), "deprecation_warning": ("ADEWARN", {"text": "Deprecated - use native ComfyUI nodes instead."}), + }, + "hidden": { + "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -708,8 +736,10 @@ def INPUT_TYPES(s): "lora_hook_F": ("HOOKS",), "lora_hook_G": ("HOOKS",), "lora_hook_H": ("HOOKS",), - "autosize": ("ADEAUTOSIZE", {"padding": 0}), "deprecation_warning": ("ADEWARN", {"text": "Deprecated - use native ComfyUI nodes instead."}), + }, + "hidden": { + "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } diff --git a/animatediff/nodes_context.py b/animatediff/nodes_context.py index 92babc4..255b36a 100644 --- a/animatediff/nodes_context.py +++ b/animatediff/nodes_context.py @@ -236,6 +236,8 @@ def INPUT_TYPES(s): "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), "guarantee_steps": ("INT", {"default": 1, "min": 0, "max": BIGMAX}), "prev_context": ("CONTEXT_OPTIONS",), + }, + "hidden": { "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } diff --git a/animatediff/nodes_context_extras.py b/animatediff/nodes_context_extras.py index 10b3b73..00562de 100644 --- a/animatediff/nodes_context_extras.py +++ b/animatediff/nodes_context_extras.py @@ -19,7 +19,7 @@ def INPUT_TYPES(s): "context_opts": ("CONTEXT_OPTIONS",), "context_extras": ("CONTEXT_EXTRAS",), }, - "optional": { + "hidden": { "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -50,6 +50,8 @@ def INPUT_TYPES(s): "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), "end_percent": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.001}), "weighted_mean": ("FLOAT", {"default": 0.95, "min": 0.0, "max": 1.0, "step": 0.001}), + }, + "hidden": { "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -83,6 +85,8 @@ def INPUT_TYPES(s): "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), "guarantee_steps": ("INT", {"default": 1, "min": 0, "max": BIGMAX}), "inherit_missing": ("BOOLEAN", {"default": True}, ), + }, + "hidden": { "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -120,6 +124,8 @@ def INPUT_TYPES(s): "optional": { "prev_kf": ("NAIVEREUSE_KEYFRAME",), "mult_multival": ("MULTIVAL",), + }, + "hidden": { "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -168,6 +174,8 @@ def INPUT_TYPES(s): "optional": { "prev_kf": ("NAIVEREUSE_KEYFRAME",), "mult_multival": ("MULTIVAL",), + }, + "hidden": { "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -223,6 +231,8 @@ def INPUT_TYPES(s): "contextref_kf": ("CONTEXTREF_KEYFRAME",), "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), "end_percent": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 1.0, "step": 0.001}), + }, + "hidden": { "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -265,6 +275,8 @@ def INPUT_TYPES(s): "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), "guarantee_steps": ("INT", {"default": 1, "min": 0, "max": BIGMAX}), "inherit_missing": ("BOOLEAN", {"default": True}, ), + }, + "hidden": { "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -305,6 +317,8 @@ def INPUT_TYPES(s): "mult_multival": ("MULTIVAL",), "mode_replace": ("CONTEXTREF_MODE",), "tune_replace": ("CONTEXTREF_TUNE",), + }, + "hidden": { "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -354,6 +368,8 @@ def INPUT_TYPES(s): "mult_multival": ("MULTIVAL",), "mode_replace": ("CONTEXTREF_MODE",), "tune_replace": ("CONTEXTREF_TUNE",), + }, + "hidden": { "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -397,9 +413,9 @@ def INPUT_TYPES(s): return { "required": { }, - "optional": { + "hidden": { "autosize": ("ADEAUTOSIZE", {"padding": 0}), - }, + } } RETURN_TYPES = ("CONTEXTREF_MODE",) @@ -419,6 +435,8 @@ def INPUT_TYPES(s): }, "optional": { "sliding_width": ("INT", {"default": 2, "min": 2, "max": BIGMAX, "step": 1}), + }, + "hidden": { "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -441,8 +459,10 @@ def INPUT_TYPES(s): "optional": { "switch_on_idxs": ("STRING", {"default": ""}), "always_include_0": ("BOOLEAN", {"default": True},), - "autosize": ("ADEAUTOSIZE", {"padding": 0}), }, + "hidden": { + "autosize": ("ADEAUTOSIZE", {"padding": 0}), + } } RETURN_TYPES = ("CONTEXTREF_MODE",) @@ -470,6 +490,8 @@ def INPUT_TYPES(s): "adain_style_fidelity": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), "adain_ref_weight": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), "adain_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), + }, + "hidden": { "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -496,6 +518,8 @@ def INPUT_TYPES(s): "attn_style_fidelity": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), "attn_ref_weight": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), "attn_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), + }, + "hidden": { "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } diff --git a/animatediff/nodes_fancyvideo.py b/animatediff/nodes_fancyvideo.py index 451b698..6d8189f 100644 --- a/animatediff/nodes_fancyvideo.py +++ b/animatediff/nodes_fancyvideo.py @@ -35,6 +35,8 @@ def INPUT_TYPES(s): "ad_keyframes": ("AD_KEYFRAMES",), "prev_m_models": ("M_MODELS",), "per_block": ("PER_BLOCK",), + }, + "hidden": { "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } diff --git a/animatediff/nodes_gen2.py b/animatediff/nodes_gen2.py index b97865e..b2e8b6a 100644 --- a/animatediff/nodes_gen2.py +++ b/animatediff/nodes_gen2.py @@ -100,6 +100,8 @@ def INPUT_TYPES(s): "ad_keyframes": ("AD_KEYFRAMES",), "prev_m_models": ("M_MODELS",), "per_block": ("PER_BLOCK",), + }, + "hidden": { "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -151,6 +153,8 @@ def INPUT_TYPES(s): "effect_multival": ("MULTIVAL",), "ad_keyframes": ("AD_KEYFRAMES",), "per_block": ("PER_BLOCK",), + }, + "hidden": { "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -178,6 +182,8 @@ def INPUT_TYPES(s): }, "optional": { "ad_settings": ("AD_SETTINGS",), + }, + "hidden": { "autosize": ("ADEAUTOSIZE", {"padding": 50}), } } @@ -206,6 +212,8 @@ def INPUT_TYPES(s): "effect_multival": ("MULTIVAL",), "inherit_missing": ("BOOLEAN", {"default": True}, ), "guarantee_steps": ("INT", {"default": 1, "min": 0, "max": BIGMAX}), + }, + "hidden": { "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } diff --git a/animatediff/nodes_lora.py b/animatediff/nodes_lora.py index 1f1524c..b74e34a 100644 --- a/animatediff/nodes_lora.py +++ b/animatediff/nodes_lora.py @@ -19,6 +19,8 @@ def INPUT_TYPES(s): }, "optional": { "prev_motion_lora": ("MOTION_LORA",), + }, + "hidden": { "autosize": ("ADEAUTOSIZE", {"padding": 30}), } } diff --git a/animatediff/nodes_multival.py b/animatediff/nodes_multival.py index d2f4a52..5a2934a 100644 --- a/animatediff/nodes_multival.py +++ b/animatediff/nodes_multival.py @@ -22,6 +22,8 @@ def INPUT_TYPES(s): }, "optional": { "mask_optional": ("MASK",), + }, + "hidden": { "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -45,6 +47,8 @@ def INPUT_TYPES(s): }, "optional": { "scaling": (ScaleType.LIST,), + }, + "hidden": { "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -91,6 +95,8 @@ def INPUT_TYPES(s): }, "optional": { "mask_optional": ("MASK",), + }, + "hidden": { "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -112,6 +118,8 @@ def INPUT_TYPES(s): }, "optional": { "mask_optional": ("MASK",), + }, + "hidden": { "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -131,7 +139,7 @@ def INPUT_TYPES(s): "required": { "float_val": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001},), }, - "optional": { + "hidden": { "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -151,7 +159,7 @@ def INPUT_TYPES(s): "required": { "multival": ("MULTIVAL",), }, - "optional": { + "hidden": { "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } diff --git a/animatediff/nodes_per_block.py b/animatediff/nodes_per_block.py index b06621f..ed077e7 100644 --- a/animatediff/nodes_per_block.py +++ b/animatediff/nodes_per_block.py @@ -38,6 +38,8 @@ def INPUT_TYPES(s): "optional": { "effect": ("MULTIVAL",), "scale": ("MULTIVAL",), + }, + "hidden": { "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } diff --git a/animatediff/nodes_sample.py b/animatediff/nodes_sample.py index f274106..6d798dc 100644 --- a/animatediff/nodes_sample.py +++ b/animatediff/nodes_sample.py @@ -34,6 +34,8 @@ def INPUT_TYPES(s): "sigma_schedule": ("SIGMA_SCHEDULE",), "image_inject": ("IMAGE_INJECT",), "noise_calib": ("NOISE_CALIBRATION",), + }, + "hidden": { "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -67,6 +69,8 @@ def INPUT_TYPES(s): "prev_noise_layers": ("NOISE_LAYERS",), "mask_optional": ("MASK",), "seed_override": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff, "forceInput": True}), + }, + "hidden": { "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -103,6 +107,8 @@ def INPUT_TYPES(s): "prev_noise_layers": ("NOISE_LAYERS",), "mask_optional": ("MASK",), "seed_override": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff, "forceInput": True}), + }, + "hidden": { "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -142,6 +148,8 @@ def INPUT_TYPES(s): "prev_noise_layers": ("NOISE_LAYERS",), "mask_optional": ("MASK",), "seed_override": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff, "forceInput": True}), + }, + "hidden": { "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -204,6 +212,8 @@ def INPUT_TYPES(s): "optional": { "iter_batch_offset": ("INT", {"default": 0, "min": 0, "max": BIGMAX}), "iter_seed_offset": ("INT", {"default": 1, "min": BIGMIN, "max": BIGMAX}), + }, + "hidden": { "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -230,7 +240,7 @@ def INPUT_TYPES(s): "calib_iterations": ("INT", {"default": 1, "min": 1, "step": 1}), "thresh_freq": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.001}), }, - "optional": { + "hidden": { "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -254,6 +264,8 @@ def INPUT_TYPES(s): }, "optional": { "cfg_extras": ("CFG_EXTRAS",), + }, + "hidden": { "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -278,6 +290,8 @@ def INPUT_TYPES(s): }, "optional": { "cfg_extras": ("CFG_EXTRAS",), + }, + "hidden": { "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -302,6 +316,8 @@ def INPUT_TYPES(s): "optional": { "prev_custom_cfg": ("CUSTOM_CFG",), "cfg_extras": ("CFG_EXTRAS",), + }, + "hidden": { "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -332,6 +348,8 @@ def INPUT_TYPES(s): "optional": { "prev_custom_cfg": ("CUSTOM_CFG",), "cfg_extras": ("CFG_EXTRAS",), + }, + "hidden": { "autosize": ("ADEAUTOSIZE", {"padding": 10}), } } @@ -362,6 +380,8 @@ def INPUT_TYPES(s): "optional": { "prev_custom_cfg": ("CUSTOM_CFG",), "cfg_extras": ("CFG_EXTRAS",), + }, + "hidden": { "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -449,6 +469,8 @@ def INPUT_TYPES(s): }, "optional": { "prev_extras": ("CFG_EXTRAS",), + }, + "hidden": { "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -481,6 +503,8 @@ def INPUT_TYPES(s): }, "optional": { "prev_extras": ("CFG_EXTRAS",), + }, + "hidden": { "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -534,6 +558,8 @@ def INPUT_TYPES(s): }, "optional": { "prev_extras": ("CFG_EXTRAS",), + }, + "hidden": { "autosize": ("ADEAUTOSIZE", {"padding": 10}), } } @@ -564,6 +590,8 @@ def INPUT_TYPES(s): "img_inject_opts": ("IMAGE_INJECT_OPTIONS", ), "strength_multival": ("MULTIVAL", ), "prev_image_inject": ("IMAGE_INJECT", ), + }, + "hidden": { "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -594,6 +622,8 @@ def INPUT_TYPES(s): "optional": { "composite_x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), "composite_y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), + }, + "hidden": { "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } diff --git a/animatediff/nodes_scheduling.py b/animatediff/nodes_scheduling.py index 01f1775..96cd96c 100644 --- a/animatediff/nodes_scheduling.py +++ b/animatediff/nodes_scheduling.py @@ -149,8 +149,10 @@ def INPUT_TYPES(s): }, "optional": { "print_schedule": ("BOOLEAN", {"default": False}), - "autosize": ("ADEAUTOSIZE", {"padding": 0}), }, + "hidden": { + "autosize": ("ADEAUTOSIZE", {"padding": 0}), + } } RETURN_TYPES = ("FLOAT", "FLOATS", "INT", "INTS") @@ -186,8 +188,10 @@ def INPUT_TYPES(s): "optional": { "print_schedule": ("BOOLEAN", {"default": False}), "max_length": ("INT", {"default": 0, "min": 0, "max": BIGMAX, "step": 1}), - "autosize": ("ADEAUTOSIZE", {"padding": 0}), }, + "hidden": { + "autosize": ("ADEAUTOSIZE", {"padding": 0}), + } } RETURN_TYPES = ("FLOAT", "FLOATS", "INT", "INTS") @@ -223,6 +227,8 @@ def INPUT_TYPES(s): }, "optional": { "prev_replace": ("VALUES_REPLACE",), + }, + "hidden": { "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -259,7 +265,7 @@ def INPUT_TYPES(s): "required": { "FLOAT": ("FLOAT", {"default": 39, "forceInput": True}), }, - "optional": { + "hidden": { "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } diff --git a/animatediff/nodes_sigma_schedule.py b/animatediff/nodes_sigma_schedule.py index 369c3f1..5eafe90 100644 --- a/animatediff/nodes_sigma_schedule.py +++ b/animatediff/nodes_sigma_schedule.py @@ -46,7 +46,7 @@ def INPUT_TYPES(s): "lcm_original_timesteps": ("INT", {"default": 50, "min": 1, "max": 1000}), "zsnr": ("BOOLEAN", {"default": False}), }, - "optional": { + "hidden": { "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -83,7 +83,7 @@ def INPUT_TYPES(s): "schedule_B": ("SIGMA_SCHEDULE",), "weight_A": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.001}), }, - "optional": { + "hidden": { "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -111,7 +111,7 @@ def INPUT_TYPES(s): "weight_A_End": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.001}), "interpolation": (InterpolationMethod._LIST,), }, - "optional": { + "hidden": { "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -142,7 +142,7 @@ def INPUT_TYPES(s): "schedule_End": ("SIGMA_SCHEDULE",), "idx_split_percent": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.001}) }, - "optional": { + "hidden": { "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } @@ -171,7 +171,7 @@ def INPUT_TYPES(s): "steps": ("INT", {"default": 20, "min": 1, "max": 10000}), "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), }, - "optional": { + "hidden": { "autosize": ("ADEAUTOSIZE", {"padding": 0}), } } From 006ac043efada973ecd74ea0cb2a1a6b64a8fd4e Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Tue, 12 Nov 2024 09:55:42 -0600 Subject: [PATCH 23/43] Commented out NoiseCalibration and FancyVideo from user's view so can be safely merged into main, will need to get working properly later --- animatediff/nodes.py | 10 ++++++---- animatediff/nodes_sample.py | 2 +- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/animatediff/nodes.py b/animatediff/nodes.py index f61aa28..9e8f1ed 100644 --- a/animatediff/nodes.py +++ b/animatediff/nodes.py @@ -98,6 +98,8 @@ "ADE_IterationOptsDefault": IterationOptionsNode, "ADE_IterationOptsFreeInit": FreeInitOptionsNode, # Conditioning + "ADE_LoraHookKeyframeInterpolation": CreateLoraHookKeyframeInterpolation, + # Conditioning (DEPRECATED) "ADE_RegisterLoraHook": MaskableLoraLoaderDEPR, "ADE_RegisterLoraHookModelOnly": MaskableLoraLoaderModelOnlyDEPR, "ADE_RegisterModelAsLoraHook": MaskableSDModelLoaderDEPR, @@ -108,7 +110,6 @@ "ADE_SetLoraHookKeyframe": SetLoraHookKeyframesDEPR, "ADE_AttachLoraHookToCLIP": SetClipLoraHookDEPR, "ADE_LoraHookKeyframe": CreateLoraHookKeyframeDEPR, - "ADE_LoraHookKeyframeInterpolation": CreateLoraHookKeyframeInterpolation, "ADE_LoraHookKeyframeFromStrengthList": CreateLoraHookKeyframeFromStrengthListDEPR, "ADE_AttachLoraHookToConditioning": SetModelLoraHookDEPR, "ADE_PairedConditioningSetMask": PairedConditioningSetMaskHookedDEPR, @@ -154,7 +155,7 @@ "ADE_SigmaScheduleToSigmas": SigmaScheduleToSigmasNode, "ADE_NoisedImageInjection": NoisedImageInjectionNode, "ADE_NoisedImageInjectOptions": NoisedImageInjectOptionsNode, - "ADE_NoiseCalibration": NoiseCalibrationNode, + #"ADE_NoiseCalibration": NoiseCalibrationNode, # Scheduling PromptSchedulingNode.NodeID: PromptSchedulingNode, PromptSchedulingLatentsNode.NodeID: PromptSchedulingLatentsNode, @@ -210,7 +211,7 @@ "ADE_PIA_AnimateDiffKeyframe": PIA_ADKeyframeNode, "ADE_InjectPIAIntoAnimateDiffModel": LoadAnimateDiffAndInjectPIANode, # FancyVideo - ApplyAnimateDiffFancyVideo.NodeID: ApplyAnimateDiffFancyVideo, + #ApplyAnimateDiffFancyVideo.NodeID: ApplyAnimateDiffFancyVideo, # Deprecated Nodes "AnimateDiffLoaderV1": AnimateDiffLoaderDEPR, "ADE_AnimateDiffLoaderV1Advanced": AnimateDiffLoaderAdvancedDEPR, @@ -268,6 +269,8 @@ "ADE_IterationOptsDefault": "Default Iteration Options πŸŽ­πŸ…πŸ…“", "ADE_IterationOptsFreeInit": "FreeInit Iteration Options πŸŽ­πŸ…πŸ…“", # Conditioning + "ADE_LoraHookKeyframeInterpolation": "LoRA Hook Keyframes Interp. πŸŽ­πŸ…πŸ…“", + # Conditioning (DEPRECATED) "ADE_RegisterLoraHook": "Register LoRA Hook πŸŽ­πŸ…πŸ…“", "ADE_RegisterLoraHookModelOnly": "Register LoRA Hook (Model Only) πŸŽ­πŸ…πŸ…“", "ADE_RegisterModelAsLoraHook": "Register Model as LoRA Hook πŸŽ­πŸ…πŸ…“", @@ -278,7 +281,6 @@ "ADE_SetLoraHookKeyframe": "Set LoRA Hook Keyframes πŸŽ­πŸ…πŸ…“", "ADE_AttachLoraHookToCLIP": "Set CLIP LoRA Hook πŸŽ­πŸ…πŸ…“", "ADE_LoraHookKeyframe": "LoRA Hook Keyframe πŸŽ­πŸ…πŸ…“", - "ADE_LoraHookKeyframeInterpolation": "LoRA Hook Keyframes Interp. πŸŽ­πŸ…πŸ…“", "ADE_LoraHookKeyframeFromStrengthList": "LoRA Hook Keyframes From List πŸŽ­πŸ…πŸ…“", "ADE_AttachLoraHookToConditioning": "Set Model LoRA Hook πŸŽ­πŸ…πŸ…“", "ADE_PairedConditioningSetMask": "Set Props on Conds πŸŽ­πŸ…πŸ…“", diff --git a/animatediff/nodes_sample.py b/animatediff/nodes_sample.py index 6d798dc..631ae00 100644 --- a/animatediff/nodes_sample.py +++ b/animatediff/nodes_sample.py @@ -33,7 +33,7 @@ def INPUT_TYPES(s): "custom_cfg": ("CUSTOM_CFG",), "sigma_schedule": ("SIGMA_SCHEDULE",), "image_inject": ("IMAGE_INJECT",), - "noise_calib": ("NOISE_CALIBRATION",), + #"noise_calib": ("NOISE_CALIBRATION",), # TODO: bring back once NoiseCalibration is working }, "hidden": { "autosize": ("ADEAUTOSIZE", {"padding": 0}), From c088d80a983195893cd835e420f201eab0f2e870 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Tue, 12 Nov 2024 13:58:20 -0600 Subject: [PATCH 24/43] Made Prompt Scheduling work with schedule_clip, fixed progress bar advancing too quick --- animatediff/scheduling.py | 49 +++++++++++++++++++++++++++++++++++---- 1 file changed, 45 insertions(+), 4 deletions(-) diff --git a/animatediff/scheduling.py b/animatediff/scheduling.py index b223d34..be5e0ab 100644 --- a/animatediff/scheduling.py +++ b/animatediff/scheduling.py @@ -110,6 +110,7 @@ class PromptOptions: append_text: str = '' values_replace: dict[str, list[float]] = None print_schedule: bool = False + add_dict: dict[str] = None def evaluate_prompt_schedule(text: str, length: int, clip: CLIP, options: PromptOptions): @@ -234,11 +235,16 @@ def handle_prompt_interpolation(pairs: list[InputPair], length: int, clip: CLIP, if len(value) < length: values_replace[key] = extend_list_to_batch_size(value, length) - pairs_lengths = len(pairs) + scheduled_keyframes = [] + if clip.use_clip_schedule: + clip = clip.clone() + scheduled_keyframes = clip.patcher.forced_hooks.get_hooks_for_clip_schedule() + + pairs_lengths = len(pairs) * max(1, len(scheduled_keyframes)) pbar_total = length + pairs_lengths pbar = ProgressBar(pbar_total) # for now, use FizzNodes approach of calculating max size of tokens beforehand; - # this doubles total encoding time, as this will be done again. + # this can up to double total encoding time, as this will be done again. # TODO: do this dynamically to save encoding time max_size = 0 for pair in pairs: @@ -247,6 +253,38 @@ def handle_prompt_interpolation(pairs: list[InputPair], length: int, clip: CLIP, max_size = max(max_size, cond.shape[1]) pbar.update(1) + # if do not need to schedule clip with hooks, do nothing special + if not clip.use_clip_schedule: + return _handle_prompt_interpolation(pairs, length, clip, options, values_replace, max_size, pbar) + # otherwise, need to account for keyframes on forced_hooks + full_output = [] + for i, scheduled_opts in enumerate(scheduled_keyframes): + clip.patcher.forced_hooks.reset() + clip.patcher.unpatch_hooks() + + t_range = scheduled_opts[0] + hooks_keyframes = scheduled_opts[1] + for hook, keyframe in hooks_keyframes: + hook.hook_keyframe._current_keyframe = keyframe + try: + # don't print_schedule on non-first iteration + orig_print_schedule = options.print_schedule + if orig_print_schedule and i != 0: + options.print_schedule = False + schedule_output = _handle_prompt_interpolation(pairs, length, clip, options, values_replace, max_size, pbar) + finally: + options.print_schedule = orig_print_schedule + for cond, pooled_dict in schedule_output: + pooled_dict: dict[str] + # add clip_start_percent and clip_end_percent in pooled + pooled_dict["clip_start_percent"] = t_range[0] + pooled_dict["clip_end_percent"] = t_range[1] + full_output.extend(schedule_output) + return full_output + + +def _handle_prompt_interpolation(pairs: list[InputPair], length: int, clip: CLIP, options: PromptOptions, + values_replace: dict[str, list[float]], max_size: int, pbar: ProgressBar): real_holders: list[CondHolder] = [None] * length real_cond = [None] * length real_pooled = [None] * length @@ -373,9 +411,9 @@ def handle_prompt_interpolation(pairs: list[InputPair], length: int, clip: CLIP, real_cond[i] = prev_holder.cond real_pooled[i] = prev_holder.pooled real_holders[i] = prev_holder + pbar.update(1) else: prev_holder = real_holders[i] - pbar.update(1) final_cond = torch.cat(real_cond, dim=0) final_pooled = torch.cat(real_pooled, dim=0) @@ -388,7 +426,10 @@ def handle_prompt_interpolation(pairs: list[InputPair], length: int, clip: CLIP, else: logger.info(f'{i} = ({1.-holder.interp_weight:.2f})"{holder.prompt}" -> ({holder.interp_weight:.2f})"{holder.interp_prompt}"') # cond is a list[list[Tensor, dict[str: Any]]] format - return [[final_cond, {"pooled_output": final_pooled}]] + final_pooled_dict = {"pooled_output": final_pooled} + if options.add_dict is not None: + final_pooled_dict.update(options.add_dict) + return [[final_cond, final_pooled_dict]] def pad_cond(cond: Tensor, target_length: int): From f2f9c90449ef5263c42868ee76e34ebaa6f7a446 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Tue, 12 Nov 2024 16:13:47 -0600 Subject: [PATCH 25/43] Added support for HelloMeme AnimateDiff model loading via simple key conversion --- animatediff/motion_module_ad.py | 80 ++++++++++++++++++++++++++------- 1 file changed, 64 insertions(+), 16 deletions(-) diff --git a/animatediff/motion_module_ad.py b/animatediff/motion_module_ad.py index dc1dfec..da861b8 100644 --- a/animatediff/motion_module_ad.py +++ b/animatediff/motion_module_ad.py @@ -136,6 +136,11 @@ def is_animatelcm(mm_state_dict: dict[str, Tensor]) -> bool: return False return True +def is_hellomeme(mm_state_dict: dict[str, Tensor]) -> bool: + for key in mm_state_dict.keys(): + if "pos_embed" in key: + return True + return False def has_conv_in(mm_state_dict: dict[str, Tensor]) -> bool: # check if conv_in.weight and .bias are present @@ -192,6 +197,14 @@ def find_hotshot_module_num(key: str) -> Union[int, None]: return None +_regex_hellomeme_module_num = re.compile(r'motion_modules\.(\d+)\.') +def find_hellomeme_module_num(key: str) -> Union[int, None]: + found = _regex_hellomeme_module_num.search(key) + if found: + return int(found.group(1)) + return None + + def has_img_encoder(mm_state_dict: dict[str, Tensor]): for key in mm_state_dict.keys(): if key.startswith("img_encoder."): @@ -239,6 +252,8 @@ def normalize_ad_state_dict(mm_state_dict: dict[str, Tensor], mm_name: str) -> T raise ValueError(f"'{mm_name}' is not a valid SD1.5 nor SDXL motion module - contained {down_block_max} downblocks.") # determine the model's format mm_format = AnimateDiffFormat.ANIMATEDIFF + if is_hellomeme(mm_state_dict): + convert_hellomeme_state_dict(mm_state_dict) if is_hotshotxl(mm_state_dict): mm_format = AnimateDiffFormat.HOTSHOTXL if is_animatelcm(mm_state_dict): @@ -260,6 +275,7 @@ def normalize_ad_state_dict(mm_state_dict: dict[str, Tensor], mm_name: str) -> T if mm_format == AnimateDiffFormat.FANCYVIDEO and key in FancyVideoKeys: continue del mm_state_dict[key] + # determine the model's version mm_version = AnimateDiffVersion.V1 if has_mid_block(mm_state_dict): @@ -269,26 +285,58 @@ def normalize_ad_state_dict(mm_state_dict: dict[str, Tensor], mm_name: str) -> T info = AnimateDiffInfo(sd_type=sd_type, mm_format=mm_format, mm_version=mm_version, mm_name=mm_name) # convert to AnimateDiff format, if needed if mm_format == AnimateDiffFormat.HOTSHOTXL: - # HotshotXL is AD-based architecture applied to SDXL instead of SD1.5 - # By renaming the keys, no code needs to be adapted at all - # - # reformat temporal_attentions: - # HSXL: temporal_attentions.#. - # AD: motion_modules.#.temporal_transformer. - # HSXL: pos_encoder.positional_encoding - # AD: pos_encoder.pe - for key in list(mm_state_dict.keys()): - module_num = find_hotshot_module_num(key) - if module_num is not None: - new_key = key.replace(f"temporal_attentions.{module_num}", - f"motion_modules.{module_num}.temporal_transformer", 1) - new_key = new_key.replace("pos_encoder.positional_encoding", "pos_encoder.pe") - mm_state_dict[new_key] = mm_state_dict[key] - del mm_state_dict[key] + convert_hotshot_state_dict(mm_state_dict) # return adjusted mm_state_dict and info return mm_state_dict, info +def convert_hotshot_state_dict(mm_state_dict: dict[str, Tensor]): + # HotshotXL is AD-based architecture applied to SDXL instead of SD1.5 + # By renaming the keys, no code needs to be adapted at all + ################################ + # reformat temporal_attentions: + # HSXL: temporal_attentions.#. + # AD: motion_modules.#.temporal_transformer. + # HSXL: pos_encoder.positional_encoding + # AD: pos_encoder.pe + for key in list(mm_state_dict.keys()): + module_num = find_hotshot_module_num(key) + if module_num is not None: + new_key = key.replace(f"temporal_attentions.{module_num}", + f"motion_modules.{module_num}.temporal_transformer", 1) + new_key = new_key.replace("pos_encoder.positional_encoding", "pos_encoder.pe") + mm_state_dict[new_key] = mm_state_dict[key] + del mm_state_dict[key] + + +def convert_hellomeme_state_dict(mm_state_dict: dict[str, Tensor]): + # HelloMeme is AD-based architecture + for key in list(mm_state_dict.keys()): + module_num = find_hellomeme_module_num(key) + if module_num is not None: + # first, add temporal_transformer everywhere as suffix after motion_modules.#. + new_key = key.replace(f"motion_modules.{module_num}", + f"motion_modules.{module_num}.temporal_transformer") + if "pos_embed" in new_key: + new_key1 = new_key.replace("pos_embed.pe", "attention_blocks.0.pos_encoder.pe") + new_key2 = new_key.replace("pos_embed.pe", "attention_blocks.1.pos_encoder.pe") + mm_state_dict[new_key1] = mm_state_dict[key].clone() + mm_state_dict[new_key2] = mm_state_dict[key].clone() + else: + if "attn1" in new_key: + new_key = new_key.replace("attn1.", "attention_blocks.0.") + elif "attn2" in new_key: + new_key = new_key.replace("attn2.", "attention_blocks.1.") + elif "norm1" in new_key: + new_key = new_key.replace("norm1.", "norms.0.") + elif "norm2" in new_key: + new_key = new_key.replace("norm2.", "norms.1.") + elif "norm3" in new_key: + new_key = new_key.replace("norm3.", "ff_norm.") + mm_state_dict[new_key] = mm_state_dict[key] + del mm_state_dict[key] + + class BlockType: UP = "up" DOWN = "down" From 7ded0a7c7421e85002e685ead01b2a9f11aec426 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Tue, 12 Nov 2024 16:33:49 -0600 Subject: [PATCH 26/43] version bump --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 6ac5bdf..9454971 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "comfyui-animatediff-evolved" description = "Improved AnimateDiff integration for ComfyUI." -version = "1.2.3" +version = "1.3.0" license = { file = "LICENSE" } dependencies = [] From 48db232eb659ed34c470e502d7e30e4bd979c18d Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Tue, 12 Nov 2024 18:14:47 -0600 Subject: [PATCH 27/43] Added DinkLink scaffolding --- __init__.py | 3 +++ animatediff/dinklink.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 32 insertions(+) create mode 100644 animatediff/dinklink.py diff --git a/__init__.py b/__init__.py index b939327..ea6806c 100644 --- a/__init__.py +++ b/__init__.py @@ -3,6 +3,7 @@ from .animatediff.utils_model import get_available_motion_models, Folders from .animatediff.nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS from .animatediff import documentation +from .animatediff.dinklink import init_dinklink if len(get_available_motion_models()) == 0: logger.error(f"No motion models found. Please download one and place in: {folder_paths.get_folder_paths(Folders.ANIMATEDIFF_MODELS)}") @@ -10,3 +11,5 @@ WEB_DIRECTORY = "./web" __all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS", "WEB_DIRECTORY"] documentation.format_descriptions(NODE_CLASS_MAPPINGS) + +init_dinklink() diff --git a/animatediff/dinklink.py b/animatediff/dinklink.py new file mode 100644 index 0000000..5120201 --- /dev/null +++ b/animatediff/dinklink.py @@ -0,0 +1,29 @@ +#################################################################################################### +# DinkLink is my method of sharing classes/functions between my nodes. +# +# My DinkLink-compatible nodes will inject comfy.hooks with a __DINKLINK attr +# that stores a dictionary, where any of my node packs can store their stuff. +# +# It is not intended to be accessed by node packs that I don't develop, so things may change +# at any time. +# +# DinkLink also serves as a proof-of-concept for a future ComfyUI implementation of +# purposely exposing node pack classes/functions with other node packs. +#################################################################################################### +from __future__ import annotations +import comfy.hooks + +DINKLINK = "__DINKLINK" + +def init_dinklink(): + if not hasattr(comfy.hooks, DINKLINK): + setattr(comfy.hooks, DINKLINK, {}) + prepare_dinklink() + + +def get_dinklink() -> dict[str, dict[str]]: + getattr(comfy.hooks, DINKLINK) + + +def prepare_dinklink(): + pass From 815eede84511dcb333bab94bd5acbf506717aae5 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Thu, 14 Nov 2024 03:05:37 -0600 Subject: [PATCH 28/43] Make context_extras an optional param on Set Context Extras node --- animatediff/nodes_context_extras.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/animatediff/nodes_context_extras.py b/animatediff/nodes_context_extras.py index 00562de..2ab3ae8 100644 --- a/animatediff/nodes_context_extras.py +++ b/animatediff/nodes_context_extras.py @@ -17,6 +17,8 @@ def INPUT_TYPES(s): return { "required": { "context_opts": ("CONTEXT_OPTIONS",), + }, + "optional": { "context_extras": ("CONTEXT_EXTRAS",), }, "hidden": { @@ -29,9 +31,10 @@ def INPUT_TYPES(s): CATEGORY = "Animate Diff πŸŽ­πŸ…πŸ…“/context opts/context extras" FUNCTION = "set_context_extras" - def set_context_extras(self, context_opts: ContextOptionsGroup, context_extras: ContextExtrasGroup): + def set_context_extras(self, context_opts: ContextOptionsGroup, context_extras: ContextExtrasGroup=None): context_opts = context_opts.clone() - context_opts.extras = context_extras.clone() + if context_extras is not None: + context_opts.extras = context_extras.clone() return (context_opts,) From 48079783a28e9d283591d66b3c2140e3b99a5c40 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Thu, 14 Nov 2024 12:10:26 -0600 Subject: [PATCH 29/43] Use DinkLink to get ACN wrapper for ContextRef and ControlNet conversion purposes --- animatediff/dinklink.py | 20 +++++++++++++++++++- animatediff/model_injection.py | 15 ++++++++++++++- 2 files changed, 33 insertions(+), 2 deletions(-) diff --git a/animatediff/dinklink.py b/animatediff/dinklink.py index 5120201..d27a1c2 100644 --- a/animatediff/dinklink.py +++ b/animatediff/dinklink.py @@ -22,8 +22,26 @@ def init_dinklink(): def get_dinklink() -> dict[str, dict[str]]: - getattr(comfy.hooks, DINKLINK) + return getattr(comfy.hooks, DINKLINK) + + +class DinkLinkConst: + VERSION = "version" + ACN = "ACN" + ACN_CREATE_OUTER_SAMPLE_WRAPPER = "create_outer_sample_wrapper" def prepare_dinklink(): pass + + +def get_acn_outer_sample_wrapper(throw_exception=True): + d = get_dinklink() + try: + link_acn = d[DinkLinkConst.ACN] + return link_acn[DinkLinkConst.ACN_CREATE_OUTER_SAMPLE_WRAPPER] + except KeyError: + if throw_exception: + raise Exception("Advanced-ControlNet nodes need to be installed to make use of ContextRef; " + \ + "they are either not installed or are of an insufficient version.") + return None diff --git a/animatediff/model_injection.py b/animatediff/model_injection.py index 8eceabe..1a6c555 100644 --- a/animatediff/model_injection.py +++ b/animatediff/model_injection.py @@ -31,6 +31,7 @@ from .motion_lora import MotionLoraInfo, MotionLoraList from .utils_model import get_motion_lora_path, get_motion_model_path, get_sd_model_type, vae_encode_raw_batched from .sample_settings import SampleSettings, SeedNoiseGeneration +from .dinklink import get_acn_outer_sample_wrapper class ModelPatcherHelper: @@ -122,7 +123,19 @@ def get_params(self) -> 'InjectionParams': def set_params(self, params: 'InjectionParams'): self.model.set_attachments(self.PARAMS, params) - + if params.context_options.context_length is not None: + self.set_ACN_outer_sample_wrapper(throw_exception=False) + elif params.context_options.extras.context_ref is not None: + self.set_ACN_outer_sample_wrapper(throw_exception=True) + + def set_ACN_outer_sample_wrapper(self, throw_exception=True): + # get wrapper to register from Advanced-ControlNet via DinkLink shared dict + wrapper_info = get_acn_outer_sample_wrapper(throw_exception) + if wrapper_info is None: + return + wrapper_type, key, wrapper = wrapper_info + if len(self.model.get_wrappers(wrapper_type, key)) == 0: + self.model.add_wrapper_with_key(wrapper_type, key, wrapper) def set_outer_sample_wrapper(self, wrapper: Callable): self.model.remove_wrappers_with_key(WrappersMP.OUTER_SAMPLE, self.ADE) From f64d67466f984a11c5213353d0d54e7f0eca485e Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Thu, 14 Nov 2024 21:59:58 -0600 Subject: [PATCH 30/43] Add proper parent reference to MotionModelPatcher.clone() to work properly with new memory management --- animatediff/model_injection.py | 1 + 1 file changed, 1 insertion(+) diff --git a/animatediff/model_injection.py b/animatediff/model_injection.py index 1a6c555..a8d541e 100644 --- a/animatediff/model_injection.py +++ b/animatediff/model_injection.py @@ -569,6 +569,7 @@ def clone(self): self.backup = n.backup if hasattr(n, "object_patches_backup"): self.object_patches_backup = n.object_patches_backup + n.parent = self # extra cloned params n.timestep_percent_range = self.timestep_percent_range n.timestep_range = self.timestep_range From d70fd110b8444fe2bca2f6b8008bd9d72d8b8b88 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Sat, 16 Nov 2024 16:40:49 -0600 Subject: [PATCH 31/43] Modified comfy.hooks calls to match ComfyUI changes --- animatediff/nodes_conditioning.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/animatediff/nodes_conditioning.py b/animatediff/nodes_conditioning.py index c70cd86..232d614 100644 --- a/animatediff/nodes_conditioning.py +++ b/animatediff/nodes_conditioning.py @@ -111,7 +111,7 @@ def INPUT_TYPES(s): def append_and_hook(self, positive_ADD, negative_ADD, strength: float, set_cond_area: str, opt_mask: Tensor=None, opt_lora_hook: HookGroup=None, opt_timesteps: tuple=None): - final_positive, final_negative = comfy.hooks.set_mask_conds(conds=[positive_ADD, negative_ADD], + final_positive, final_negative = comfy.hooks.set_conds_props(conds=[positive_ADD, negative_ADD], strength=strength, set_cond_area=set_cond_area, opt_mask=opt_mask, opt_hooks=opt_lora_hook, opt_timestep_range=opt_timesteps) return (final_positive, final_negative) @@ -145,7 +145,7 @@ def INPUT_TYPES(s): def append_and_hook(self, cond_ADD, strength: float, set_cond_area: str, opt_mask: Tensor=None, opt_lora_hook: HookGroup=None, opt_timesteps: tuple=None): - (final_conditioning,) = comfy.hooks.set_mask_conds(conds=[cond_ADD], + (final_conditioning,) = comfy.hooks.set_conds_props(conds=[cond_ADD], strength=strength, set_cond_area=set_cond_area, opt_mask=opt_mask, opt_hooks=opt_lora_hook, opt_timestep_range=opt_timesteps) return (final_conditioning,) @@ -183,7 +183,7 @@ def INPUT_TYPES(s): def append_and_combine(self, positive, negative, positive_ADD, negative_ADD, strength: float, set_cond_area: str, opt_mask: Tensor=None, opt_lora_hook: HookGroup=None, opt_timesteps: tuple=None): - final_positive, final_negative = comfy.hooks.set_mask_and_combine_conds(conds=[positive, negative], new_conds=[positive_ADD, negative_ADD], + final_positive, final_negative = comfy.hooks.set_conds_props_and_combine(conds=[positive, negative], new_conds=[positive_ADD, negative_ADD], strength=strength, set_cond_area=set_cond_area, opt_mask=opt_mask, opt_hooks=opt_lora_hook, opt_timestep_range=opt_timesteps) return (final_positive, final_negative,) @@ -218,7 +218,7 @@ def INPUT_TYPES(s): def append_and_combine(self, cond, cond_ADD, strength: float, set_cond_area: str, opt_mask: Tensor=None, opt_lora_hook: HookGroup=None, opt_timesteps: tuple=None): - (final_conditioning,) = comfy.hooks.set_mask_and_combine_conds(conds=[cond], new_conds=[cond_ADD], + (final_conditioning,) = comfy.hooks.set_conds_props_and_combine(conds=[cond], new_conds=[cond_ADD], strength=strength, set_cond_area=set_cond_area, opt_mask=opt_mask, opt_hooks=opt_lora_hook, opt_timestep_range=opt_timesteps) return (final_conditioning,) @@ -251,7 +251,7 @@ def INPUT_TYPES(s): def append_and_combine(self, positive, negative, positive_DEFAULT, negative_DEFAULT, opt_lora_hook: HookGroup=None): - final_positive, final_negative = comfy.hooks.set_default_and_combine_conds(conds=[positive, negative], new_conds=[positive_DEFAULT, negative_DEFAULT], + final_positive, final_negative = comfy.hooks.set_default_conds_and_combine(conds=[positive, negative], new_conds=[positive_DEFAULT, negative_DEFAULT], opt_hooks=opt_lora_hook) return (final_positive, final_negative,) @@ -280,7 +280,7 @@ def INPUT_TYPES(s): def append_and_combine(self, cond, cond_DEFAULT, opt_lora_hook: HookGroup=None): - (final_conditioning,) = comfy.hooks.set_default_and_combine_conds(conds=[cond], new_conds=[cond_DEFAULT], + (final_conditioning,) = comfy.hooks.set_default_conds_and_combine(conds=[cond], new_conds=[cond_DEFAULT], opt_hooks=opt_lora_hook) return (final_conditioning,) @@ -307,7 +307,7 @@ def INPUT_TYPES(s): DEPRECATED = True def combine(self, positive_A, negative_A, positive_B, negative_B): - final_positive, final_negative = comfy.hooks.set_mask_and_combine_conds(conds=[positive_A, negative_A], new_conds=[positive_B, negative_B],) + final_positive, final_negative = comfy.hooks.set_conds_props_and_combine(conds=[positive_A, negative_A], new_conds=[positive_B, negative_B],) return (final_positive, final_negative,) @@ -330,7 +330,7 @@ def INPUT_TYPES(s): DEPRECATED = True def combine(self, cond_A, cond_B): - (final_conditioning,) = comfy.hooks.set_mask_and_combine_conds(conds=[cond_A], new_conds=[cond_B],) + (final_conditioning,) = comfy.hooks.set_conds_props_and_combine(conds=[cond_A], new_conds=[cond_B],) return (final_conditioning,) ############################################### ############################################### From 68ccced6f77a2500fbf0fc2da2bfdbd8bb1a6649 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Sat, 16 Nov 2024 17:38:43 -0600 Subject: [PATCH 32/43] Added support for hooks from CLIP to be applied to resulting conds to respect ComfyUI expected behavior --- animatediff/scheduling.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/animatediff/scheduling.py b/animatediff/scheduling.py index be5e0ab..4f06484 100644 --- a/animatediff/scheduling.py +++ b/animatediff/scheduling.py @@ -429,6 +429,8 @@ def _handle_prompt_interpolation(pairs: list[InputPair], length: int, clip: CLIP final_pooled_dict = {"pooled_output": final_pooled} if options.add_dict is not None: final_pooled_dict.update(options.add_dict) + # add hooks, if needed + clip.add_hooks_to_dict(final_pooled_dict) return [[final_cond, final_pooled_dict]] From 844b4e4488eaafbd253fc780df88a9761659d04b Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Sun, 17 Nov 2024 08:35:45 -0600 Subject: [PATCH 33/43] Make diffusion_model_groupnormed hack use built-in diffusion_model wrapper instead to fix weird issue with ContextRef + ImageInjection --- animatediff/sampling.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/animatediff/sampling.py b/animatediff/sampling.py index 2730e16..7b99e2a 100644 --- a/animatediff/sampling.py +++ b/animatediff/sampling.py @@ -8,13 +8,14 @@ import comfy.model_management import comfy.model_patcher +import comfy.patcher_extension import comfy.samplers import comfy.sampler_helpers import comfy.utils from comfy.controlnet import ControlBase from comfy.model_base import BaseModel from comfy.model_patcher import ModelPatcher -from comfy.patcher_extension import WrapperExecutor +from comfy.patcher_extension import WrapperExecutor, WrappersMP import comfy.conds import comfy.ops @@ -178,11 +179,18 @@ def apply_model_ade_wrapper(self, *args, **kwargs): return orig_apply_model(*args, **kwargs) return apply_model_ade_wrapper -def diffusion_model_forward_groupnormed_factory(orig_diffusion_model_forward: Callable, inject_helper: 'GroupnormInjectHelper'): - def diffusion_model_forward_groupnormed(*args, **kwargs): +def create_diffusion_model_groupnormed_wrapper(model_options: dict, inject_helper: 'GroupnormInjectHelper'): + comfy.patcher_extension.add_wrapper_with_key(WrappersMP.DIFFUSION_MODEL, + "ADE_groupnormed_diffusion_model", + _diffusion_model_groupnormed_wrapper_factory(inject_helper), + model_options, is_model_options=True) + + +def _diffusion_model_groupnormed_wrapper_factory(inject_helper: 'GroupnormInjectHelper'): + def _diffusion_model_groupnormed_wrapper(executor, *args, **kwargs): with inject_helper: - return orig_diffusion_model_forward(*args, **kwargs) - return diffusion_model_forward_groupnormed + return executor(*args, **kwargs) + return _diffusion_model_groupnormed_wrapper ###################################################################### ################################################################################## @@ -237,7 +245,6 @@ def inject_functions(self, helper: ModelPatcherHelper, params: InjectionParams, self.orig_memory_required = None self.orig_groupnorm_forward = torch.nn.GroupNorm.forward # used to normalize latents to remove "flickering" of colors/brightness between frames self.orig_groupnorm_forward_comfy_cast_weights = comfy.ops.disable_weight_init.GroupNorm.forward_comfy_cast_weights - self.orig_diffusion_model_forward = None self.orig_sampling_function = comfy.samplers.sampling_function # used to support sliding context windows in samplers self.orig_apply_model = None # Inject Functions @@ -254,8 +261,7 @@ def inject_functions(self, helper: ModelPatcherHelper, params: InjectionParams, self.inject_groupnorm_forward = groupnorm_mm_factory(params) self.inject_groupnorm_forward_comfy_cast_weights = groupnorm_mm_factory(params, manual_cast=True) self.groupnorm_injector = GroupnormInjectHelper(self) - self.orig_diffusion_model_forward = helper.model.model.diffusion_model.forward - helper.model.model.diffusion_model.forward = diffusion_model_forward_groupnormed_factory(self.orig_diffusion_model_forward, self.groupnorm_injector) + create_diffusion_model_groupnormed_wrapper(model_options, self.groupnorm_injector) # if mps device (Apple Silicon), disable batched conds to avoid black images with groupnorm hack try: if helper.model.load_device.type == "mps": @@ -281,8 +287,6 @@ def restore_functions(self, helper: ModelPatcherHelper): helper.model.model.memory_required = self.orig_memory_required torch.nn.GroupNorm.forward = self.orig_groupnorm_forward comfy.ops.disable_weight_init.GroupNorm.forward_comfy_cast_weights = self.orig_groupnorm_forward_comfy_cast_weights - if self.orig_diffusion_model_forward is not None: - helper.model.model.diffusion_model.forward = self.orig_diffusion_model_forward comfy.samplers.sampling_function = self.orig_sampling_function if self.orig_apply_model is not None: helper.model.model.apply_model = self.orig_apply_model From 054418df0e35e85d962277dd7e04874f24790de3 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Sun, 17 Nov 2024 08:59:41 -0600 Subject: [PATCH 34/43] Replace apply_model_factory with special_model_apply_model_wrapper that uses built-in ComfyUI wrapper for apply_model --- animatediff/sampling.py | 41 +++++++++++++++++++++++++---------------- 1 file changed, 25 insertions(+), 16 deletions(-) diff --git a/animatediff/sampling.py b/animatediff/sampling.py index 7b99e2a..5f307db 100644 --- a/animatediff/sampling.py +++ b/animatediff/sampling.py @@ -165,19 +165,32 @@ def groupnorm_mm_forward(self, input: Tensor) -> Tensor: return input return groupnorm_mm_forward -def apply_model_factory(orig_apply_model: Callable): - def apply_model_ade_wrapper(self, *args, **kwargs): - x: Tensor = args[0] - cond_or_uncond = kwargs["transformer_options"]["cond_or_uncond"] - ad_params = kwargs["transformer_options"]["ad_params"] - ADGS: AnimateDiffGlobalState = kwargs["transformer_options"]["ADGS"] - if ADGS.motion_models is not None: +def create_special_model_apply_model_wrapper(model_options: dict): + comfy.patcher_extension.add_wrapper_with_key(WrappersMP.APPLY_MODEL, + "ADE_special_model_apply_model", + _apply_model_wrapper, + model_options, is_model_options=True) + +def _apply_model_wrapper(executor, *args, **kwargs): + # args (from BaseModel._apply_model): + # 0: x + # 1: t + # 2: c_concat + # 3: c_crossattn + # 4: control + # 5: transformer_options + x: Tensor = args[0] + transformer_options = args[5] + cond_or_uncond = transformer_options["cond_or_uncond"] + ad_params = transformer_options["ad_params"] + ADGS: AnimateDiffGlobalState = transformer_options["ADGS"] + if ADGS.motion_models is not None: for motion_model in ADGS.motion_models.models: - motion_model.prepare_alcmi2v_features(x=x, cond_or_uncond=cond_or_uncond, ad_params=ad_params, latent_format=self.latent_format) + motion_model.prepare_alcmi2v_features(x=x, cond_or_uncond=cond_or_uncond, ad_params=ad_params, latent_format=executor.class_obj.latent_format) motion_model.prepare_camera_features(x=x, cond_or_uncond=cond_or_uncond, ad_params=ad_params) - del x - return orig_apply_model(*args, **kwargs) - return apply_model_ade_wrapper + del x + return executor(*args, **kwargs) + def create_diffusion_model_groupnormed_wrapper(model_options: dict, inject_helper: 'GroupnormInjectHelper'): comfy.patcher_extension.add_wrapper_with_key(WrappersMP.DIFFUSION_MODEL, @@ -246,7 +259,6 @@ def inject_functions(self, helper: ModelPatcherHelper, params: InjectionParams, self.orig_groupnorm_forward = torch.nn.GroupNorm.forward # used to normalize latents to remove "flickering" of colors/brightness between frames self.orig_groupnorm_forward_comfy_cast_weights = comfy.ops.disable_weight_init.GroupNorm.forward_comfy_cast_weights self.orig_sampling_function = comfy.samplers.sampling_function # used to support sliding context windows in samplers - self.orig_apply_model = None # Inject Functions if params.unlimited_area_hack: # allows for "unlimited area hack" to prevent halving of conds/unconds @@ -272,8 +284,7 @@ def inject_functions(self, helper: ModelPatcherHelper, params: InjectionParams, # if img_encoder or camera_encoder present, inject apply_model to handle correctly for motion_model in helper.get_motion_models(): if (motion_model.model.img_encoder is not None) or (motion_model.model.camera_encoder is not None): - self.orig_apply_model = helper.model.model.apply_model - helper.model.model.apply_model = apply_model_factory(self.orig_apply_model).__get__(helper.model.model, type(helper.model.model)) + create_special_model_apply_model_wrapper(model_options) break del info comfy.samplers.sampling_function = evolved_sampling_function @@ -288,8 +299,6 @@ def restore_functions(self, helper: ModelPatcherHelper): torch.nn.GroupNorm.forward = self.orig_groupnorm_forward comfy.ops.disable_weight_init.GroupNorm.forward_comfy_cast_weights = self.orig_groupnorm_forward_comfy_cast_weights comfy.samplers.sampling_function = self.orig_sampling_function - if self.orig_apply_model is not None: - helper.model.model.apply_model = self.orig_apply_model except AttributeError: logger.error("Encountered AttributeError while attempting to restore functions - likely, an error occured while trying " + \ "to save original functions before injection, and a more specific error was thrown by ComfyUI.") From 066bb23b432686ab5614f618e71fcf378616d532 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Sun, 17 Nov 2024 09:01:40 -0600 Subject: [PATCH 35/43] Clean up perform_image_injection, as backwards compatible code no longer needed --- animatediff/sampling.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/animatediff/sampling.py b/animatediff/sampling.py index 5f307db..1e84e78 100644 --- a/animatediff/sampling.py +++ b/animatediff/sampling.py @@ -539,10 +539,7 @@ def evolved_sampling_function(model, x: Tensor, timestep: Tensor, uncond, cond, def perform_image_injection(ADGS: AnimateDiffGlobalState, model: BaseModel, latents: Tensor, to_inject: NoisedImageToInject) -> Tensor: # NOTE: the latents here have already been process_latent_out'ed # get currently used models so they can be properly reloaded after perfoming VAE Encoding - if hasattr(comfy.model_management, "loaded_models"): - cached_loaded_models = comfy.model_management.loaded_models(only_currently_used=True) - else: - cached_loaded_models: list[ModelPatcher] = [x.model for x in comfy.model_management.current_loaded_models] + cached_loaded_models = comfy.model_management.loaded_models(only_currently_used=True) try: orig_device = latents.device orig_dtype = latents.dtype From d06a40877a631ed0a5b0a01c369e248a170e7c41 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Sun, 17 Nov 2024 09:08:16 -0600 Subject: [PATCH 36/43] Adjusted spacing --- animatediff/sampling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/animatediff/sampling.py b/animatediff/sampling.py index 1e84e78..e5995a4 100644 --- a/animatediff/sampling.py +++ b/animatediff/sampling.py @@ -165,6 +165,7 @@ def groupnorm_mm_forward(self, input: Tensor) -> Tensor: return input return groupnorm_mm_forward + def create_special_model_apply_model_wrapper(model_options: dict): comfy.patcher_extension.add_wrapper_with_key(WrappersMP.APPLY_MODEL, "ADE_special_model_apply_model", @@ -198,7 +199,6 @@ def create_diffusion_model_groupnormed_wrapper(model_options: dict, inject_helpe _diffusion_model_groupnormed_wrapper_factory(inject_helper), model_options, is_model_options=True) - def _diffusion_model_groupnormed_wrapper_factory(inject_helper: 'GroupnormInjectHelper'): def _diffusion_model_groupnormed_wrapper(executor, *args, **kwargs): with inject_helper: From 32a06f1e15ae29588d0204a27ee8bfb13161426a Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Thu, 21 Nov 2024 15:50:47 -0600 Subject: [PATCH 37/43] Match changes in ComfyUI PR that remove opt_ prefix on optional comfy.hook function params --- animatediff/nodes_conditioning.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/animatediff/nodes_conditioning.py b/animatediff/nodes_conditioning.py index 232d614..d05b612 100644 --- a/animatediff/nodes_conditioning.py +++ b/animatediff/nodes_conditioning.py @@ -113,7 +113,7 @@ def append_and_hook(self, positive_ADD, negative_ADD, opt_mask: Tensor=None, opt_lora_hook: HookGroup=None, opt_timesteps: tuple=None): final_positive, final_negative = comfy.hooks.set_conds_props(conds=[positive_ADD, negative_ADD], strength=strength, set_cond_area=set_cond_area, - opt_mask=opt_mask, opt_hooks=opt_lora_hook, opt_timestep_range=opt_timesteps) + mask=opt_mask, hooks=opt_lora_hook, timesteps_range=opt_timesteps) return (final_positive, final_negative) @@ -147,7 +147,7 @@ def append_and_hook(self, cond_ADD, opt_mask: Tensor=None, opt_lora_hook: HookGroup=None, opt_timesteps: tuple=None): (final_conditioning,) = comfy.hooks.set_conds_props(conds=[cond_ADD], strength=strength, set_cond_area=set_cond_area, - opt_mask=opt_mask, opt_hooks=opt_lora_hook, opt_timestep_range=opt_timesteps) + mask=opt_mask, hooks=opt_lora_hook, timesteps_range=opt_timesteps) return (final_conditioning,) @@ -185,7 +185,7 @@ def append_and_combine(self, positive, negative, positive_ADD, negative_ADD, opt_mask: Tensor=None, opt_lora_hook: HookGroup=None, opt_timesteps: tuple=None): final_positive, final_negative = comfy.hooks.set_conds_props_and_combine(conds=[positive, negative], new_conds=[positive_ADD, negative_ADD], strength=strength, set_cond_area=set_cond_area, - opt_mask=opt_mask, opt_hooks=opt_lora_hook, opt_timestep_range=opt_timesteps) + mask=opt_mask, hooks=opt_lora_hook, timesteps_range=opt_timesteps) return (final_positive, final_negative,) @@ -220,7 +220,7 @@ def append_and_combine(self, cond, cond_ADD, opt_mask: Tensor=None, opt_lora_hook: HookGroup=None, opt_timesteps: tuple=None): (final_conditioning,) = comfy.hooks.set_conds_props_and_combine(conds=[cond], new_conds=[cond_ADD], strength=strength, set_cond_area=set_cond_area, - opt_mask=opt_mask, opt_hooks=opt_lora_hook, opt_timestep_range=opt_timesteps) + mask=opt_mask, hooks=opt_lora_hook, timesteps_range=opt_timesteps) return (final_conditioning,) @@ -252,7 +252,7 @@ def INPUT_TYPES(s): def append_and_combine(self, positive, negative, positive_DEFAULT, negative_DEFAULT, opt_lora_hook: HookGroup=None): final_positive, final_negative = comfy.hooks.set_default_conds_and_combine(conds=[positive, negative], new_conds=[positive_DEFAULT, negative_DEFAULT], - opt_hooks=opt_lora_hook) + hooks=opt_lora_hook) return (final_positive, final_negative,) @@ -281,7 +281,7 @@ def INPUT_TYPES(s): def append_and_combine(self, cond, cond_DEFAULT, opt_lora_hook: HookGroup=None): (final_conditioning,) = comfy.hooks.set_default_conds_and_combine(conds=[cond], new_conds=[cond_DEFAULT], - opt_hooks=opt_lora_hook) + hooks=opt_lora_hook) return (final_conditioning,) From 0a2d1291c9081d7f42d765d6bba2c25e8cddfeb7 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Wed, 27 Nov 2024 23:31:55 -0600 Subject: [PATCH 38/43] Refactored MotionModelPatcher into MotionModelAttachment so that no custom ModelPatcher model is required; no more mismatches between ComfyUI and ADE ModelPatcher features. Also made Scale Ref Image and VAE Encode node use batched vae encoding --- animatediff/model_injection.py | 216 +++++++++++++++-------------- animatediff/nodes_animatelcmi2v.py | 18 +-- animatediff/nodes_cameractrl.py | 7 +- animatediff/nodes_fancyvideo.py | 7 +- animatediff/nodes_gen2.py | 13 +- animatediff/nodes_pia.py | 9 +- animatediff/utils_model.py | 1 - 7 files changed, 137 insertions(+), 134 deletions(-) diff --git a/animatediff/model_injection.py b/animatediff/model_injection.py index a8d541e..4a6cad2 100644 --- a/animatediff/model_injection.py +++ b/animatediff/model_injection.py @@ -14,7 +14,7 @@ import comfy.model_management import comfy.utils from comfy.model_patcher import ModelPatcher -from comfy.patcher_extension import WrappersMP, PatcherInjection +from comfy.patcher_extension import CallbacksMP, WrappersMP, PatcherInjection from comfy.model_base import BaseModel from comfy.sd import CLIP, VAE @@ -34,6 +34,12 @@ from .dinklink import get_acn_outer_sample_wrapper +class MotionModelPatcher(ModelPatcher): + '''Class used only for type hints.''' + def __init__(self): + self.model: AnimateDiffModel + + class ModelPatcherHelper: SAMPLE_SETTINGS = "ADE_sample_settings" PARAMS = "ADE_params" @@ -55,13 +61,10 @@ def set_all_properties(self, outer_sampler_wrapper: Callable, calc_cond_batch_wr self.remove_motion_models() self.remove_forward_timestep_embed_patch() - def get_adgs(self): - pass - - def get_motion_models(self) -> list['MotionModelPatcher']: + def get_motion_models(self) -> list[MotionModelPatcher]: return self.model.additional_models.get(self.ADE, []) - def set_motion_models(self, motion_models: list['MotionModelPatcher']): + def set_motion_models(self, motion_models: list[MotionModelPatcher]): self.model.set_additional_models(self.ADE, motion_models) self.model.set_injections(self.ADE, [PatcherInjection(inject=inject_motion_models, eject=eject_motion_models)]) @@ -152,7 +155,7 @@ def remove_wrappers(self): def pre_run(self): # TODO: could implement this as a ModelPatcher ON_PRE_RUN callback for motion_model in self.get_motion_models(): - motion_model.pre_run(self.model) + motion_model.pre_run() self.get_sample_settings().pre_run(self.model) @@ -177,11 +180,61 @@ def forward_timestep_embed_patch_ade(layer, x, emb, context, transformer_options return layer(x, context) -class MotionModelPatcher(ModelPatcher): - # Mostly here so that type hints work in IDEs - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.model: AnimateDiffModel = self.model +def create_MotionModelPatcher(model, load_device, offload_device) -> MotionModelPatcher: + patcher = ModelPatcher(model, load_device=load_device, offload_device=offload_device) + ade = ModelPatcherHelper.ADE + patcher.add_callback_with_key(CallbacksMP.ON_LOAD, ade, _mm_patch_lowvram_extras_callback) + patcher.add_callback_with_key(CallbacksMP.ON_LOAD, ade, _mm_handle_float8_pe_tensors_callback) + patcher.add_callback_with_key(CallbacksMP.ON_PRE_RUN, ade, _mm_pre_run_callback) + patcher.add_callback_with_key(CallbacksMP.ON_CLEANUP, ade, _mm_clean_callback) + patcher.set_attachments(ade, MotionModelAttachment()) + return patcher + + +def _mm_patch_lowvram_extras_callback(self: MotionModelPatcher, device_to, lowvram_model_memory, *args, **kwargs): + if lowvram_model_memory > 0: + # figure out the tensors (likely pe's) that should be cast to device besides just the named_modules + remaining_tensors = list(self.model.state_dict().keys()) + named_modules = [] + for n, _ in self.model.named_modules(): + named_modules.append(n) + named_modules.append(f"{n}.weight") + named_modules.append(f"{n}.bias") + for name in named_modules: + if name in remaining_tensors: + remaining_tensors.remove(name) + + for key in remaining_tensors: + self.patch_weight_to_device(key, device_to) + if device_to is not None: + comfy.utils.set_attr(self.model, key, comfy.utils.get_attr(self.model, key).to(device_to)) + +def _mm_handle_float8_pe_tensors_callback(self: MotionModelPatcher, *args, **kwargs): + remaining_tensors = list(self.model.state_dict().keys()) + pe_tensors = [x for x in remaining_tensors if '.pe' in x] + is_first = True + for key in pe_tensors: + if is_first: + is_first = False + if comfy.utils.get_attr(self.model, key).dtype not in [torch.float8_e5m2, torch.float8_e4m3fn]: + break + comfy.utils.set_attr(self.model, key, comfy.utils.get_attr(self.model, key).half()) + +def _mm_pre_run_callback(self: MotionModelPatcher, *args, **kwargs): + attachment = get_mm_attachment(self) + attachment.pre_run(self) + +def _mm_clean_callback(self: MotionModelPatcher, *args, **kwargs): + attachment = get_mm_attachment(self) + attachment.cleanup(self) + + +def get_mm_attachment(patcher: MotionModelPatcher) -> 'MotionModelAttachment': + return patcher.get_attachment(ModelPatcherHelper.ADE) + + +class MotionModelAttachment: + def __init__(self): self.timestep_percent_range = (0.0, 1.0) self.timestep_range: tuple[float, float] = None self.keyframes: ADKeyframeGroup = ADKeyframeGroup() @@ -239,49 +292,14 @@ def __init__(self, *args, **kwargs): self.prev_sub_idxs = None self.prev_batched_number = None - def load(self, device_to=None, lowvram_model_memory=0, *args, **kwargs): - to_return = super().load(device_to=device_to, lowvram_model_memory=lowvram_model_memory, *args, **kwargs) - if lowvram_model_memory > 0: - self._patch_lowvram_extras(device_to=device_to) - self._handle_float8_pe_tensors() - return to_return - - def _patch_lowvram_extras(self, device_to=None): - # figure out the tensors (likely pe's) that should be cast to device besides just the named_modules - remaining_tensors = list(self.model.state_dict().keys()) - named_modules = [] - for n, _ in self.model.named_modules(): - named_modules.append(n) - named_modules.append(f"{n}.weight") - named_modules.append(f"{n}.bias") - for name in named_modules: - if name in remaining_tensors: - remaining_tensors.remove(name) - - for key in remaining_tensors: - self.patch_weight_to_device(key, device_to) - if device_to is not None: - comfy.utils.set_attr(self.model, key, comfy.utils.get_attr(self.model, key).to(device_to)) - - def _handle_float8_pe_tensors(self): - remaining_tensors = list(self.model.state_dict().keys()) - pe_tensors = [x for x in remaining_tensors if '.pe' in x] - is_first = True - for key in pe_tensors: - if is_first: - is_first = False - if comfy.utils.get_attr(self.model, key).dtype not in [torch.float8_e5m2, torch.float8_e4m3fn]: - break - comfy.utils.set_attr(self.model, key, comfy.utils.get_attr(self.model, key).half()) - - def pre_run(self, model: ModelPatcher): - self.cleanup() - self.model.set_scale(self.scale_multival, self.per_block_list) - self.model.set_effect(self.effect_multival, self.per_block_list) - self.model.set_cameractrl_effect(self.cameractrl_multival) - if self.model.img_encoder is not None: - self.model.img_encoder.set_ref_drift(self.orig_ref_drift) - self.model.img_encoder.set_insertion_weights(self.orig_insertion_weights) + def pre_run(self, patcher: MotionModelPatcher): + self.cleanup(patcher) + patcher.model.set_scale(self.scale_multival, self.per_block_list) + patcher.model.set_effect(self.effect_multival, self.per_block_list) + patcher.model.set_cameractrl_effect(self.cameractrl_multival) + if patcher.model.img_encoder is not None: + patcher.model.img_encoder.set_ref_drift(self.orig_ref_drift) + patcher.model.img_encoder.set_insertion_weights(self.orig_insertion_weights) def initialize_timesteps(self, model: BaseModel): self.timestep_range = (model.model_sampling.percent_to_sigma(self.timestep_percent_range[0]), @@ -290,7 +308,7 @@ def initialize_timesteps(self, model: BaseModel): for keyframe in self.keyframes.keyframes: keyframe.start_t = model.model_sampling.percent_to_sigma(keyframe.start_percent) - def prepare_current_keyframe(self, x: Tensor, t: Tensor): + def prepare_current_keyframe(self, patcher: MotionModelPatcher, x: Tensor, t: Tensor): curr_t: float = t[0] # if curr_t was previous_t, then do nothing (already accounted for this step) if curr_t == self.previous_t: @@ -340,26 +358,26 @@ def prepare_current_keyframe(self, x: Tensor, t: Tensor): self.combined_pia_mask = get_combined_input(self.pia_input, self.current_pia_input, x) self.combined_pia_effect = get_combined_input_effect_multival(self.pia_input, self.current_pia_input) # apply scale and effect - self.model.set_scale(self.combined_scale, self.per_block_list) - self.model.set_effect(self.combined_effect, self.per_block_list) # TODO: set combined_per_block_list - self.model.set_cameractrl_effect(self.combined_cameractrl_effect) + patcher.model.set_scale(self.combined_scale, self.per_block_list) + patcher.model.set_effect(self.combined_effect, self.per_block_list) # TODO: set combined_per_block_list + patcher.model.set_cameractrl_effect(self.combined_cameractrl_effect) # apply effect - if not within range, set effect to 0, effectively turning model off if curr_t > self.timestep_range[0] or curr_t < self.timestep_range[1]: - self.model.set_effect(0.0) + patcher.model.set_effect(0.0) self.was_within_range = False else: # if was not in range last step, apply effect to toggle AD status if not self.was_within_range: - self.model.set_effect(self.combined_effect, self.per_block_list) + patcher.model.set_effect(self.combined_effect, self.per_block_list) self.was_within_range = True # update steps current keyframe is used self.current_used_steps += 1 # update previous_t self.previous_t = curr_t - def prepare_alcmi2v_features(self, x: Tensor, cond_or_uncond: list[int], ad_params: dict[str], latent_format): + def prepare_alcmi2v_features(self, patcher: MotionModelPatcher, x: Tensor, cond_or_uncond: list[int], ad_params: dict[str], latent_format): # if no img_encoder, done - if self.model.img_encoder is None: + if patcher.model.img_encoder is None: return batched_number = len(cond_or_uncond) full_length = ad_params["full_length"] @@ -372,20 +390,20 @@ def prepare_alcmi2v_features(self, x: Tensor, cond_or_uncond: list[int], ad_para img_latents = comfy.utils.common_upscale(self.orig_img_latents[sub_idxs], x.shape[3], x.shape[2], 'nearest-exact', 'center').to(x.dtype).to(x.device) else: img_latents = comfy.utils.common_upscale(self.orig_img_latents, x.shape[3], x.shape[2], 'nearest-exact', 'center').to(x.dtype).to(x.device) - img_latents = latent_format.process_in(img_latents) + img_latents: Tensor = latent_format.process_in(img_latents) # make sure img_latents matches goal_length if goal_length != img_latents.shape[0]: img_latents = ade_broadcast_image_to(img_latents, goal_length, batched_number) - img_features = self.model.img_encoder(img_latents, goal_length, batched_number) - self.model.set_img_features(img_features=img_features, apply_ref_when_disabled=self.orig_apply_ref_when_disabled) + img_features = patcher.model.img_encoder(img_latents, goal_length, batched_number) + patcher.model.set_img_features(img_features=img_features, apply_ref_when_disabled=self.orig_apply_ref_when_disabled) # cache values for next step self.img_latents_shape = img_latents.shape self.prev_sub_idxs = sub_idxs self.prev_batched_number = batched_number - def prepare_camera_features(self, x: Tensor, cond_or_uncond: list[int], ad_params: dict[str]): + def prepare_camera_features(self, patcher: MotionModelPatcher, x: Tensor, cond_or_uncond: list[int], ad_params: dict[str]): # if no camera_encoder, done - if self.model.camera_encoder is None: + if patcher.model.camera_encoder is None: return batched_number = len(cond_or_uncond) full_length = ad_params["full_length"] @@ -410,8 +428,8 @@ def prepare_camera_features(self, x: Tensor, cond_or_uncond: list[int], ad_param # create encoded embeddings b, c, h, w = x.shape plucker_embedding = prepare_pose_embedding(camera_poses, image_width=w*8, image_height=h*8).to(dtype=x.dtype, device=x.device) - camera_embedding = self.model.camera_encoder(plucker_embedding, video_length=goal_length, batched_number=batched_number) - self.model.set_camera_features(camera_features=camera_embedding) + camera_embedding = patcher.model.camera_encoder(plucker_embedding, video_length=goal_length, batched_number=batched_number) + patcher.model.set_camera_features(camera_features=camera_embedding) self.camera_features_shape = len(camera_embedding) self.prev_sub_idxs = sub_idxs self.prev_batched_number = batched_number @@ -517,16 +535,15 @@ def get_fancy_c_concat(self, model: BaseModel, x: Tensor) -> Tensor: finally: comfy.model_management.load_models_gpu(cached_loaded_models) - def is_pia(self): - return self.model.mm_info.mm_format == AnimateDiffFormat.PIA and self.orig_pia_images is not None + def is_pia(self, patcher: MotionModelPatcher): + return patcher.model.mm_info.mm_format == AnimateDiffFormat.PIA and self.orig_pia_images is not None - def is_fancyvideo(self): - return self.model.mm_info.mm_format == AnimateDiffFormat.FANCYVIDEO + def is_fancyvideo(self, patcher: MotionModelPatcher): + return patcher.model.mm_info.mm_format == AnimateDiffFormat.FANCYVIDEO - def cleanup(self): - super().cleanup() - if self.model is not None: - self.model.cleanup() + def cleanup(self, patcher: MotionModelPatcher): + if patcher.model is not None: + patcher.model.cleanup() # AnimateLCM-I2V del self.img_features self.img_features = None @@ -552,24 +569,8 @@ def cleanup(self): self.prev_sub_idxs = None self.prev_batched_number = None - def clone(self): - # normal ModelPatcher clone actions - n = MotionModelPatcher(self.model, self.load_device, self.offload_device, self.size, weight_inplace_update=self.weight_inplace_update) - n.patches = {} - for k in self.patches: - n.patches[k] = self.patches[k][:] - if hasattr(n, "patches_uuid"): - self.patches_uuid = n.patches_uuid - - n.object_patches = self.object_patches.copy() - n.model_options = copy.deepcopy(self.model_options) - if hasattr(n, "model_keys"): - n.model_keys = self.model_keys - if hasattr(n, "backup"): - self.backup = n.backup - if hasattr(n, "object_patches_backup"): - self.object_patches_backup = n.object_patches_backup - n.parent = self + def on_model_patcher_clone(self): + n = MotionModelAttachment() # extra cloned params n.timestep_percent_range = self.timestep_percent_range n.timestep_range = self.timestep_range @@ -635,11 +636,12 @@ def set_video_length(self, video_length: int, full_length: int): def initialize_timesteps(self, model: BaseModel): for motion_model in self.models: - motion_model.initialize_timesteps(model) + attachment = get_mm_attachment(motion_model) + attachment.initialize_timesteps(model) def pre_run(self, model: ModelPatcher): for motion_model in self.models: - motion_model.pre_run(model) + motion_model.pre_run() def cleanup(self): for motion_model in self.models: @@ -647,12 +649,14 @@ def cleanup(self): def prepare_current_keyframe(self, x: Tensor, t: Tensor): for motion_model in self.models: - motion_model.prepare_current_keyframe(x=x, t=t) + attachment = get_mm_attachment(motion_model) + attachment.prepare_current_keyframe(motion_model, x=x, t=t) def get_special_models(self): pia_motion_models: list[MotionModelPatcher] = [] for motion_model in self.models: - if motion_model.is_pia() or motion_model.is_fancyvideo(): + attachment = get_mm_attachment(motion_model) + if attachment.is_pia(motion_model) or attachment.is_fancyvideo(motion_model): pia_motion_models.append(motion_model) return pia_motion_models @@ -759,7 +763,7 @@ def load_motion_module_gen1(model_name: str, model: ModelPatcher, motion_lora: M load_result = ad_wrapper.load_state_dict(mm_state_dict, strict=False) verify_load_result(load_result=load_result, mm_info=mm_info) # wrap motion_module into a ModelPatcher, to allow motion lora patches - motion_model = MotionModelPatcher(model=ad_wrapper, load_device=model.load_device, offload_device=model.offload_device) + motion_model = create_MotionModelPatcher(model=ad_wrapper, load_device=model.load_device, offload_device=model.offload_device) # load motion_lora, if present if motion_lora is not None: for lora in motion_lora.loras: @@ -783,8 +787,8 @@ def load_motion_module_gen2(model_name: str, motion_model_settings: AnimateDiffS load_result = ad_wrapper.load_state_dict(mm_state_dict, strict=False) verify_load_result(load_result=load_result, mm_info=mm_info) # wrap motion_module into a ModelPatcher, to allow motion lora patches - motion_model = MotionModelPatcher(model=ad_wrapper, load_device=comfy.model_management.get_torch_device(), - offload_device=comfy.model_management.unet_offload_device()) + motion_model = create_MotionModelPatcher(model=ad_wrapper, load_device=comfy.model_management.get_torch_device(), + offload_device=comfy.model_management.unet_offload_device()) return motion_model @@ -823,7 +827,7 @@ def create_fresh_motion_module(motion_model: MotionModelPatcher) -> MotionModelP ad_wrapper.to(comfy.model_management.unet_dtype()) ad_wrapper.to(comfy.model_management.unet_offload_device()) ad_wrapper.load_state_dict(motion_model.model.state_dict()) - return MotionModelPatcher(model=ad_wrapper, load_device=comfy.model_management.get_torch_device(), + return create_MotionModelPatcher(model=ad_wrapper, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device()) @@ -832,8 +836,8 @@ def create_fresh_encoder_only_model(motion_model: MotionModelPatcher) -> MotionM ad_wrapper.to(comfy.model_management.unet_dtype()) ad_wrapper.to(comfy.model_management.unet_offload_device()) ad_wrapper.load_state_dict(motion_model.model.state_dict(), strict=False) - return MotionModelPatcher(model=ad_wrapper, load_device=comfy.model_management.get_torch_device(), - offload_device=comfy.model_management.unet_offload_device()) + return create_MotionModelPatcher(model=ad_wrapper, load_device=comfy.model_management.get_torch_device(), + offload_device=comfy.model_management.unet_offload_device()) def inject_img_encoder_into_model(motion_model: MotionModelPatcher, w_encoder: MotionModelPatcher): diff --git a/animatediff/nodes_animatelcmi2v.py b/animatediff/nodes_animatelcmi2v.py index 22a91e4..ee62d51 100644 --- a/animatediff/nodes_animatelcmi2v.py +++ b/animatediff/nodes_animatelcmi2v.py @@ -7,10 +7,10 @@ from .ad_settings import AnimateDiffSettings from .logger import logger -from .utils_model import ScaleMethods, CropMethods, get_available_motion_models +from .utils_model import ScaleMethods, CropMethods, get_available_motion_models, vae_encode_raw_batched from .utils_motion import ADKeyframeGroup from .motion_lora import MotionLoraList -from .model_injection import (MotionModelGroup, MotionModelPatcher, create_fresh_encoder_only_model, +from .model_injection import (MotionModelGroup, MotionModelPatcher, get_mm_attachment, create_fresh_encoder_only_model, load_motion_module_gen2, inject_img_encoder_into_model) from .motion_module_ad import AnimateDiffFormat from .nodes_gen2 import ApplyAnimateDiffModelNode @@ -58,9 +58,10 @@ def apply_motion_model(self, motion_model: MotionModelPatcher, ref_latent: dict, # confirm that model contains img_encoder if curr_model.model.img_encoder is None: raise Exception(f"Motion model '{curr_model.model.mm_info.mm_name}' does not contain an img_encoder; cannot be used with Apply AnimateLCM-I2V Model node.") - curr_model.orig_img_latents = ref_latent["samples"] - curr_model.orig_ref_drift = ref_drift - curr_model.orig_apply_ref_when_disabled = apply_ref_when_disabled + attachment = get_mm_attachment(curr_model) + attachment.orig_img_latents = ref_latent["samples"] + attachment.orig_ref_drift = ref_drift + attachment.orig_apply_ref_when_disabled = apply_ref_when_disabled return new_m_models @@ -148,9 +149,4 @@ def preprocess_images(self, image: torch.Tensor, vae: VAE, latent_size: torch.Te image = comfy.utils.common_upscale(samples=image, width=w*8, height=h*8, upscale_method=scale_method, crop=crop) image = image.movedim(1,-1) # now that images are the expected size, VAEEncode them - try: # account for old ComfyUI versions (TODO: remove this when other changes require ComfyUI update) - if not hasattr(vae, "vae_encode_crop_pixels"): - image = VAEEncode.vae_encode_crop_pixels(image) - except Exception: - pass - return ({"samples": vae.encode(image[:,:,:,:3])},) + return ({"samples": vae_encode_raw_batched(vae, image)},) diff --git a/animatediff/nodes_cameractrl.py b/animatediff/nodes_cameractrl.py index cd64b0d..7b7bb9a 100644 --- a/animatediff/nodes_cameractrl.py +++ b/animatediff/nodes_cameractrl.py @@ -16,7 +16,7 @@ from .utils_model import get_available_motion_models, calculate_file_hash, strip_path, BIGMAX from .utils_motion import ADKeyframeGroup from .motion_lora import MotionLoraList -from .model_injection import (MotionModelGroup, MotionModelPatcher, load_motion_module_gen2, inject_camera_encoder_into_model) +from .model_injection import (MotionModelGroup, MotionModelPatcher, get_mm_attachment, load_motion_module_gen2, inject_camera_encoder_into_model) from .nodes_gen2 import ApplyAnimateDiffModelNode, ADKeyframeNode @@ -230,8 +230,9 @@ def apply_motion_model(self, motion_model: MotionModelPatcher, cameractrl_poses: if curr_model.model.camera_encoder is None: raise Exception(f"Motion model '{curr_model.model.mm_info.mm_name}' does not contain a camera_encoder; cannot be used with Apply AnimateDiff-CameraCtrl Model node.") camera_entries = [CameraEntry(entry) for entry in cameractrl_poses] - curr_model.orig_camera_entries = camera_entries - curr_model.cameractrl_multival = cameractrl_multival + attachment = get_mm_attachment(curr_model) + attachment.orig_camera_entries = camera_entries + attachment.cameractrl_multival = cameractrl_multival return new_m_models diff --git a/animatediff/nodes_fancyvideo.py b/animatediff/nodes_fancyvideo.py index 6d8189f..4b2709e 100644 --- a/animatediff/nodes_fancyvideo.py +++ b/animatediff/nodes_fancyvideo.py @@ -10,7 +10,7 @@ from .utils_model import BIGMIN, BIGMAX, get_available_motion_models from .utils_motion import ADKeyframeGroup, InputPIA, InputPIA_Multival, extend_list_to_batch_size, extend_to_batch_size, prepare_mask_batch from .motion_lora import MotionLoraList -from .model_injection import MotionModelGroup, MotionModelPatcher, load_motion_module_gen2, inject_pia_conv_in_into_model +from .model_injection import MotionModelGroup, MotionModelPatcher, get_mm_attachment, load_motion_module_gen2, inject_pia_conv_in_into_model from .motion_module_ad import AnimateDiffFormat from .nodes_gen2 import ApplyAnimateDiffModelNode, ADKeyframeNode @@ -59,6 +59,7 @@ def apply_motion_model(self, motion_model: MotionModelPatcher, image: Tensor, va # confirm that model is FancyVideo if curr_model.model.mm_info.mm_format != AnimateDiffFormat.FANCYVIDEO: raise Exception(f"Motion model '{curr_model.model.mm_info.mm_name}' is not a FancyVideo model; cannot be used with Apply AD-FancyModel Model node.") - curr_model.orig_fancy_images = image - curr_model.fancy_vae = vae + attachment = get_mm_attachment(curr_model) + attachment.orig_fancy_images = image + attachment.fancy_vae = vae return new_m_models diff --git a/animatediff/nodes_gen2.py b/animatediff/nodes_gen2.py index b2e8b6a..19ed454 100644 --- a/animatediff/nodes_gen2.py +++ b/animatediff/nodes_gen2.py @@ -11,7 +11,7 @@ from .motion_lora import MotionLoraList from .motion_module_ad import AllPerBlocks from .model_injection import (ModelPatcherHelper, - InjectionParams, MotionModelGroup, MotionModelPatcher, create_fresh_motion_module, + InjectionParams, MotionModelGroup, MotionModelPatcher, get_mm_attachment, create_fresh_motion_module, load_motion_module_gen2, load_motion_lora_as_patches, validate_model_compatibility_gen2, validate_per_block_compatibility) from .sample_settings import SampleSettings from .sampling import outer_sample_wrapper, sliding_calc_cond_batch @@ -128,13 +128,14 @@ def apply_motion_model(self, motion_model: MotionModelPatcher, start_percent: fl if motion_lora is not None: for lora in motion_lora.loras: load_motion_lora_as_patches(motion_model, lora) - motion_model.scale_multival = scale_multival - motion_model.effect_multival = effect_multival + attachment = get_mm_attachment(motion_model) + attachment.scale_multival = scale_multival + attachment.effect_multival = effect_multival if per_block is not None: validate_per_block_compatibility(motion_model=motion_model, all_per_blocks=per_block) - motion_model.per_block_list = per_block.per_block_list - motion_model.keyframes = ad_keyframes.clone() if ad_keyframes else ADKeyframeGroup() - motion_model.timestep_percent_range = (start_percent, end_percent) + attachment.per_block_list = per_block.per_block_list + attachment.keyframes = ad_keyframes.clone() if ad_keyframes else ADKeyframeGroup() + attachment.timestep_percent_range = (start_percent, end_percent) # add to beginning, so that after injection, it will be the earliest of prev_m_models to be run prev_m_models.add_to_start(mm=motion_model) return (prev_m_models,) diff --git a/animatediff/nodes_pia.py b/animatediff/nodes_pia.py index d5d9beb..a7b9716 100644 --- a/animatediff/nodes_pia.py +++ b/animatediff/nodes_pia.py @@ -10,7 +10,7 @@ from .utils_model import BIGMIN, BIGMAX, get_available_motion_models from .utils_motion import ADKeyframeGroup, InputPIA, InputPIA_Multival, extend_list_to_batch_size, extend_to_batch_size, prepare_mask_batch from .motion_lora import MotionLoraList -from .model_injection import MotionModelGroup, MotionModelPatcher, load_motion_module_gen2, inject_pia_conv_in_into_model +from .model_injection import MotionModelGroup, MotionModelPatcher, get_mm_attachment, load_motion_module_gen2, inject_pia_conv_in_into_model from .motion_module_ad import AnimateDiffFormat from .nodes_gen2 import ApplyAnimateDiffModelNode, ADKeyframeNode @@ -148,11 +148,12 @@ def apply_motion_model(self, motion_model: MotionModelPatcher, image: Tensor, va # confirm that model is PIA if curr_model.model.mm_info.mm_format != AnimateDiffFormat.PIA: raise Exception(f"Motion model '{curr_model.model.mm_info.mm_name}' is not a PIA model; cannot be used with Apply AnimateDiff-PIA Model node.") - curr_model.orig_pia_images = image - curr_model.pia_vae = vae + attachment = get_mm_attachment(curr_model) + attachment.orig_pia_images = image + attachment.pia_vae = vae if pia_input is None: pia_input = InputPIA_Multival(1.0) - curr_model.pia_input = pia_input + attachment.pia_input = pia_input #curr_model.pia_multival = ref_multival return new_m_models diff --git a/animatediff/utils_model.py b/animatediff/utils_model.py index 009aff8..54bef00 100644 --- a/animatediff/utils_model.py +++ b/animatediff/utils_model.py @@ -37,7 +37,6 @@ def vae_encode_raw_dynamic_batched(vae: VAE, pixels: Tensor, max_batch=16, min_b b, h, w, c = pixels.shape actual_size = h*w actual_batch_size = int(max(min_batch, min(max_batch, max_batch // max((actual_size / max_size), 1.0)))) - logger.info(f"actual_batch_size: {actual_batch_size}") return vae_encode_raw_batched(vae=vae, pixels=pixels, per_batch=actual_batch_size, show_pbar=show_pbar) From 5ee8cf677e5a40ff28b2480ba32da4ca296cc1b2 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Thu, 28 Nov 2024 13:52:02 -0600 Subject: [PATCH 39/43] Pass transformer_options into VanillaTemporalModule and all subsequent layers, as well as mm_kwargs were applicable, so that behavior can be more easily extended in the future, both from ADE code and other node packs --- animatediff/model_injection.py | 2 +- animatediff/motion_module_ad.py | 29 ++++++++++++++++++----------- animatediff/utils_motion.py | 2 +- 3 files changed, 20 insertions(+), 13 deletions(-) diff --git a/animatediff/model_injection.py b/animatediff/model_injection.py index 4a6cad2..3e48fc5 100644 --- a/animatediff/model_injection.py +++ b/animatediff/model_injection.py @@ -177,7 +177,7 @@ def create_forward_timestep_embed_patch(): return (VanillaTemporalModule, forward_timestep_embed_patch_ade) def forward_timestep_embed_patch_ade(layer, x, emb, context, transformer_options, output_shape, time_context, num_video_frames, image_only_indicator, *args, **kwargs): - return layer(x, context) + return layer(x, context, transformer_options=transformer_options) def create_MotionModelPatcher(model, load_device, offload_device) -> MotionModelPatcher: diff --git a/animatediff/motion_module_ad.py b/animatediff/motion_module_ad.py index da861b8..fe66742 100644 --- a/animatediff/motion_module_ad.py +++ b/animatediff/motion_module_ad.py @@ -921,7 +921,7 @@ def should_handle_img_features(self): def should_handle_camera_features(self): return self.camera_features is not None and self.block_type != BlockType.MID# and self.module_idx == 0 - def forward(self, input_tensor: Tensor, encoder_hidden_states=None, attention_mask=None): + def forward(self, input_tensor: Tensor, encoder_hidden_states=None, attention_mask=None, transformer_options=None): #logger.info(f"block_type: {self.block_type}, block_idx: {self.block_idx}, module_idx: {self.module_idx}") mm_kwargs = None if self.should_handle_camera_features(): @@ -930,7 +930,7 @@ def forward(self, input_tensor: Tensor, encoder_hidden_states=None, attention_ma # do AnimateLCM-I2V stuff if needed if self.should_handle_img_features(): input_tensor += self.img_features[self.block_idx] - return self.temporal_transformer(input_tensor, encoder_hidden_states, attention_mask, self.view_options, mm_kwargs) + return self.temporal_transformer(input_tensor, encoder_hidden_states, attention_mask, self.view_options, mm_kwargs, transformer_options) # return weighted average of input_tensor and AD output if type(self.effect) != Tensor: effect = self.effect @@ -944,8 +944,8 @@ def forward(self, input_tensor: Tensor, encoder_hidden_states=None, attention_ma effect = self.get_effect_mask(input_tensor) # do AnimateLCM-I2V stuff if needed if self.should_handle_img_features(): - return input_tensor*(1.0-effect) + self.temporal_transformer(input_tensor+self.img_features[self.block_idx], encoder_hidden_states, attention_mask, self.view_options, mm_kwargs)*effect - return input_tensor*(1.0-effect) + self.temporal_transformer(input_tensor, encoder_hidden_states, attention_mask, self.view_options, mm_kwargs)*effect + return input_tensor*(1.0-effect) + self.temporal_transformer(input_tensor+self.img_features[self.block_idx], encoder_hidden_states, attention_mask, self.view_options, mm_kwargs, transformer_options)*effect + return input_tensor*(1.0-effect) + self.temporal_transformer(input_tensor, encoder_hidden_states, attention_mask, self.view_options, mm_kwargs, transformer_options)*effect class TemporalTransformer3DModel(nn.Module): @@ -1174,7 +1174,7 @@ def get_cameractrl_effect(self, hidden_states: Tensor) -> Union[float, Tensor, N return self.temp_cameractrl_effect[:, self.sub_idxs, :] return self.temp_cameractrl_effect - def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, view_options: ContextOptions=None, mm_kwargs: dict[str]=None): + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, view_options: ContextOptions=None, mm_kwargs: dict[str]=None, transformer_options=None): batch, channel, height, width = hidden_states.shape residual = hidden_states scale_masks = self.get_scale_masks(hidden_states) @@ -1197,7 +1197,8 @@ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None scale_masks=scale_masks, cameractrl_effect=cameractrl_effect, view_options=view_options, - mm_kwargs=mm_kwargs + mm_kwargs=mm_kwargs, + transformer_options=transformer_options, ) # output @@ -1288,6 +1289,7 @@ def forward( cameractrl_effect: Union[float, Tensor] = None, view_options: Union[ContextOptions, None]=None, mm_kwargs: dict[str]=None, + transformer_options: dict[str]=None, ): if scale_masks is None: scale_masks = [None] * len(self.attention_blocks) @@ -1310,7 +1312,8 @@ def forward( video_length=video_length, scale_mask=scale_mask, cameractrl_effect=cameractrl_effect, - mm_kwargs=mm_kwargs + mm_kwargs=mm_kwargs, + transformer_options=transformer_options, ) + hidden_states ) else: @@ -1345,7 +1348,8 @@ def forward( video_length=len(sub_idxs), scale_mask=scale_mask[:, sub_idxs, :] if scale_mask is not None else scale_mask, cameractrl_effect=cameractrl_effect[:, sub_idxs, :] if type(cameractrl_effect) == Tensor else cameractrl_effect, - mm_kwargs=mm_kwargs + mm_kwargs=mm_kwargs, + transformer_options=transformer_options, ) + sub_hidden_states ) sub_hidden_states = rearrange(sub_hidden_states, "(b f) d c -> b f d c", f=len(sub_idxs)) @@ -1407,7 +1411,7 @@ def __init__(self, d_model, dropout=0.0, max_len=24): def set_sub_idxs(self, sub_idxs: list[int]): self.sub_idxs = sub_idxs - def forward(self, x: Tensor): + def forward(self, x: Tensor, mm_kwargs: dict[str]={}, transformer_options: dict[str]=None): #if self.sub_idxs is not None: # x = x + self.pe[:, self.sub_idxs] #else: @@ -1474,6 +1478,7 @@ def forward( scale_mask=None, cameractrl_effect: Union[float, Tensor] = 1.0, mm_kwargs: dict[str]={}, + transformer_options: dict[str]=None, ): if self.attention_mode != "Temporal": raise NotImplementedError @@ -1484,7 +1489,7 @@ def forward( ) if self.pos_encoder is not None: - hidden_states = self.pos_encoder(hidden_states).to(hidden_states.dtype) + hidden_states = self.pos_encoder(hidden_states, mm_kwargs, transformer_options).to(hidden_states.dtype) encoder_hidden_states = ( repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d) @@ -1502,6 +1507,8 @@ def forward( value=None, mask=attention_mask, scale_mask=scale_mask, + mm_kwargs=mm_kwargs, + transformer_options=transformer_options, ) hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d) @@ -1581,7 +1588,7 @@ def __init__( def create(cls, in_channels, block_type: str, block_idx: int, module_idx: int, ops=comfy.ops.disable_weight_init): return cls(in_channels=in_channels, block_type=block_type, block_idx=block_idx, module_idx=module_idx, ops=ops) - def forward(self, input_tensor: Tensor, encoder_hidden_states=None, attention_mask=None): + def forward(self, input_tensor: Tensor, encoder_hidden_states=None, attention_mask=None, transformer_options=None): if self.effect is None: # do AnimateLCM-I2V stuff if needed if self.should_handle_img_features(): diff --git a/animatediff/utils_motion.py b/animatediff/utils_motion.py index 501a7f6..e6f69a0 100644 --- a/animatediff/utils_motion.py +++ b/animatediff/utils_motion.py @@ -57,7 +57,7 @@ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0. def reset_attention_type(self): self.actual_attention = optimized_attention_mm - def forward(self, x, context=None, value=None, mask=None, scale_mask=None): + def forward(self, x, context=None, value=None, mask=None, scale_mask=None, mm_kwargs=None, transformer_options=None): q = self.to_q(x) context = default(context, x) k: Tensor = self.to_k(context) From 95cfd12a8199c4ff9fd08d5e0bc940928b078cca Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Fri, 29 Nov 2024 01:33:49 -0600 Subject: [PATCH 40/43] Generalized AnimateDiffModel initialization so mm_state_dict can be used to determine if down/up blocks should be present and the attention block count, added init_kwargs input to AnimateDiffModel so some behavior can be adapted by unusual models --- animatediff/motion_module_ad.py | 107 +++++++++++++++++++------------- 1 file changed, 64 insertions(+), 43 deletions(-) diff --git a/animatediff/motion_module_ad.py b/animatediff/motion_module_ad.py index fe66742..dc55ea2 100644 --- a/animatediff/motion_module_ad.py +++ b/animatediff/motion_module_ad.py @@ -156,10 +156,16 @@ def is_fancyvideo(mm_state_dict: dict[str, Tensor]) -> bool: def get_down_block_max(mm_state_dict: dict[str, Tensor]) -> int: + return get_block_max(mm_state_dict, "down_blocks") + +def get_up_block_max(mm_state_dict: dict[str, Tensor]) -> int: + return get_block_max(mm_state_dict, "up_blocks") + +def get_block_max(mm_state_dict: dict[str, Tensor], block_name: str) -> int: # keep track of biggest down_block count in module - biggest_block = 0 + biggest_block = -1 for key in mm_state_dict.keys(): - if "down_blocks" in key: + if block_name in key: try: block_int = key.split(".")[1] block_num = int(block_int) @@ -169,7 +175,6 @@ def get_down_block_max(mm_state_dict: dict[str, Tensor]) -> int: pass return biggest_block - def has_mid_block(mm_state_dict: dict[str, Tensor]): # check if keys contain mid_block for key in mm_state_dict.keys(): @@ -177,6 +182,17 @@ def has_mid_block(mm_state_dict: dict[str, Tensor]): return True return False +_regex_attention_blocks_num = re.compile(r'\.attention_blocks\.(\d+)\.') +def get_attention_block_max_len(mm_state_dict: dict[str, Tensor]): + biggest_attention = -1 + for key in mm_state_dict.keys(): + found = _regex_attention_blocks_num.search(key) + if found: + attention_num = int(found.group(1)) + if attention_num > biggest_attention: + biggest_attention = attention_num + return biggest_attention + 1 + def get_position_encoding_max_len(mm_state_dict: dict[str, Tensor], mm_name: str, mm_format: str) -> Union[int, None]: # use pos_encoder.pe entries to determine max length - [1, {max_length}, {320|640|1280}] @@ -337,21 +353,33 @@ def convert_hellomeme_state_dict(mm_state_dict: dict[str, Tensor]): del mm_state_dict[key] +class InitKwargs: + GET_UNET_FUNC = "get_unet_func" + ATTN_BLOCK_TYPE = "attn_block_type" + + class BlockType: UP = "up" DOWN = "down" MID = "mid" +def get_unet_default(wrapper: 'AnimateDiffModel', model: ModelPatcher): + return model.model.diffusion_model + + class AnimateDiffModel(nn.Module): - def __init__(self, mm_state_dict: dict[str, Tensor], mm_info: AnimateDiffInfo): + def __init__(self, mm_state_dict: dict[str, Tensor], mm_info: AnimateDiffInfo, init_kwargs: dict[str]={}): super().__init__() self.mm_info = mm_info - self.down_blocks: Iterable[MotionModule] = nn.ModuleList([]) - self.up_blocks: Iterable[MotionModule] = nn.ModuleList([]) + self.down_blocks: list[MotionModule] = None + self.up_blocks: list[MotionModule] = None self.mid_block: Union[MotionModule, None] = None self.encoding_max_len = get_position_encoding_max_len(mm_state_dict, mm_info.mm_name, mm_info.mm_format) self.has_position_encoding = self.encoding_max_len is not None + self.attn_len = get_attention_block_max_len(mm_state_dict) + self.attn_type = init_kwargs.get(InitKwargs.ATTN_BLOCK_TYPE, "Temporal_Self") + self.attn_block_types = tuple([self.attn_type] * self.attn_len) # determine ops to use (to support fp8 properly) if comfy.model_management.unet_manual_cast(comfy.model_management.unet_dtype(), comfy.model_management.get_torch_device()) is None: ops = comfy.ops.disable_weight_init @@ -364,16 +392,24 @@ def __init__(self, mm_state_dict: dict[str, Tensor], mm_info: AnimateDiffInfo): else: layer_channels = (320, 640, 1280, 1280) self.layer_channels = layer_channels + self.middle_channel = 1280 # fill out down/up blocks and middle block, if present - for idx, c in enumerate(layer_channels): - self.down_blocks.append(MotionModule(c, temporal_pe=self.has_position_encoding, - temporal_pe_max_len=self.encoding_max_len, block_type=BlockType.DOWN, block_idx=idx, ops=ops)) - for idx, c in enumerate(list(reversed(layer_channels))): - self.up_blocks.append(MotionModule(c, temporal_pe=self.has_position_encoding, - temporal_pe_max_len=self.encoding_max_len, block_type=BlockType.UP, block_idx=idx, ops=ops)) + if get_down_block_max(mm_state_dict) > -1: + self.down_blocks = nn.ModuleList([]) + for idx, c in enumerate(layer_channels): + self.down_blocks.append(MotionModule(c, temporal_pe=self.has_position_encoding, + temporal_pe_max_len=self.encoding_max_len, block_type=BlockType.DOWN, block_idx=idx, + attention_block_types=self.attn_block_types, ops=ops)) + if get_up_block_max(mm_state_dict) > -1: + self.up_blocks = nn.ModuleList([]) + for idx, c in enumerate(list(reversed(layer_channels))): + self.up_blocks.append(MotionModule(c, temporal_pe=self.has_position_encoding, + temporal_pe_max_len=self.encoding_max_len, block_type=BlockType.UP, block_idx=idx, + attention_block_types=self.attn_block_types, ops=ops)) if has_mid_block(mm_state_dict): - self.mid_block = MotionModule(1280, temporal_pe=self.has_position_encoding, - temporal_pe_max_len=self.encoding_max_len, block_type=BlockType.MID, ops=ops) + self.mid_block = MotionModule(self.middle_channel, temporal_pe=self.has_position_encoding, + temporal_pe_max_len=self.encoding_max_len, block_type=BlockType.MID, + attention_block_types=self.attn_block_types, ops=ops) self.AD_video_length: int = 24 self.effect_model = 1.0 self.effect_per_block_list = None @@ -395,6 +431,8 @@ def __init__(self, mm_state_dict: dict[str, Tensor], mm_info: AnimateDiffInfo): self.init_fps_embedding(mm_state_dict) if has_motion_embedding(mm_state_dict): self.init_motion_embedding(mm_state_dict) + # get_unet_func initialization + self.get_unet_func = init_kwargs.get(InitKwargs.GET_UNET_FUNC, get_unet_default) def init_img_encoder(self): del self.img_encoder @@ -501,7 +539,7 @@ def cleanup(self): self.img_encoder.cleanup() def inject(self, model: ModelPatcher): - unet: openaimodel.UNetModel = model.model.diffusion_model + unet: openaimodel.UNetModel = self.get_unet_func(self, model) # inject input (down) blocks # SD15 mm contains 4 downblocks, each with 2 TemporalTransformers - 8 in total # SDXL mm contains 3 downblocks, each with 2 TemporalTransformers - 6 in total @@ -555,7 +593,7 @@ def _inject(self, unet_blocks: nn.ModuleList, mm_blocks: nn.ModuleList): unet_idx += 1 def eject(self, model: ModelPatcher): - unet: openaimodel.UNetModel = model.model.diffusion_model + unet: openaimodel.UNetModel = self.get_unet_func(self, model) # remove from input blocks (downblocks) self._eject(unet.input_blocks) # remove from output blocks (upblocks) @@ -715,23 +753,24 @@ def __init__(self, temporal_pe_max_len=24, block_type: str=BlockType.DOWN, block_idx: int=0, + attention_block_types=("Temporal_Self", "Temporal_Self"), ops=comfy.ops.disable_weight_init ): super().__init__() if block_type == BlockType.MID: # mid blocks contain only a single VanillaTemporalModule - self.motion_modules: Iterable[VanillaTemporalModule] = nn.ModuleList([get_motion_module(in_channels, block_type, block_idx, module_idx=0, temporal_pe=temporal_pe, temporal_pe_max_len=temporal_pe_max_len, ops=ops)]) + self.motion_modules: list[VanillaTemporalModule] = nn.ModuleList([get_motion_module(in_channels, block_type, block_idx, module_idx=0, attention_block_types=attention_block_types, temporal_pe=temporal_pe, temporal_pe_max_len=temporal_pe_max_len, ops=ops)]) else: # down blocks contain two VanillaTemporalModules - self.motion_modules: Iterable[VanillaTemporalModule] = nn.ModuleList( + self.motion_modules: list[VanillaTemporalModule] = nn.ModuleList( [ - get_motion_module(in_channels, block_type, block_idx, module_idx=0, temporal_pe=temporal_pe, temporal_pe_max_len=temporal_pe_max_len, ops=ops), - get_motion_module(in_channels, block_type, block_idx, module_idx=1, temporal_pe=temporal_pe, temporal_pe_max_len=temporal_pe_max_len, ops=ops) + get_motion_module(in_channels, block_type, block_idx, module_idx=0, attention_block_types=attention_block_types, temporal_pe=temporal_pe, temporal_pe_max_len=temporal_pe_max_len, ops=ops), + get_motion_module(in_channels, block_type, block_idx, module_idx=1, attention_block_types=attention_block_types, temporal_pe=temporal_pe, temporal_pe_max_len=temporal_pe_max_len, ops=ops) ] ) # up blocks contain one additional VanillaTemporalModule if block_type == BlockType.UP: - self.motion_modules.append(get_motion_module(in_channels, block_type, block_idx, module_idx=2, temporal_pe=temporal_pe, temporal_pe_max_len=temporal_pe_max_len, ops=ops)) + self.motion_modules.append(get_motion_module(in_channels, block_type, block_idx, module_idx=2, attention_block_types=attention_block_types, temporal_pe=temporal_pe, temporal_pe_max_len=temporal_pe_max_len, ops=ops)) def set_video_length(self, video_length: int, full_length: int): for motion_module in self.motion_modules: @@ -772,8 +811,10 @@ def reset_temp_vars(self): def get_motion_module(in_channels, block_type: str, block_idx: int, module_idx: int, + attention_block_types: list[str], temporal_pe, temporal_pe_max_len, ops=comfy.ops.disable_weight_init): return VanillaTemporalModule(in_channels=in_channels, block_type=block_type, block_idx=block_idx, module_idx=module_idx, + attention_block_types=attention_block_types, temporal_pe=temporal_pe, temporal_pe_max_len=temporal_pe_max_len, ops=ops) @@ -1324,7 +1365,6 @@ def forward( hidden_states = rearrange(hidden_states, "(b f) d c -> b f d c", f=video_length) value_final = torch.zeros_like(hidden_states) count_final = torch.zeros_like(hidden_states) - # bias_final = [0.0] * video_length batched_conds = hidden_states.size(1) // video_length # store original camera_feature, if present has_camera_feature = False @@ -1354,23 +1394,6 @@ def forward( ) sub_hidden_states = rearrange(sub_hidden_states, "(b f) d c -> b f d c", f=len(sub_idxs)) - # if view_options.fuse_method == ContextFuseMethod.RELATIVE: - # for pos, idx in enumerate(sub_idxs): - # # bias is the influence of a specific index in relation to the whole context window - # bias = 1 - abs(idx - (sub_idxs[0] + sub_idxs[-1]) / 2) / ((sub_idxs[-1] - sub_idxs[0] + 1e-2) / 2) - # bias = max(1e-2, bias) - # # take weighted averate relative to total bias of current idx - # bias_total = bias_final[idx] - # prev_weight = torch.tensor([bias_total / (bias_total + bias)], - # dtype=value_final.dtype, device=value_final.device).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) - # #prev_weight = torch.cat([prev_weight]*value_final.shape[1], dim=1) - # new_weight = torch.tensor([bias / (bias_total + bias)], - # dtype=value_final.dtype, device=value_final.device).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) - # #new_weight = torch.cat([new_weight]*value_final.shape[1], dim=1) - # test = value_final[:, idx:idx+1, :, :] - # value_final[:, idx:idx+1, :, :] = value_final[:, idx:idx+1, :, :] * prev_weight + sub_hidden_states[:, pos:pos+1, : ,:] * new_weight - # bias_final[idx] = bias_total + bias - # else: weights = get_context_weights(len(sub_idxs), view_options.fuse_method) * batched_conds weights_tensor = torch.Tensor(weights).to(device=hidden_states.device).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) value_final[:, sub_idxs] += sub_hidden_states * weights_tensor @@ -1379,13 +1402,11 @@ def forward( if has_camera_feature: mm_kwargs["camera_feature"] = orig_camera_feature del orig_camera_feature - # get weighted average of sub_hidden_states, if fuse method requires it - # if view_options.fuse_method != ContextFuseMethod.RELATIVE: + # get weighted average of sub_hidden_states hidden_states = value_final / count_final hidden_states = rearrange(hidden_states, "b f d c -> (b f) d c") del value_final del count_final - # del bias_final hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states @@ -1521,7 +1542,7 @@ def forward( class EncoderOnlyAnimateDiffModel(AnimateDiffModel): def __init__(self, mm_state_dict: dict[str, Tensor], mm_info: AnimateDiffInfo): super().__init__(mm_state_dict=mm_state_dict, mm_info=mm_info) - self.down_blocks: Iterable[EncoderOnlyMotionModule] = nn.ModuleList([]) + self.down_blocks: list[EncoderOnlyMotionModule] = nn.ModuleList([]) self.up_blocks = None self.mid_block = None # fill out down/up blocks and middle block, if present From 28b841f083147da24dd2106fea0a0f3c35238bc4 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Fri, 29 Nov 2024 02:04:40 -0600 Subject: [PATCH 41/43] Fixed special model features not being changed to use attachment from refactored MotionModelPatchers --- animatediff/sampling.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/animatediff/sampling.py b/animatediff/sampling.py index e5995a4..80cac68 100644 --- a/animatediff/sampling.py +++ b/animatediff/sampling.py @@ -24,7 +24,7 @@ from .sample_settings import SampleSettings, NoisedImageToInject from .utils_model import MachineState, vae_encode_raw_batched, vae_decode_raw_batched from .utils_motion import composite_extend, prepare_mask_batch, extend_to_batch_size -from .model_injection import InjectionParams, ModelPatcherHelper, MotionModelGroup +from .model_injection import InjectionParams, ModelPatcherHelper, MotionModelGroup, get_mm_attachment from .motion_module_ad import AnimateDiffFormat, AnimateDiffInfo, AnimateDiffVersion from .logger import logger @@ -67,15 +67,16 @@ def perform_special_model_features(self, model: BaseModel, conds: list, x_in: Te if len(special_models) > 0: for special_model in special_models: if special_model.model.is_in_effect(): - if special_model.is_pia(): + attachment = get_mm_attachment(special_model) + if attachment.is_pia(): special_model.model.inject_unet_conv_in_pia_fancyvideo(model) conds = get_conds_with_c_concat(conds, - special_model.get_pia_c_concat(model, x_in)) - elif special_model.is_fancyvideo(): + attachment.get_pia_c_concat(model, x_in)) + elif attachment.is_fancyvideo(): # TODO: handle other weights special_model.model.inject_unet_conv_in_pia_fancyvideo(model) conds = get_conds_with_c_concat(conds, - special_model.get_fancy_c_concat(model, x_in)) + attachment.get_fancy_c_concat(model, x_in)) # add fps_embedding/motion_embedding patches emb_patches = special_model.model.get_fancyvideo_emb_patches(dtype=x_in.dtype, device=x_in.device) transformer_patches = model_options["transformer_options"].get("patches", {}) @@ -88,9 +89,10 @@ def restore_special_model_features(self, model: BaseModel): special_models = self.motion_models.get_special_models() if len(special_models) > 0: for special_model in reversed(special_models): - if special_model.is_pia(): + attachment = get_mm_attachment(special_model) + if attachment.is_pia(): special_model.model.restore_unet_conv_in_pia_fancyvideo(model) - elif special_model.is_fancyvideo(): + elif attachment.is_fancyvideo(): # TODO: fill out special_model.model.restore_unet_conv_in_pia_fancyvideo(model) From a6c94002a7cf82421f1b8503cee9d1e6adcf77ec Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Fri, 29 Nov 2024 22:37:02 -0600 Subject: [PATCH 42/43] Register AnimateDiffModel and AnimateDiffInfo on DinkLink --- animatediff/dinklink.py | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/animatediff/dinklink.py b/animatediff/dinklink.py index d27a1c2..36063d5 100644 --- a/animatediff/dinklink.py +++ b/animatediff/dinklink.py @@ -13,27 +13,41 @@ from __future__ import annotations import comfy.hooks +from .motion_module_ad import AnimateDiffModel, AnimateDiffInfo + DINKLINK = "__DINKLINK" + def init_dinklink(): - if not hasattr(comfy.hooks, DINKLINK): - setattr(comfy.hooks, DINKLINK, {}) + create_dinklink() prepare_dinklink() +def create_dinklink(): + if not hasattr(comfy.hooks, DINKLINK): + setattr(comfy.hooks, DINKLINK, {}) def get_dinklink() -> dict[str, dict[str]]: + create_dinklink() return getattr(comfy.hooks, DINKLINK) class DinkLinkConst: VERSION = "version" + # ACN ACN = "ACN" ACN_CREATE_OUTER_SAMPLE_WRAPPER = "create_outer_sample_wrapper" - + # ADE + ADE = "ADE" + ADE_ANIMATEDIFFMODEL = "AnimateDiffModel" + ADE_ANIMATEDIFFINFO = "AnimateDiffInfo" def prepare_dinklink(): - pass - + # expose classes + d = get_dinklink() + link_ade = d.setdefault(DinkLinkConst.ADE, {}) + link_ade[DinkLinkConst.VERSION] = 10000 + link_ade[DinkLinkConst.ADE_ANIMATEDIFFMODEL] = AnimateDiffModel + link_ade[DinkLinkConst.ADE_ANIMATEDIFFINFO] = AnimateDiffInfo def get_acn_outer_sample_wrapper(throw_exception=True): d = get_dinklink() From 852fe3fc8169a94b4045ac288148fce615877a44 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Sun, 1 Dec 2024 19:43:44 -0600 Subject: [PATCH 43/43] Add create_MotionModelPatcher to DinkLink, make eject function check if blocks variables are actually defined --- __init__.py | 2 ++ animatediff/dinklink.py | 1 + animatediff/model_injection.py | 9 ++++++++- animatediff/motion_module_ad.py | 27 ++++++++++++++++----------- 4 files changed, 27 insertions(+), 12 deletions(-) diff --git a/__init__.py b/__init__.py index ea6806c..64cf281 100644 --- a/__init__.py +++ b/__init__.py @@ -1,6 +1,7 @@ import folder_paths from .animatediff.logger import logger from .animatediff.utils_model import get_available_motion_models, Folders +from .animatediff.model_injection import prepare_dinklink_register_definitions from .animatediff.nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS from .animatediff import documentation from .animatediff.dinklink import init_dinklink @@ -13,3 +14,4 @@ documentation.format_descriptions(NODE_CLASS_MAPPINGS) init_dinklink() +prepare_dinklink_register_definitions() diff --git a/animatediff/dinklink.py b/animatediff/dinklink.py index 36063d5..66ffdd6 100644 --- a/animatediff/dinklink.py +++ b/animatediff/dinklink.py @@ -40,6 +40,7 @@ class DinkLinkConst: ADE = "ADE" ADE_ANIMATEDIFFMODEL = "AnimateDiffModel" ADE_ANIMATEDIFFINFO = "AnimateDiffInfo" + ADE_CREATE_MOTIONMODELPATCHER = "create_MotionModelPatcher" def prepare_dinklink(): # expose classes diff --git a/animatediff/model_injection.py b/animatediff/model_injection.py index 3e48fc5..5987d40 100644 --- a/animatediff/model_injection.py +++ b/animatediff/model_injection.py @@ -31,7 +31,14 @@ from .motion_lora import MotionLoraInfo, MotionLoraList from .utils_model import get_motion_lora_path, get_motion_model_path, get_sd_model_type, vae_encode_raw_batched from .sample_settings import SampleSettings, SeedNoiseGeneration -from .dinklink import get_acn_outer_sample_wrapper +from .dinklink import DinkLinkConst, get_dinklink, get_acn_outer_sample_wrapper + + +def prepare_dinklink_register_definitions(): + # expose create_MotionModelPatcher + d = get_dinklink() + link_ade = d.setdefault(DinkLinkConst.ADE, {}) + link_ade[DinkLinkConst.ADE_CREATE_MOTIONMODELPATCHER] = create_MotionModelPatcher class MotionModelPatcher(ModelPatcher): diff --git a/animatediff/motion_module_ad.py b/animatediff/motion_module_ad.py index dc55ea2..21c265b 100644 --- a/animatediff/motion_module_ad.py +++ b/animatediff/motion_module_ad.py @@ -354,6 +354,7 @@ def convert_hellomeme_state_dict(mm_state_dict: dict[str, Tensor]): class InitKwargs: + OPS = "ops" GET_UNET_FUNC = "get_unet_func" ATTN_BLOCK_TYPE = "attn_block_type" @@ -381,11 +382,12 @@ def __init__(self, mm_state_dict: dict[str, Tensor], mm_info: AnimateDiffInfo, i self.attn_type = init_kwargs.get(InitKwargs.ATTN_BLOCK_TYPE, "Temporal_Self") self.attn_block_types = tuple([self.attn_type] * self.attn_len) # determine ops to use (to support fp8 properly) - if comfy.model_management.unet_manual_cast(comfy.model_management.unet_dtype(), comfy.model_management.get_torch_device()) is None: - ops = comfy.ops.disable_weight_init - else: - ops = comfy.ops.manual_cast - self.ops = ops + self.ops = init_kwargs.get(InitKwargs.OPS, None) + if self.ops is None: + if comfy.model_management.unet_manual_cast(comfy.model_management.unet_dtype(), comfy.model_management.get_torch_device()) is None: + self.ops = comfy.ops.disable_weight_init + else: + self.ops = comfy.ops.manual_cast # SDXL has 3 up/down blocks, SD1.5 has 4 up/down blocks if mm_info.sd_type == ModelTypeSD.SDXL: layer_channels = (320, 640, 1280) @@ -399,17 +401,17 @@ def __init__(self, mm_state_dict: dict[str, Tensor], mm_info: AnimateDiffInfo, i for idx, c in enumerate(layer_channels): self.down_blocks.append(MotionModule(c, temporal_pe=self.has_position_encoding, temporal_pe_max_len=self.encoding_max_len, block_type=BlockType.DOWN, block_idx=idx, - attention_block_types=self.attn_block_types, ops=ops)) + attention_block_types=self.attn_block_types, ops=self.ops)) if get_up_block_max(mm_state_dict) > -1: self.up_blocks = nn.ModuleList([]) for idx, c in enumerate(list(reversed(layer_channels))): self.up_blocks.append(MotionModule(c, temporal_pe=self.has_position_encoding, temporal_pe_max_len=self.encoding_max_len, block_type=BlockType.UP, block_idx=idx, - attention_block_types=self.attn_block_types, ops=ops)) + attention_block_types=self.attn_block_types, ops=self.ops)) if has_mid_block(mm_state_dict): self.mid_block = MotionModule(self.middle_channel, temporal_pe=self.has_position_encoding, temporal_pe_max_len=self.encoding_max_len, block_type=BlockType.MID, - attention_block_types=self.attn_block_types, ops=ops) + attention_block_types=self.attn_block_types, ops=self.ops) self.AD_video_length: int = 24 self.effect_model = 1.0 self.effect_per_block_list = None @@ -595,11 +597,14 @@ def _inject(self, unet_blocks: nn.ModuleList, mm_blocks: nn.ModuleList): def eject(self, model: ModelPatcher): unet: openaimodel.UNetModel = self.get_unet_func(self, model) # remove from input blocks (downblocks) - self._eject(unet.input_blocks) + if hasattr(unet, "input_blocks"): + self._eject(unet.input_blocks) # remove from output blocks (upblocks) - self._eject(unet.output_blocks) + if hasattr(unet, "output_blocks"): + self._eject(unet.output_blocks) # remove from middle block (encapsulate in list to make compatible) - self._eject([unet.middle_block]) + if hasattr(unet, "middle_block"): + self._eject([unet.middle_block]) del unet def _eject(self, unet_blocks: nn.ModuleList):