Skip to content

Commit

Permalink
Merge PR #498 from Kosinkadink/rework-modelpatcher
Browse files Browse the repository at this point in the history
Rework ModelPatcher for upcoming ComfyUI update
  • Loading branch information
Kosinkadink authored Dec 2, 2024
2 parents b3e508a + 852fe3f commit 76d8664
Show file tree
Hide file tree
Showing 31 changed files with 1,865 additions and 2,064 deletions.
5 changes: 5 additions & 0 deletions __init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
import folder_paths
from .animatediff.logger import logger
from .animatediff.utils_model import get_available_motion_models, Folders
from .animatediff.model_injection import prepare_dinklink_register_definitions
from .animatediff.nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
from .animatediff import documentation
from .animatediff.dinklink import init_dinklink

if len(get_available_motion_models()) == 0:
logger.error(f"No motion models found. Please download one and place in: {folder_paths.get_folder_paths(Folders.ANIMATEDIFF_MODELS)}")

WEB_DIRECTORY = "./web"
__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS", "WEB_DIRECTORY"]
documentation.format_descriptions(NODE_CLASS_MAPPINGS)

init_dinklink()
prepare_dinklink_register_definitions()
40 changes: 40 additions & 0 deletions animatediff/adapter_fancyvideo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from torch import nn

import comfy.ops


FancyVideoKeys = [
'fps_embedding.linear.bias',
'fps_embedding.linear.weight',
'motion_embedding.linear.bias',
'motion_embedding.linear.weight',
'conv_in.bias',
'conv_in.weight',
]


def initialize_weights_to_zero(m):
if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
nn.init.constant_(m.weight, 0)
if m.bias is not None:
nn.init.constant_(m.bias, 0)


class FancyVideoCondEmbedding(nn.Module):
def __init__(self, in_channels: int, cond_embed_dim: int, act_fn: str = "silu", ops=comfy.ops.disable_weight_init):
super().__init__()

self.linear = ops.Linear(in_channels, cond_embed_dim)
self.act = None
if act_fn == "silu":
self.act = nn.SiLU()
elif act_fn == "mish":
self.act = nn.Mish()

def forward(self, sample):
sample = self.linear(sample)

if self.act is not None:
sample = self.act(sample)

return sample
11 changes: 9 additions & 2 deletions animatediff/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,9 +604,14 @@ def draw_view(window: list[int], gd: GridDisplay):
draw_subidxs(window=window, gd=gd, y_grid_offset=2, color=gd.vs.view_color)


def generate_context_visualization(context_opts: ContextOptionsGroup, model: ModelPatcher, sampler_name: str=None, scheduler: str=None,
def generate_context_visualization(model: ModelPatcher, context_opts: ContextOptionsGroup=None, sampler_name: str=None, scheduler: str=None,
width=1440, height=200, video_length=32,
steps=None, start_step=None, end_step=None, sigmas=None, force_full_denoise=False, denoise=None):
if context_opts is None:
context_opts = ContextOptionsGroup.default()
params = model.get_attachment("ADE_params")
if params is not None:
context_opts = params.context_options
context_opts = context_opts.clone()
vs = VisualizeSettings(width, video_length)
all_imgs = []
Expand Down Expand Up @@ -642,7 +647,9 @@ def generate_context_visualization(context_opts: ContextOptionsGroup, model: Mod

# check if context should even be active in this case
context_active = True
if video_length < context_opts.context_length:
if context_opts.context_length is None:
context_active = False
elif video_length < context_opts.context_length:
context_active = False
elif video_length == context_opts.context_length and not context_opts.use_on_equal_length:
context_active = False
Expand Down
62 changes: 62 additions & 0 deletions animatediff/dinklink.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
####################################################################################################
# DinkLink is my method of sharing classes/functions between my nodes.
#
# My DinkLink-compatible nodes will inject comfy.hooks with a __DINKLINK attr
# that stores a dictionary, where any of my node packs can store their stuff.
#
# It is not intended to be accessed by node packs that I don't develop, so things may change
# at any time.
#
# DinkLink also serves as a proof-of-concept for a future ComfyUI implementation of
# purposely exposing node pack classes/functions with other node packs.
####################################################################################################
from __future__ import annotations
import comfy.hooks

from .motion_module_ad import AnimateDiffModel, AnimateDiffInfo

DINKLINK = "__DINKLINK"


def init_dinklink():
create_dinklink()
prepare_dinklink()

def create_dinklink():
if not hasattr(comfy.hooks, DINKLINK):
setattr(comfy.hooks, DINKLINK, {})

def get_dinklink() -> dict[str, dict[str]]:
create_dinklink()
return getattr(comfy.hooks, DINKLINK)


class DinkLinkConst:
VERSION = "version"
# ACN
ACN = "ACN"
ACN_CREATE_OUTER_SAMPLE_WRAPPER = "create_outer_sample_wrapper"
# ADE
ADE = "ADE"
ADE_ANIMATEDIFFMODEL = "AnimateDiffModel"
ADE_ANIMATEDIFFINFO = "AnimateDiffInfo"
ADE_CREATE_MOTIONMODELPATCHER = "create_MotionModelPatcher"

def prepare_dinklink():
# expose classes
d = get_dinklink()
link_ade = d.setdefault(DinkLinkConst.ADE, {})
link_ade[DinkLinkConst.VERSION] = 10000
link_ade[DinkLinkConst.ADE_ANIMATEDIFFMODEL] = AnimateDiffModel
link_ade[DinkLinkConst.ADE_ANIMATEDIFFINFO] = AnimateDiffInfo

def get_acn_outer_sample_wrapper(throw_exception=True):
d = get_dinklink()
try:
link_acn = d[DinkLinkConst.ACN]
return link_acn[DinkLinkConst.ACN_CREATE_OUTER_SAMPLE_WRAPPER]
except KeyError:
if throw_exception:
raise Exception("Advanced-ControlNet nodes need to be installed to make use of ContextRef; " + \
"they are either not installed or are of an insufficient version.")
return None
4 changes: 3 additions & 1 deletion animatediff/freeinit.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class FreeInitFilter:
LIST = [GAUSSIAN, BUTTERWORTH, IDEAL, BOX]


def freq_mix_3d(x, noise, LPF):
def freq_mix_3d(x: torch.Tensor, noise: torch.Tensor, LPF: torch.Tensor):
"""
Noise reinitialization.
Expand All @@ -33,6 +33,8 @@ def freq_mix_3d(x, noise, LPF):
noise: randomly sampled noise
LPF: low pass filter
"""
noise = noise.to(dtype=x.dtype, device=x.device)
LPF = LPF.to(dtype=x.dtype, device=x.device)
# FFT
x_freq = fft.fftn(x, dim=(-4, -2, -1))
x_freq = fft.fftshift(x_freq, dim=(-4, -2, -1))
Expand Down
Loading

0 comments on commit 76d8664

Please sign in to comment.