Skip to content

Commit

Permalink
Merge pull request #8 from mattjj/custom-vjp-fix
Browse files Browse the repository at this point in the history
refactor custom_vjp use for type agreement
  • Loading branch information
lucidrains authored Sep 29, 2022
2 parents e5efb90 + 78ce0a9 commit 25df3b9
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 11 deletions.
10 changes: 7 additions & 3 deletions flash_attention_jax/causal_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,7 @@ def chunk_scanner(carries, _):

return out, row_sum, row_max

@custom_vjp
def causal_flash_attention(q, k, v):
def _causal_flash_attention(q, k, v):
q_len, dim, k_len, v_dim = *q.shape, *v.shape

q_range = jnp.arange(q_len).reshape(q_len, 1) + (k_len - q_len)
Expand All @@ -92,9 +91,14 @@ def chunk_scanner(chunk_idx, _):

return out, (row_sum, row_max)

@custom_vjp
def causal_flash_attention(q, k, v):
out, _ = _causal_flash_attention(q, k, v)
return out

@jit
def flash_attention_forward(q, k, v):
out, (row_sum, row_max) = causal_flash_attention(q, k, v)
out, (row_sum, row_max) = _causal_flash_attention(q, k, v)
return out, (q, k, v, out, row_sum, row_max)

def _query_chunk_flash_attention_backward(query_range_chunk, key_range, q, k, v, o, do, l, m):
Expand Down
10 changes: 7 additions & 3 deletions flash_attention_jax/cosine_sim_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,7 @@ def cosine_sim_flash_attention(q, k, v, key_mask):
q, k = map(l2norm, (q, k))
return cosine_sim_flash_attention_after_l2norm(q, k, v, key_mask)

@custom_vjp
def cosine_sim_flash_attention_after_l2norm(q, k, v, key_mask):
def _cosine_sim_flash_attention_after_l2norm(q, k, v, key_mask):
q_len, dim, v_dim = *q.shape, v.shape[-1]

def chunk_scanner(chunk_idx, _):
Expand All @@ -81,9 +80,14 @@ def chunk_scanner(chunk_idx, _):

return out, (row_sum,)

@custom_vjp
def cosine_sim_flash_attention_after_l2norm(q, k, v, key_mask):
out, _ = _cosine_sim_flash_attention_after_l2norm(q, k, v, key_mask)
return out

@jit
def flash_attention_forward(q, k, v, key_mask):
out, (row_sum,) = cosine_sim_flash_attention_after_l2norm(q, k, v, key_mask)
out, (row_sum,) = _cosine_sim_flash_attention_after_l2norm(q, k, v, key_mask)
return out, (q, k, v, key_mask, out, row_sum)

def _query_chunk_flash_attention_backward(q, k, v, key_mask,o, do, l):
Expand Down
14 changes: 9 additions & 5 deletions flash_attention_jax/flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,7 @@ def chunk_scanner(carries, _):

return out, row_sum, row_max

@custom_vjp
@jit
def flash_attention(q, k, v, key_mask):
def _flash_attention(q, k, v, key_mask):
batch, heads, q_len, dim, v_dim = *q.shape, v.shape[-1]

def chunk_scanner(chunk_idx, _):
Expand All @@ -91,12 +89,18 @@ def chunk_scanner(chunk_idx, _):

return out, (row_sum, row_max)

@custom_vjp
@jit
def flash_attention(q, k, v, key_mask):
out, _ = _flash_attention(q, k, v, key_mask)
return out

@jit
def flash_attention_forward(q, k, v, key_mask):
out, (row_sum, row_max) = flash_attention(q, k, v, key_mask)
out, (row_sum, row_max) = _flash_attention(q, k, v, key_mask)
return out, (q, k, v, key_mask, out, row_sum, row_max)

def _query_chunk_flash_attention_backward(q, k, v, key_mask,o, do, l, m):
def _query_chunk_flash_attention_backward(q, k, v, key_mask, o, do, l, m):
q_len, batch, heads, dim, k_len, v_dim = *q.shape, v.shape[0], v.shape[-1]

scale = 1 / jnp.sqrt(dim)
Expand Down

0 comments on commit 25df3b9

Please sign in to comment.