Skip to content

Commit

Permalink
Fix some memory issues in sub quad attention.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Oct 30, 2023
1 parent 125b03e commit c837a17
Showing 1 changed file with 11 additions and 24 deletions.
35 changes: 11 additions & 24 deletions comfy/ldm/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,32 +160,19 @@ def attention_sub_quad(query, key, value, heads, mask=None):

mem_free_total, mem_free_torch = model_management.get_free_memory(query.device, True)

chunk_threshold_bytes = mem_free_torch * 0.5 #Using only this seems to work better on AMD

kv_chunk_size_min = None
kv_chunk_size = None
query_chunk_size = None

for x in [4096, 2048, 1024, 512, 256]:
count = mem_free_total / (batch_x_heads * bytes_per_token * x * 4.0)
if count >= k_tokens:
kv_chunk_size = k_tokens
query_chunk_size = x
break

#not sure at all about the math here
#TODO: tweak this
if mem_free_total > 8192 * 1024 * 1024 * 1.3:
query_chunk_size_x = 1024 * 4
elif mem_free_total > 4096 * 1024 * 1024 * 1.3:
query_chunk_size_x = 1024 * 2
else:
query_chunk_size_x = 1024
kv_chunk_size_min_x = None
kv_chunk_size_x = (int((chunk_threshold_bytes // (batch_x_heads * bytes_per_token * query_chunk_size_x)) * 2.0) // 1024) * 1024
if kv_chunk_size_x < 1024:
kv_chunk_size_x = None

if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes:
# the big matmul fits into our memory limit; do everything in 1 chunk,
# i.e. send it down the unchunked fast-path
query_chunk_size = q_tokens
kv_chunk_size = k_tokens
else:
query_chunk_size = query_chunk_size_x
kv_chunk_size = kv_chunk_size_x
kv_chunk_size_min = kv_chunk_size_min_x
if query_chunk_size is None:
query_chunk_size = 512

hidden_states = efficient_dot_product_attention(
query,
Expand Down

0 comments on commit c837a17

Please sign in to comment.