From 78ce0a9d0f49039d28b96f64f424e2dc895eeb72 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Thu, 29 Sep 2022 15:02:23 -0700 Subject: [PATCH] refactor custom_vjp use for type agreement When defining a forward rule for `jax.custom_vjp(primal_fn)`, if `primal_fn` has output type `T` then we need the forward rule to have output type `(T, R)` for some `R`. That is, we need the first output of the forward rule to look like the full output of `primal_fn`. (Here the `R` values represent the 'residuals' computed on the forward pass to save for use on the backward pass.) This PR fixes a disagreement between `custom_vjp`-decorated functions and their corresponding forward rules. The disagreement caused some interesting behavior! Discussed on https://github.com/lucidrains/flash-attention-jax/issues/7 Separately, I'm going to try to get JAX to raise a better error message in this case; the error message was some really confusing JAX-internals thing. --- flash_attention_jax/causal_flash_attention.py | 10 +++++++--- flash_attention_jax/cosine_sim_flash_attention.py | 10 +++++++--- flash_attention_jax/flash_attention.py | 14 +++++++++----- 3 files changed, 23 insertions(+), 11 deletions(-) diff --git a/flash_attention_jax/causal_flash_attention.py b/flash_attention_jax/causal_flash_attention.py index f998b94..61cc17f 100644 --- a/flash_attention_jax/causal_flash_attention.py +++ b/flash_attention_jax/causal_flash_attention.py @@ -69,8 +69,7 @@ def chunk_scanner(carries, _): return out, row_sum, row_max -@custom_vjp -def causal_flash_attention(q, k, v): +def _causal_flash_attention(q, k, v): q_len, dim, k_len, v_dim = *q.shape, *v.shape q_range = jnp.arange(q_len).reshape(q_len, 1) + (k_len - q_len) @@ -92,9 +91,14 @@ def chunk_scanner(chunk_idx, _): return out, (row_sum, row_max) +@custom_vjp +def causal_flash_attention(q, k, v): + out, _ = _causal_flash_attention(q, k, v) + return out + @jit def flash_attention_forward(q, k, v): - out, (row_sum, row_max) = causal_flash_attention(q, k, v) + out, (row_sum, row_max) = _causal_flash_attention(q, k, v) return out, (q, k, v, out, row_sum, row_max) def _query_chunk_flash_attention_backward(query_range_chunk, key_range, q, k, v, o, do, l, m): diff --git a/flash_attention_jax/cosine_sim_flash_attention.py b/flash_attention_jax/cosine_sim_flash_attention.py index c1cb6e8..6e3c7f8 100644 --- a/flash_attention_jax/cosine_sim_flash_attention.py +++ b/flash_attention_jax/cosine_sim_flash_attention.py @@ -63,8 +63,7 @@ def cosine_sim_flash_attention(q, k, v, key_mask): q, k = map(l2norm, (q, k)) return cosine_sim_flash_attention_after_l2norm(q, k, v, key_mask) -@custom_vjp -def cosine_sim_flash_attention_after_l2norm(q, k, v, key_mask): +def _cosine_sim_flash_attention_after_l2norm(q, k, v, key_mask): q_len, dim, v_dim = *q.shape, v.shape[-1] def chunk_scanner(chunk_idx, _): @@ -81,9 +80,14 @@ def chunk_scanner(chunk_idx, _): return out, (row_sum,) +@custom_vjp +def cosine_sim_flash_attention_after_l2norm(q, k, v, key_mask): + out, _ = _cosine_sim_flash_attention_after_l2norm(q, k, v, key_mask) + return out + @jit def flash_attention_forward(q, k, v, key_mask): - out, (row_sum,) = cosine_sim_flash_attention_after_l2norm(q, k, v, key_mask) + out, (row_sum,) = _cosine_sim_flash_attention_after_l2norm(q, k, v, key_mask) return out, (q, k, v, key_mask, out, row_sum) def _query_chunk_flash_attention_backward(q, k, v, key_mask,o, do, l): diff --git a/flash_attention_jax/flash_attention.py b/flash_attention_jax/flash_attention.py index e36b284..f8f92e5 100644 --- a/flash_attention_jax/flash_attention.py +++ b/flash_attention_jax/flash_attention.py @@ -68,9 +68,7 @@ def chunk_scanner(carries, _): return out, row_sum, row_max -@custom_vjp -@jit -def flash_attention(q, k, v, key_mask): +def _flash_attention(q, k, v, key_mask): batch, heads, q_len, dim, v_dim = *q.shape, v.shape[-1] def chunk_scanner(chunk_idx, _): @@ -91,12 +89,18 @@ def chunk_scanner(chunk_idx, _): return out, (row_sum, row_max) +@custom_vjp +@jit +def flash_attention(q, k, v, key_mask): + out, _ = _flash_attention(q, k, v, key_mask) + return out + @jit def flash_attention_forward(q, k, v, key_mask): - out, (row_sum, row_max) = flash_attention(q, k, v, key_mask) + out, (row_sum, row_max) = _flash_attention(q, k, v, key_mask) return out, (q, k, v, key_mask, out, row_sum, row_max) -def _query_chunk_flash_attention_backward(q, k, v, key_mask,o, do, l, m): +def _query_chunk_flash_attention_backward(q, k, v, key_mask, o, do, l, m): q_len, batch, heads, dim, k_len, v_dim = *q.shape, v.shape[0], v.shape[-1] scale = 1 / jnp.sqrt(dim)