-
Notifications
You must be signed in to change notification settings - Fork 4
/
controller.py
156 lines (141 loc) · 7.07 KB
/
controller.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import numpy as np
import torch
import math
import xformers
class DummyController:
def __call__(self, *args):
return args[0]
def __init__(self):
self.num_att_layers = 0
class GroupedCAController:
def __init__(self, mask_list = None):
self.mask_list = mask_list
if self.mask_list is None:
self.is_decom = False
else:
self.is_decom = True
def mask_img_to_mask_vec(self, mask, length):
mask_vec = torch.nn.functional.interpolate(mask.unsqueeze(0).unsqueeze(0), (length, length)).squeeze()
mask_vec = mask_vec.flatten()
return mask_vec
def ca_forward_decom(self, q, k_list, v_list, scale, place_in_unet):
# attn [Bh, N, d ]
# [8, 4096, 77]
# q [Bh, N, d] [8, 4096, 40] [8, 1024, 80] [8, 256,160] [8, 64, 160]
# k [Bh, P, d] [8, 77 , 40] [8, 77, 80] [8, 77, 160] [8, 77, 160]
# v [Bh, P, d] [8, 77 , 40] [8, 77, 80] [8, 77, 160] [8, 77, 160]
N = q.shape[1]
mask_vec_list = []
for mask in self.mask_list:
mask_vec = self.mask_img_to_mask_vec(mask, int(math.sqrt(N))) # [1,N,1]
mask_vec = mask_vec.unsqueeze(0).unsqueeze(-1)
mask_vec_list.append(mask_vec)
out = 0
for mask_vec, k, v in zip(mask_vec_list, k_list, v_list):
sim = torch.einsum("b i d, b j d -> b i j", q, k) * scale # [8, 4096, 20]
attn = sim.softmax(dim=-1) # [Bh,N,P] [8,4096,20]
attn = attn.masked_fill(mask_vec==0, 0)
masked_out = torch.einsum("b i j, b j d -> b i d", attn, v) # [Bh,N,d] [8,4096,320/h]
# mask_vec_inf = torch.where(mask_vec>0, 0, torch.finfo(k.dtype).min)
# masked_out1 = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask_vec_inf, op=None, scale=scale)
out += masked_out
return out
def reshape_heads_to_batch_dim(self):
def func(tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.num_heads
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
return func
def reshape_batch_dim_to_heads(self):
def func(tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.num_heads
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
return func
def register_attention_disentangled_control(unet, controller):
def ca_forward(self, place_in_unet):
to_out = self.to_out
if type(to_out) is torch.nn.modules.container.ModuleList:
to_out = self.to_out[0]
else:
to_out = self.to_out
def forward(x, encoder_hidden_states =None, attention_mask=None):
if isinstance(controller, DummyController): # SA CA full
q = self.to_q(x)
is_cross = encoder_hidden_states is not None
encoder_hidden_states = encoder_hidden_states if is_cross else x
k = self.to_k(encoder_hidden_states)
v = self.to_v(encoder_hidden_states)
q = self.head_to_batch_dim(q)
k = self.head_to_batch_dim(k)
v = self.head_to_batch_dim(v)
# sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
# attn = sim.softmax(dim=-1)
# attn = controller(attn, is_cross, place_in_unet)
# out = torch.einsum("b i j, b j d -> b i d", attn, v)
out = xformers.ops.memory_efficient_attention(
q, k, v, attn_bias=None, op=None, scale=self.scale
)
out = self.batch_to_head_dim(out)
else: # decom: CA+SA
is_cross = encoder_hidden_states is not None
assert is_cross is not None
encoder_hidden_states_list = encoder_hidden_states if is_cross else x
q = self.to_q(x)
q = self.head_to_batch_dim(q) # [Bh, 4096, 320/h ] h: 8
if is_cross: #CA
k_list = []
v_list = []
assert type(encoder_hidden_states_list) is list
for encoder_hidden_states in encoder_hidden_states_list:
k = self.to_k(encoder_hidden_states)
k = self.head_to_batch_dim(k) # [Bh, 77, 320/h ]
k_list.append(k)
v = self.to_v(encoder_hidden_states)
v = self.head_to_batch_dim(v) # [Bh, 77, 320/h ]
v_list.append(v)
out = controller.ca_forward_decom(q, k_list, v_list, self.scale, place_in_unet) # [Bh,N,d]
out = self.batch_to_head_dim(out)
else: # SA
exit("decomposing SA!")
k = self.to_k(x)
v = self.to_v(x)
k = self.head_to_batch_dim(k) # [Bh, 77, 320/h ]
v = self.head_to_batch_dim(v) # [Bh, 77, 320/h ]
import pdb; pdb.set_trace()
if k.shape[1] <= 1024 ** 2:
out = controller.sa_forward(q, k, v, self.scale, place_in_unet) # [Bh,N,d]
else:
print("warining")
out = controller.sa_forward_decom(q, k, v, self.scale, place_in_unet) # [Bh,N,d]
# sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
# attn = sim.softmax(dim=-1) # [8,4096,4096] [Bh,N,N]
# out = torch.einsum("b i j, b j d -> b i d", attn, v) # [Bh,N,d] [8,4096,320/h]
out = self.batch_to_head_dim(out) # [B, H, N, D]
return to_out(out)
return forward
if controller is None:
controller = DummyController()
def register_recr(net_, count, place_in_unet):
if net_.__class__.__name__ == 'Attention' and net_.to_k.in_features == unet.ca_dim:
net_.forward = ca_forward(net_, place_in_unet)
return count + 1
elif hasattr(net_, 'children'):
for net__ in net_.children():
count = register_recr(net__, count, place_in_unet)
return count
cross_att_count = 0
sub_nets = unet.named_children()
for net in sub_nets:
if "down" in net[0]:
down_count = register_recr(net[1], 0, "down")#6
cross_att_count += down_count
elif "up" in net[0]:
up_count = register_recr(net[1], 0, "up") #9
cross_att_count += up_count
elif "mid" in net[0]:
mid_count = register_recr(net[1], 0, "mid") #1
cross_att_count += mid_count
controller.num_att_layers = cross_att_count