Skip to content

Commit

Permalink
Update attention.py
Browse files Browse the repository at this point in the history
  • Loading branch information
neonsecret authored Sep 2, 2022
1 parent 1857272 commit 47f8784
Showing 1 changed file with 11 additions and 7 deletions.
18 changes: 11 additions & 7 deletions ldm/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,23 +174,27 @@ def forward(self, x, context=None, mask=None):
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
del context, x

q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))

sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale # (8, 4096, 40)
del q, k

if exists(mask):
mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)
del mask

# attention, what we cannot get enough of
attn = sim.softmax(dim=-1)
# attention, what we cannot get enough of, by halves
sim[4:] = sim[4:].softmax(dim=-1)
sim[:4] = sim[:4].softmax(dim=-1)

out = einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
return self.to_out(out)
sim = einsum('b i j, b j d -> b i d', sim, v)
sim = rearrange(sim, '(b h) n d -> b n (h d)', h=h)
return self.to_out(sim)


class BasicTransformerBlock(nn.Module):
Expand Down Expand Up @@ -258,4 +262,4 @@ def forward(self, x, context=None):
x = block(x, context=context)
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
x = self.proj_out(x)
return x + x_in
return x + x_in

0 comments on commit 47f8784

Please sign in to comment.