44from typing import List , Union
55
66import torch .nn as nn
7- from diffusers .models import AutoencoderKL , UNet2DModel
7+ from diffusers .models import AutoencoderKL , UNet2DConditionModel
88
99
1010def _conv_forward_asymmetric (self , input , weight , bias ):
@@ -25,12 +25,53 @@ def _conv_forward_asymmetric(self, input, weight, bias):
2525
2626
2727@contextmanager
28- def set_seamless (model : Union [UNet2DModel , AutoencoderKL ], seamless_axes : List [str ]):
28+ def set_seamless (model : Union [UNet2DConditionModel , AutoencoderKL ], seamless_axes : List [str ]):
2929 try :
3030 to_restore = []
3131
32- for m in model .modules ():
32+ for m_name , m in model .named_modules ():
33+ if isinstance (model , UNet2DConditionModel ):
34+ if ".attentions." in m_name :
35+ continue
36+
37+ if ".resnets." in m_name :
38+ if ".conv2" in m_name :
39+ continue
40+ if ".conv_shortcut" in m_name :
41+ continue
42+
43+ """
44+ if isinstance(model, UNet2DConditionModel):
45+ if False and ".upsamplers." in m_name:
46+ continue
47+
48+ if False and ".downsamplers." in m_name:
49+ continue
50+
51+ if True and ".resnets." in m_name:
52+ if True and ".conv1" in m_name:
53+ if False and "down_blocks" in m_name:
54+ continue
55+ if False and "mid_block" in m_name:
56+ continue
57+ if False and "up_blocks" in m_name:
58+ continue
59+
60+ if True and ".conv2" in m_name:
61+ continue
62+
63+ if True and ".conv_shortcut" in m_name:
64+ continue
65+
66+ if True and ".attentions." in m_name:
67+ continue
68+
69+ if False and m_name in ["conv_in", "conv_out"]:
70+ continue
71+ """
72+
3373 if isinstance (m , (nn .Conv2d , nn .ConvTranspose2d )):
74+ print (f"applied - { m_name } " )
3475 m .asymmetric_padding_mode = {}
3576 m .asymmetric_padding = {}
3677 m .asymmetric_padding_mode ["x" ] = "circular" if ("x" in seamless_axes ) else "constant"
0 commit comments