Skip to content

Commit

Permalink
cherry-pick @Any-Winter-4079's invoke-ai/InvokeAI#540. this is a coll…
Browse files Browse the repository at this point in the history
…aboration incorporating a lot of people's contributions -- including for example @Doggettx and the original code from @neonsecret on which the Doggetx optimizations were based (see invoke-ai/InvokeAI#431, https://github.com/sd-webui/stable-diffusion-webui/pull/771\#issuecomment-1239716055). Takes exactly the same amount of time to run 8 steps as original CompVis code does (10.4 secs, ~1.25s/it). (#1177)

Co-authored-by: Alex Birch <birch-san@users.noreply.github.com>
  • Loading branch information
codedealer and Birch-san authored Sep 16, 2022
1 parent 4efe62b commit c465891
Showing 1 changed file with 72 additions and 35 deletions.
107 changes: 72 additions & 35 deletions ldm/modules/attention.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import gc
from inspect import isfunction
import math
import torch
Expand All @@ -8,6 +7,8 @@

from ldm.modules.diffusionmodules.util import checkpoint

import psutil


def exists(val):
return val is not None
Expand Down Expand Up @@ -151,14 +152,13 @@ def forward(self, x):


class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., att_step=1):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)

self.scale = dim_head ** -0.5
self.heads = heads
self.att_step = att_step

self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
Expand All @@ -169,23 +169,50 @@ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.
nn.Dropout(dropout)
)

def forward(self, x, context=None, mask=None):
h = self.heads

q_in = self.to_q(x)
context = default(context, x)

k_in = self.to_k(context)
v_in = self.to_v(context)

del context, x

q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
del q_in, k_in, v_in


r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)

if torch.cuda.is_available():
self.einsum_op = self.einsum_op_cuda
else:
self.mem_total = psutil.virtual_memory().total / (1024**3)
self.einsum_op = self.einsum_op_mps_v1 if self.mem_total >= 32 else self.einsum_op_mps_v2

def einsum_op_compvis(self, q, k, v, r1):
s1 = einsum('b i d, b j d -> b i j', q, k) * self.scale # faster
s2 = s1.softmax(dim=-1, dtype=q.dtype)
del s1
r1 = einsum('b i j, b j d -> b i d', s2, v)
del s2
return r1

def einsum_op_mps_v1(self, q, k, v, r1):
if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
r1 = self.einsum_op_compvis(q, k, v, r1)
else:
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale
s2 = s1.softmax(dim=-1, dtype=r1.dtype)
del s1
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
del s2
return r1

def einsum_op_mps_v2(self, q, k, v, r1):
if self.mem_total >= 8 and q.shape[1] <= 4096:
r1 = self.einsum_op_compvis(q, k, v, r1)
else:
slice_size = 1
for i in range(0, q.shape[0], slice_size):
end = min(q.shape[0], i + slice_size)
s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
s1 *= self.scale
s2 = s1.softmax(dim=-1, dtype=r1.dtype)
del s1
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
del s2
return r1

def einsum_op_cuda(self, q, k, v, r1):
stats = torch.cuda.memory_stats(q.device)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
Expand All @@ -200,30 +227,39 @@ def forward(self, x, context=None, mask=None):

if mem_required > mem_free_total:
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")

if steps > 64:
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free')
f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free')

slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
end = min(q.shape[1], i + slice_size)
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale

s2 = s1.softmax(dim=-1)
s2 = s1.softmax(dim=-1, dtype=r1.dtype)
del s1

r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
del s2
del s2
return r1

del q, k, v
def forward(self, x, context=None, mask=None):
h = self.heads

q = self.to_q(x)
context = default(context, x)
del x
k = self.to_k(context)
v = self.to_v(context)
del context

q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))

r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
r1 = self.einsum_op(q, k, v, r1)
del q, k, v
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
del r1

return self.to_out(r2)


Expand All @@ -243,9 +279,10 @@ def forward(self, x, context=None):
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)

def _forward(self, x, context=None):
x = self.attn1(self.norm1(x)) + x
x = self.attn2(self.norm2(x), context=context) + x
x = self.ff(self.norm3(x)) + x
x = x.contiguous() if x.device.type == 'mps' else x
x += self.attn1(self.norm1(x))
x += self.attn2(self.norm2(x), context=context)
x += self.ff(self.norm3(x))
return x


Expand Down Expand Up @@ -292,4 +329,4 @@ def forward(self, x, context=None):
x = block(x, context=context)
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
x = self.proj_out(x)
return x + x_in
return x + x_in

0 comments on commit c465891

Please sign in to comment.