Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A faster flash attention bwd implementation #177

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

tonywu95
Copy link
Contributor

@tonywu95 tonywu95 commented Jun 22, 2023

  • Decompose the bwd kernel into two kernels, one for dq and one for dk,dv.
  • Extra parallelism over the sequence length axis.
  • On a benchmark, with causal=True, it is close to 6X faster compared to the previous implementation. ~3X faster than XLA bwd pass.

- Decompose the bwd kernel into two kernels, one for dq and one for dk,dv. 
- Extra parallelism over the sequence length axis.
- On a benchmark, it is 4X faster compared to the previous implementation. 2X faster than XLA bwd pass.
Copy link
Collaborator

@sharadmv sharadmv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

High level comment: the current backward pass is a fully fused kernel that parallelizes over batch * num heads number of threads.

For attention shapes that have small batch and heads (as is common in language model training) this kernel will underutilize the GPU.

However, there are applications where this kernel is faster than the two kernel variant.

Could you add the two kernel version as a separate backward pass impl, that way the user has the option of selecting the one they want?

jax_triton/pallas/ops/attention.py Outdated Show resolved Hide resolved
jax_triton/pallas/ops/attention.py Show resolved Hide resolved
jax_triton/pallas/ops/attention.py Outdated Show resolved Hide resolved
jax_triton/pallas/ops/attention.py Outdated Show resolved Hide resolved
jax_triton/pallas/ops/attention.py Outdated Show resolved Hide resolved
jax_triton/pallas/ops/attention.py Outdated Show resolved Hide resolved
jax_triton/pallas/ops/attention.py Outdated Show resolved Hide resolved
jax_triton/pallas/ops/attention.py Outdated Show resolved Hide resolved
jax_triton/pallas/ops/attention.py Outdated Show resolved Hide resolved
jax_triton/pallas/ops/attention.py Outdated Show resolved Hide resolved
@tonywu95 tonywu95 requested a review from sharadmv June 23, 2023 15:42
Copy link
Collaborator

@sharadmv sharadmv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you also add tests into pallas_test.py?

pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)),
pl.BlockSpec(lambda j, k, _: (j, 0, k, 0), (None, seq_len, None, head_dim)),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you rename to be (i, j, _)? Same below?

upper_bound = jt.cdiv(seq_len, block_k)
dq = lax.fori_loop(0, upper_bound, inner_loop, dq)
pl.store(dq_ref, (pl.ds(start_q * block_q, block_q),
slice(None)), dq, eviction_policy="evict_last")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need eviction policy here

Comment on lines +394 to +396
slice(None)), dv.astype(dv_ref.dtype))
pl.store(dk_ref, (pl.ds(start_k * block_k, block_k),
slice(None)), dk.astype(dk_ref.dtype))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: indentation

@@ -346,6 +450,65 @@ def _mha_backward(sm_scale: float, causal: bool, block_q: int, block_k: int,
num_warps=num_warps,
num_stages=1,
input_output_aliases={8: 0})(q, k, v, out, do_scaled, l, m, delta, dq)
elif backward_pass_impl == "triton_split":
# We accumulate into dq so we need to initialize it to zeros.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment is not accurate here

@@ -346,6 +450,65 @@ def _mha_backward(sm_scale: float, causal: bool, block_q: int, block_k: int,
num_warps=num_warps,
num_stages=1,
input_output_aliases={8: 0})(q, k, v, out, do_scaled, l, m, delta, dq)
elif backward_pass_impl == "triton_split":
# We accumulate into dq so we need to initialize it to zeros.
out_shapes_q = jax.ShapeDtypeStruct(q.shape, jnp.float32)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect we don't need dq to be f32 anymore. Could you try q.dtype?

@abhinavgoel95
Copy link

abhinavgoel95 commented Aug 9, 2023

@sharadmv Can this PR be merged? We see a big performance improvement on NVIDIA A100 GPUs with this PR.
Thank you.

@sharadmv
Copy link
Collaborator

sharadmv commented Aug 9, 2023

I left some comments. @tonywu95 do you have time to address them?

@skye
Copy link
Member

skye commented Jun 13, 2024

Hey @tonywu95, is it ok if we take over this PR and put you as a co-author? We'd love to get it in!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants