Skip to content

Commit

Permalink
Refactor comfy.ops
Browse files Browse the repository at this point in the history
comfy.ops -> comfy.ops.disable_weight_init

This should make it more clear what they actually do.

Some unused code has also been removed.
  • Loading branch information
comfyanonymous committed Dec 12, 2023
1 parent b0aab1e commit 77755ab
Show file tree
Hide file tree
Showing 10 changed files with 77 additions and 153 deletions.
2 changes: 1 addition & 1 deletion comfy/cldm/cldm.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(
transformer_depth_middle=None,
transformer_depth_output=None,
device=None,
operations=comfy.ops,
operations=comfy.ops.disable_weight_init,
**kwargs,
):
super().__init__()
Expand Down
2 changes: 1 addition & 1 deletion comfy/clip_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(self, json_config):
if comfy.model_management.should_use_fp16(self.load_device, prioritize_performance=False):
self.dtype = torch.float16

self.model = comfy.clip_model.CLIPVisionModelProjection(config, self.dtype, offload_device, comfy.ops)
self.model = comfy.clip_model.CLIPVisionModelProjection(config, self.dtype, offload_device, comfy.ops.disable_weight_init)

self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
def load_sd(self, sd):
Expand Down
27 changes: 7 additions & 20 deletions comfy/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,9 +208,9 @@ def __init__(self, in_features: int, out_features: int, bias: bool = True,

def forward(self, input):
if self.up is not None:
return torch.nn.functional.linear(input, self.weight.to(input.dtype).to(input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias)
return torch.nn.functional.linear(input, self.weight.to(dtype=input.dtype, device=input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias)
else:
return torch.nn.functional.linear(input, self.weight.to(input.device), self.bias)
return torch.nn.functional.linear(input, self.weight.to(dtype=input.dtype, device=input.device), self.bias)

class Conv2d(torch.nn.Module):
def __init__(
Expand Down Expand Up @@ -247,24 +247,9 @@ def __init__(

def forward(self, input):
if self.up is not None:
return torch.nn.functional.conv2d(input, self.weight.to(input.dtype).to(input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias, self.stride, self.padding, self.dilation, self.groups)
return torch.nn.functional.conv2d(input, self.weight.to(dtype=input.dtype, device=input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias, self.stride, self.padding, self.dilation, self.groups)
else:
return torch.nn.functional.conv2d(input, self.weight.to(input.device), self.bias, self.stride, self.padding, self.dilation, self.groups)

def conv_nd(self, dims, *args, **kwargs):
if dims == 2:
return self.Conv2d(*args, **kwargs)
else:
raise ValueError(f"unsupported dimensions: {dims}")

class Conv3d(comfy.ops.Conv3d):
pass

class GroupNorm(comfy.ops.GroupNorm):
pass

class LayerNorm(comfy.ops.LayerNorm):
pass
return torch.nn.functional.conv2d(input, self.weight.to(dtype=input.dtype, device=input.device), self.bias, self.stride, self.padding, self.dilation, self.groups)


class ControlLora(ControlNet):
Expand All @@ -278,7 +263,9 @@ def pre_run(self, model, percent_to_timestep_function):
controlnet_config = model.model_config.unet_config.copy()
controlnet_config.pop("out_channels")
controlnet_config["hint_channels"] = self.control_weights["input_hint_block.0.weight"].shape[1]
controlnet_config["operations"] = ControlLoraOps()
class control_lora_ops(ControlLoraOps, comfy.ops.disable_weight_init):
pass
controlnet_config["operations"] = control_lora_ops
self.control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
dtype = model.get_dtype()
self.control_model.to(dtype)
Expand Down
13 changes: 7 additions & 6 deletions comfy/ldm/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from comfy.cli_args import args
import comfy.ops
ops = comfy.ops.disable_weight_init

# CrossAttn precision handling
if args.dont_upcast_attention:
Expand Down Expand Up @@ -55,7 +56,7 @@ def init_(tensor):

# feedforward
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=comfy.ops):
def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=ops):
super().__init__()
self.proj = operations.Linear(dim_in, dim_out * 2, dtype=dtype, device=device)

Expand All @@ -65,7 +66,7 @@ def forward(self, x):


class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0., dtype=None, device=None, operations=comfy.ops):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0., dtype=None, device=None, operations=ops):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
Expand Down Expand Up @@ -356,7 +357,7 @@ def optimized_attention_for_device(device, mask=False):


class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=ops):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
Expand Down Expand Up @@ -389,7 +390,7 @@ def forward(self, x, context=None, value=None, mask=None):

class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, ff_in=False, inner_dim=None,
disable_self_attn=False, disable_temporal_crossattention=False, switch_temporal_ca_to_sa=False, dtype=None, device=None, operations=comfy.ops):
disable_self_attn=False, disable_temporal_crossattention=False, switch_temporal_ca_to_sa=False, dtype=None, device=None, operations=ops):
super().__init__()

self.ff_in = ff_in or inner_dim is not None
Expand Down Expand Up @@ -558,7 +559,7 @@ class SpatialTransformer(nn.Module):
def __init__(self, in_channels, n_heads, d_head,
depth=1, dropout=0., context_dim=None,
disable_self_attn=False, use_linear=False,
use_checkpoint=True, dtype=None, device=None, operations=comfy.ops):
use_checkpoint=True, dtype=None, device=None, operations=ops):
super().__init__()
if exists(context_dim) and not isinstance(context_dim, list):
context_dim = [context_dim] * depth
Expand Down Expand Up @@ -632,7 +633,7 @@ def __init__(
disable_self_attn=False,
disable_temporal_crossattention=False,
max_time_embed_period: int = 10000,
dtype=None, device=None, operations=comfy.ops
dtype=None, device=None, operations=ops
):
super().__init__(
in_channels,
Expand Down
39 changes: 20 additions & 19 deletions comfy/ldm/modules/diffusionmodules/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from comfy import model_management
import comfy.ops
ops = comfy.ops.disable_weight_init

if model_management.xformers_enabled_vae():
import xformers
Expand Down Expand Up @@ -48,7 +49,7 @@ def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
self.conv = comfy.ops.Conv2d(in_channels,
self.conv = ops.Conv2d(in_channels,
in_channels,
kernel_size=3,
stride=1,
Expand Down Expand Up @@ -78,7 +79,7 @@ def __init__(self, in_channels, with_conv):
self.with_conv = with_conv
if self.with_conv:
# no asymmetric padding in torch conv, must do it ourselves
self.conv = comfy.ops.Conv2d(in_channels,
self.conv = ops.Conv2d(in_channels,
in_channels,
kernel_size=3,
stride=2,
Expand All @@ -105,30 +106,30 @@ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,

self.swish = torch.nn.SiLU(inplace=True)
self.norm1 = Normalize(in_channels)
self.conv1 = comfy.ops.Conv2d(in_channels,
self.conv1 = ops.Conv2d(in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
if temb_channels > 0:
self.temb_proj = comfy.ops.Linear(temb_channels,
self.temb_proj = ops.Linear(temb_channels,
out_channels)
self.norm2 = Normalize(out_channels)
self.dropout = torch.nn.Dropout(dropout, inplace=True)
self.conv2 = comfy.ops.Conv2d(out_channels,
self.conv2 = ops.Conv2d(out_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = comfy.ops.Conv2d(in_channels,
self.conv_shortcut = ops.Conv2d(in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
else:
self.nin_shortcut = comfy.ops.Conv2d(in_channels,
self.nin_shortcut = ops.Conv2d(in_channels,
out_channels,
kernel_size=1,
stride=1,
Expand Down Expand Up @@ -245,22 +246,22 @@ def __init__(self, in_channels):
self.in_channels = in_channels

self.norm = Normalize(in_channels)
self.q = comfy.ops.Conv2d(in_channels,
self.q = ops.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.k = comfy.ops.Conv2d(in_channels,
self.k = ops.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.v = comfy.ops.Conv2d(in_channels,
self.v = ops.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.proj_out = comfy.ops.Conv2d(in_channels,
self.proj_out = ops.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
Expand Down Expand Up @@ -312,14 +313,14 @@ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
# timestep embedding
self.temb = nn.Module()
self.temb.dense = nn.ModuleList([
comfy.ops.Linear(self.ch,
ops.Linear(self.ch,
self.temb_ch),
comfy.ops.Linear(self.temb_ch,
ops.Linear(self.temb_ch,
self.temb_ch),
])

# downsampling
self.conv_in = comfy.ops.Conv2d(in_channels,
self.conv_in = ops.Conv2d(in_channels,
self.ch,
kernel_size=3,
stride=1,
Expand Down Expand Up @@ -388,7 +389,7 @@ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,

# end
self.norm_out = Normalize(block_in)
self.conv_out = comfy.ops.Conv2d(block_in,
self.conv_out = ops.Conv2d(block_in,
out_ch,
kernel_size=3,
stride=1,
Expand Down Expand Up @@ -461,7 +462,7 @@ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
self.in_channels = in_channels

# downsampling
self.conv_in = comfy.ops.Conv2d(in_channels,
self.conv_in = ops.Conv2d(in_channels,
self.ch,
kernel_size=3,
stride=1,
Expand Down Expand Up @@ -506,7 +507,7 @@ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,

# end
self.norm_out = Normalize(block_in)
self.conv_out = comfy.ops.Conv2d(block_in,
self.conv_out = ops.Conv2d(block_in,
2*z_channels if double_z else z_channels,
kernel_size=3,
stride=1,
Expand Down Expand Up @@ -541,7 +542,7 @@ class Decoder(nn.Module):
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
conv_out_op=comfy.ops.Conv2d,
conv_out_op=ops.Conv2d,
resnet_op=ResnetBlock,
attn_op=AttnBlock,
**ignorekwargs):
Expand All @@ -565,7 +566,7 @@ def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
self.z_shape, np.prod(self.z_shape)))

# z to block_in
self.conv_in = comfy.ops.Conv2d(z_channels,
self.conv_in = ops.Conv2d(z_channels,
block_in,
kernel_size=3,
stride=1,
Expand Down
14 changes: 7 additions & 7 deletions comfy/ldm/modules/diffusionmodules/openaimodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
checkpoint,
avg_pool_nd,
zero_module,
normalization,
timestep_embedding,
AlphaBlender,
)
from ..attention import SpatialTransformer, SpatialVideoTransformer, default
from comfy.ldm.util import exists
import comfy.ops
ops = comfy.ops.disable_weight_init

class TimestepBlock(nn.Module):
"""
Expand Down Expand Up @@ -70,7 +70,7 @@ class Upsample(nn.Module):
upsampling occurs in the inner-two dimensions.
"""

def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None, operations=comfy.ops):
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None, operations=ops):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
Expand Down Expand Up @@ -106,7 +106,7 @@ class Downsample(nn.Module):
downsampling occurs in the inner-two dimensions.
"""

def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None, operations=comfy.ops):
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1, dtype=None, device=None, operations=ops):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
Expand Down Expand Up @@ -159,7 +159,7 @@ def __init__(
skip_t_emb=False,
dtype=None,
device=None,
operations=comfy.ops
operations=ops
):
super().__init__()
self.channels = channels
Expand Down Expand Up @@ -284,7 +284,7 @@ def __init__(
down: bool = False,
dtype=None,
device=None,
operations=comfy.ops
operations=ops
):
super().__init__(
channels,
Expand Down Expand Up @@ -434,7 +434,7 @@ def __init__(
disable_temporal_crossattention=False,
max_ddpm_temb_period=10000,
device=None,
operations=comfy.ops,
operations=ops,
):
super().__init__()
assert use_spatial_transformer == True, "use_spatial_transformer has to be true"
Expand Down Expand Up @@ -581,7 +581,7 @@ def get_resblock(
up=False,
dtype=None,
device=None,
operations=comfy.ops
operations=ops
):
if self.use_temporal_resblocks:
return VideoResBlock(
Expand Down
Loading

0 comments on commit 77755ab

Please sign in to comment.