diff --git a/flash_attention_jax/causal_flash_attention.py b/flash_attention_jax/causal_flash_attention.py index f998b94..61cc17f 100644 --- a/flash_attention_jax/causal_flash_attention.py +++ b/flash_attention_jax/causal_flash_attention.py @@ -69,8 +69,7 @@ def chunk_scanner(carries, _): return out, row_sum, row_max -@custom_vjp -def causal_flash_attention(q, k, v): +def _causal_flash_attention(q, k, v): q_len, dim, k_len, v_dim = *q.shape, *v.shape q_range = jnp.arange(q_len).reshape(q_len, 1) + (k_len - q_len) @@ -92,9 +91,14 @@ def chunk_scanner(chunk_idx, _): return out, (row_sum, row_max) +@custom_vjp +def causal_flash_attention(q, k, v): + out, _ = _causal_flash_attention(q, k, v) + return out + @jit def flash_attention_forward(q, k, v): - out, (row_sum, row_max) = causal_flash_attention(q, k, v) + out, (row_sum, row_max) = _causal_flash_attention(q, k, v) return out, (q, k, v, out, row_sum, row_max) def _query_chunk_flash_attention_backward(query_range_chunk, key_range, q, k, v, o, do, l, m): diff --git a/flash_attention_jax/cosine_sim_flash_attention.py b/flash_attention_jax/cosine_sim_flash_attention.py index c1cb6e8..6e3c7f8 100644 --- a/flash_attention_jax/cosine_sim_flash_attention.py +++ b/flash_attention_jax/cosine_sim_flash_attention.py @@ -63,8 +63,7 @@ def cosine_sim_flash_attention(q, k, v, key_mask): q, k = map(l2norm, (q, k)) return cosine_sim_flash_attention_after_l2norm(q, k, v, key_mask) -@custom_vjp -def cosine_sim_flash_attention_after_l2norm(q, k, v, key_mask): +def _cosine_sim_flash_attention_after_l2norm(q, k, v, key_mask): q_len, dim, v_dim = *q.shape, v.shape[-1] def chunk_scanner(chunk_idx, _): @@ -81,9 +80,14 @@ def chunk_scanner(chunk_idx, _): return out, (row_sum,) +@custom_vjp +def cosine_sim_flash_attention_after_l2norm(q, k, v, key_mask): + out, _ = _cosine_sim_flash_attention_after_l2norm(q, k, v, key_mask) + return out + @jit def flash_attention_forward(q, k, v, key_mask): - out, (row_sum,) = cosine_sim_flash_attention_after_l2norm(q, k, v, key_mask) + out, (row_sum,) = _cosine_sim_flash_attention_after_l2norm(q, k, v, key_mask) return out, (q, k, v, key_mask, out, row_sum) def _query_chunk_flash_attention_backward(q, k, v, key_mask,o, do, l): diff --git a/flash_attention_jax/flash_attention.py b/flash_attention_jax/flash_attention.py index e36b284..f8f92e5 100644 --- a/flash_attention_jax/flash_attention.py +++ b/flash_attention_jax/flash_attention.py @@ -68,9 +68,7 @@ def chunk_scanner(carries, _): return out, row_sum, row_max -@custom_vjp -@jit -def flash_attention(q, k, v, key_mask): +def _flash_attention(q, k, v, key_mask): batch, heads, q_len, dim, v_dim = *q.shape, v.shape[-1] def chunk_scanner(chunk_idx, _): @@ -91,12 +89,18 @@ def chunk_scanner(chunk_idx, _): return out, (row_sum, row_max) +@custom_vjp +@jit +def flash_attention(q, k, v, key_mask): + out, _ = _flash_attention(q, k, v, key_mask) + return out + @jit def flash_attention_forward(q, k, v, key_mask): - out, (row_sum, row_max) = flash_attention(q, k, v, key_mask) + out, (row_sum, row_max) = _flash_attention(q, k, v, key_mask) return out, (q, k, v, key_mask, out, row_sum, row_max) -def _query_chunk_flash_attention_backward(q, k, v, key_mask,o, do, l, m): +def _query_chunk_flash_attention_backward(q, k, v, key_mask, o, do, l, m): q_len, batch, heads, dim, k_len, v_dim = *q.shape, v.shape[0], v.shape[-1] scale = 1 / jnp.sqrt(dim)