Skip to content

Commit

Permalink
refactor custom_vjp use for type agreement
Browse files Browse the repository at this point in the history
When defining a forward rule for `jax.custom_vjp(primal_fn)`, if `primal_fn`
has output type `T` then we need the forward rule to have output type `(T, R)`
for some `R`. That is, we need the first output of the forward rule to look
like the full output of `primal_fn`. (Here the `R` values represent the
'residuals' computed on the forward pass to save for use on the backward pass.)

This PR fixes a disagreement between `custom_vjp`-decorated functions and their
corresponding forward rules.

The disagreement caused some interesting behavior! Discussed on
#7

Separately, I'm going to try to get JAX to raise a better error message in this
case; the error message was some really confusing JAX-internals thing.
  • Loading branch information
mattjj committed Sep 29, 2022
1 parent e5efb90 commit 78ce0a9
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 11 deletions.
10 changes: 7 additions & 3 deletions flash_attention_jax/causal_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,7 @@ def chunk_scanner(carries, _):

return out, row_sum, row_max

@custom_vjp
def causal_flash_attention(q, k, v):
def _causal_flash_attention(q, k, v):
q_len, dim, k_len, v_dim = *q.shape, *v.shape

q_range = jnp.arange(q_len).reshape(q_len, 1) + (k_len - q_len)
Expand All @@ -92,9 +91,14 @@ def chunk_scanner(chunk_idx, _):

return out, (row_sum, row_max)

@custom_vjp
def causal_flash_attention(q, k, v):
out, _ = _causal_flash_attention(q, k, v)
return out

@jit
def flash_attention_forward(q, k, v):
out, (row_sum, row_max) = causal_flash_attention(q, k, v)
out, (row_sum, row_max) = _causal_flash_attention(q, k, v)
return out, (q, k, v, out, row_sum, row_max)

def _query_chunk_flash_attention_backward(query_range_chunk, key_range, q, k, v, o, do, l, m):
Expand Down
10 changes: 7 additions & 3 deletions flash_attention_jax/cosine_sim_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,7 @@ def cosine_sim_flash_attention(q, k, v, key_mask):
q, k = map(l2norm, (q, k))
return cosine_sim_flash_attention_after_l2norm(q, k, v, key_mask)

@custom_vjp
def cosine_sim_flash_attention_after_l2norm(q, k, v, key_mask):
def _cosine_sim_flash_attention_after_l2norm(q, k, v, key_mask):
q_len, dim, v_dim = *q.shape, v.shape[-1]

def chunk_scanner(chunk_idx, _):
Expand All @@ -81,9 +80,14 @@ def chunk_scanner(chunk_idx, _):

return out, (row_sum,)

@custom_vjp
def cosine_sim_flash_attention_after_l2norm(q, k, v, key_mask):
out, _ = _cosine_sim_flash_attention_after_l2norm(q, k, v, key_mask)
return out

@jit
def flash_attention_forward(q, k, v, key_mask):
out, (row_sum,) = cosine_sim_flash_attention_after_l2norm(q, k, v, key_mask)
out, (row_sum,) = _cosine_sim_flash_attention_after_l2norm(q, k, v, key_mask)
return out, (q, k, v, key_mask, out, row_sum)

def _query_chunk_flash_attention_backward(q, k, v, key_mask,o, do, l):
Expand Down
14 changes: 9 additions & 5 deletions flash_attention_jax/flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,7 @@ def chunk_scanner(carries, _):

return out, row_sum, row_max

@custom_vjp
@jit
def flash_attention(q, k, v, key_mask):
def _flash_attention(q, k, v, key_mask):
batch, heads, q_len, dim, v_dim = *q.shape, v.shape[-1]

def chunk_scanner(chunk_idx, _):
Expand All @@ -91,12 +89,18 @@ def chunk_scanner(chunk_idx, _):

return out, (row_sum, row_max)

@custom_vjp
@jit
def flash_attention(q, k, v, key_mask):
out, _ = _flash_attention(q, k, v, key_mask)
return out

@jit
def flash_attention_forward(q, k, v, key_mask):
out, (row_sum, row_max) = flash_attention(q, k, v, key_mask)
out, (row_sum, row_max) = _flash_attention(q, k, v, key_mask)
return out, (q, k, v, key_mask, out, row_sum, row_max)

def _query_chunk_flash_attention_backward(q, k, v, key_mask,o, do, l, m):
def _query_chunk_flash_attention_backward(q, k, v, key_mask, o, do, l, m):
q_len, batch, heads, dim, k_len, v_dim = *q.shape, v.shape[0], v.shape[-1]

scale = 1 / jnp.sqrt(dim)
Expand Down

0 comments on commit 78ce0a9

Please sign in to comment.