Skip to content

[tx] Add cuDNN flash attention#879

Merged
pcmoritz merged 30 commits intoNovaSky-AI:mainfrom
raulchen:flash-attention
Jan 24, 2026
Merged

[tx] Add cuDNN flash attention#879
pcmoritz merged 30 commits intoNovaSky-AI:mainfrom
raulchen:flash-attention

Conversation

@raulchen
Copy link
Contributor

@raulchen raulchen commented Jan 15, 2026

Summary

  • Add dot_product_attention() in tx/layers/attention.py that
    uses cuDNN on GPU for both causal (prefill) and non-causal
    (decode) attention, with XLA fallback for CPU/TPU
  • Update Llama3 and Qwen3 models to use the shared attention
    function
  • Add GPU CI workflow via Anyscale

Test plan

  • GPU tests verify cuDNN output matches XLA mask-based
    attention
  • Tests cover: causal with padding, no padding, mixed lengths,
    GQA, decode (non-causal)
  • CPU CI excludes GPU tests (--ignore=tests/gpu)

Benchmark results #891 (comment)

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces an excellent optimization by adding support for cuDNN flash attention, which will significantly reduce memory usage during training and prefill on GPUs. The implementation is robust, automatically handling both left- and right-padded sequences by cleverly converting them to the right-padded format required by cuDNN. The fallback to a standard mask-based attention for the decode phase and non-GPU backends is a sound design choice. The new functionality is accompanied by a comprehensive set of unit tests for both CPU and GPU, ensuring numerical correctness across different scenarios. The addition of a dedicated GPU CI workflow is also a great step towards maintaining the stability of this feature. I have one suggestion to further improve the test coverage for an edge case.

@pcmoritz pcmoritz added the tx label Jan 15, 2026
@raulchen
Copy link
Contributor Author

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request introduces cuDNN flash attention support for JAX models, significantly reducing memory usage for causal attention during training and prefill. It includes a new shared attention module that intelligently selects between cuDNN flash attention and a mask-based fallback, handling left-padded sequences by internally converting them to a right-padded format. New GPU CI workflows and comprehensive tests for the new attention mechanism and sequence shifting utility have also been added. The changes are well-structured and integrate smoothly into the existing model architectures.

raulchen and others added 14 commits January 20, 2026 18:54
- Use seq_lengths instead of attention_mask for attention computation
- On GPU: use cuDNN flash attention with query_seq_lengths/key_value_seq_lengths
- On CPU/TPU: fall back to mask-based attention (construct mask from seq_lengths)
- cuDNN flash attention provides O(seq) memory vs O(seq²) for standard attention

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Extract shared attention logic to tx/models/attention.py
- Use cuDNN flash attention only for right-padded sequences on GPU
- Fall back to mask-based attention for left-padded (generation) or CPU/TPU
- Fixes generation bug where cuDNN received wrong valid positions
Shift left-padded sequences to right-padded before applying cuDNN flash
attention, then shift output back. This enables O(S^2) -> O(S) memory
savings for inference prefill while keeping mask-based attention for
decode (where flash attention provides minimal benefit).

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Use argmax to find first valid token position (0 for right-padded, >0 for left-padded)
- Always apply shift (no-op when shift=0), avoiding dual-branch compilation
- Document that attention_mask must have at least one valid token per batch

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Add CPU tests for _shift_sequences and basic attention (tests/models/)
- Add GPU tests for cuDNN vs mask-based numerical equivalence (tests/gpu/)
- Add gpu_skyrl_tx.yaml workflow using Anyscale for GPU testing
- Update cpu_skyrl_tx.yaml to exclude tests/gpu/

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
@raulchen raulchen force-pushed the flash-attention branch 2 times, most recently from 91b1fc0 to c8e7095 Compare January 21, 2026 02:54
pcmoritz added a commit that referenced this pull request Jan 23, 2026
)

This is in preparation for merging
#879.
Updated version of #918, rebased
on top of #919

The reason we do this is, while it is slightly less natural for sampling
prefill to be left aligned (right padded), it makes things more uniform,
since that's the same alignment used during training, and the jax cudnn
flash attention doesn't currently support left padding as far as we are
aware. It is a small change and actually a little simpler, so what is
not to like about it. Plus the work on the PR uncovered a bug in the
prompt logprobs.
raulchen and others added 3 commits January 23, 2026 08:13
Resolve conflicts in llama3.py and qwen3.py by keeping both
dot_product_attention and LogitsProcessorMixin imports.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Main branch updated inference to use right padding, so the shifting
logic to convert left-padded to right-padded format is no longer needed.

- Remove _shift_sequences() function from attention.py
- Simplify dot_product_attention() to directly use cuDNN
- Remove CPU tests for shift_sequences
- Update GPU tests to only test right-padded sequences

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
ray_version: "2.51.1"
compute_config: l4_ci
working_dir: .
env_vars:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm pretty sure this is not needed any more on modern Ray versions :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, copied from other files. removed all

pcmoritz added a commit that referenced this pull request Jan 24, 2026
In #927, we left aligned the
prompts for sampling prefill in preparation for
#879. Here we shift them after
the prefill so they will again be right aligned for the decoding. This
will allow us to use the cudnn attention
#927 for both prefill and
decoding.

---------

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
pcmoritz added a commit that referenced this pull request Jan 24, 2026
This PR writes the new decoded token into the kv cache in such a way
that the whole sequence is left aligned. This is needed so that the
CUDNN attention #879 truly works
without attention mask.
@pcmoritz
Copy link
Collaborator

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant performance improvement by adding support for cuDNN's flash attention on GPUs. The new dot_product_attention utility is well-designed, providing a clean abstraction with a fallback for other backends. The models are correctly updated to use this new utility, and the changes are accompanied by a comprehensive suite of GPU-specific tests that validate the numerical equivalence with the existing XLA implementation.

My review focuses on the new test code, where I've suggested a couple of improvements for maintainability and adherence to modern JAX practices. Overall, this is a solid contribution.

Comment on lines +21 to +23
q = jax.random.normal(jax.random.key(0), (batch, seq_len, num_heads, head_dim), dtype=dtype)
k = jax.random.normal(jax.random.key(1), (batch, seq_len, num_kv_heads, head_dim), dtype=dtype)
v = jax.random.normal(jax.random.key(2), (batch, seq_len, num_kv_heads, head_dim), dtype=dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

jax.random.key is an alias for jax.random.PRNGKey and its usage is discouraged in new code. Please use jax.random.PRNGKey for clarity and future compatibility.

Suggested change
q = jax.random.normal(jax.random.key(0), (batch, seq_len, num_heads, head_dim), dtype=dtype)
k = jax.random.normal(jax.random.key(1), (batch, seq_len, num_kv_heads, head_dim), dtype=dtype)
v = jax.random.normal(jax.random.key(2), (batch, seq_len, num_kv_heads, head_dim), dtype=dtype)
q = jax.random.normal(jax.random.PRNGKey(0), (batch, seq_len, num_heads, head_dim), dtype=dtype)
k = jax.random.normal(jax.random.PRNGKey(1), (batch, seq_len, num_kv_heads, head_dim), dtype=dtype)
v = jax.random.normal(jax.random.PRNGKey(2), (batch, seq_len, num_kv_heads, head_dim), dtype=dtype)

Comment on lines +95 to +97
q = jax.random.normal(jax.random.key(0), (batch, 1, num_heads, head_dim), dtype=jnp.bfloat16)
k = jax.random.normal(jax.random.key(1), (batch, kv_len, num_heads, head_dim), dtype=jnp.bfloat16)
v = jax.random.normal(jax.random.key(2), (batch, kv_len, num_heads, head_dim), dtype=jnp.bfloat16)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

jax.random.key is an alias for jax.random.PRNGKey and its usage is discouraged in new code. Please use jax.random.PRNGKey for clarity and future compatibility.

Additionally, this block duplicates tensor creation logic from make_qkv and test_gqa_decode. Consider refactoring make_qkv to accept separate q_len and kv_len to centralize this logic and improve maintainability.

Suggested change
q = jax.random.normal(jax.random.key(0), (batch, 1, num_heads, head_dim), dtype=jnp.bfloat16)
k = jax.random.normal(jax.random.key(1), (batch, kv_len, num_heads, head_dim), dtype=jnp.bfloat16)
v = jax.random.normal(jax.random.key(2), (batch, kv_len, num_heads, head_dim), dtype=jnp.bfloat16)
q = jax.random.normal(jax.random.PRNGKey(0), (batch, 1, num_heads, head_dim), dtype=jnp.bfloat16)
k = jax.random.normal(jax.random.PRNGKey(1), (batch, kv_len, num_heads, head_dim), dtype=jnp.bfloat16)
v = jax.random.normal(jax.random.PRNGKey(2), (batch, kv_len, num_heads, head_dim), dtype=jnp.bfloat16)

Comment on lines +111 to +113
q = jax.random.normal(jax.random.key(0), (batch, 1, num_heads, head_dim), dtype=jnp.bfloat16)
k = jax.random.normal(jax.random.key(1), (batch, kv_len, num_kv_heads, head_dim), dtype=jnp.bfloat16)
v = jax.random.normal(jax.random.key(2), (batch, kv_len, num_kv_heads, head_dim), dtype=jnp.bfloat16)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

jax.random.key is an alias for jax.random.PRNGKey and its usage is discouraged in new code. Please use jax.random.PRNGKey for clarity and future compatibility.

As mentioned in test_decode, this is another instance of duplicated tensor creation logic. Refactoring make_qkv would benefit this test as well.

Suggested change
q = jax.random.normal(jax.random.key(0), (batch, 1, num_heads, head_dim), dtype=jnp.bfloat16)
k = jax.random.normal(jax.random.key(1), (batch, kv_len, num_kv_heads, head_dim), dtype=jnp.bfloat16)
v = jax.random.normal(jax.random.key(2), (batch, kv_len, num_kv_heads, head_dim), dtype=jnp.bfloat16)
q = jax.random.normal(jax.random.PRNGKey(0), (batch, 1, num_heads, head_dim), dtype=jnp.bfloat16)
k = jax.random.normal(jax.random.PRNGKey(1), (batch, kv_len, num_kv_heads, head_dim), dtype=jnp.bfloat16)
v = jax.random.normal(jax.random.PRNGKey(2), (batch, kv_len, num_kv_heads, head_dim), dtype=jnp.bfloat16)

@pcmoritz pcmoritz merged commit 4feedf0 into NovaSky-AI:main Jan 24, 2026
4 of 7 checks passed
tanmaysachan pushed a commit to tanmaysachan/SkyRL that referenced this pull request Jan 25, 2026
…ovaSky-AI#927)

This is in preparation for merging
NovaSky-AI#879.
Updated version of NovaSky-AI#918, rebased
on top of NovaSky-AI#919

The reason we do this is, while it is slightly less natural for sampling
prefill to be left aligned (right padded), it makes things more uniform,
since that's the same alignment used during training, and the jax cudnn
flash attention doesn't currently support left padding as far as we are
aware. It is a small change and actually a little simpler, so what is
not to like about it. Plus the work on the PR uncovered a bug in the
prompt logprobs.
tanmaysachan pushed a commit to tanmaysachan/SkyRL that referenced this pull request Jan 25, 2026
In NovaSky-AI#927, we left aligned the
prompts for sampling prefill in preparation for
NovaSky-AI#879. Here we shift them after
the prefill so they will again be right aligned for the decoding. This
will allow us to use the cudnn attention
NovaSky-AI#927 for both prefill and
decoding.

---------

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
tanmaysachan pushed a commit to tanmaysachan/SkyRL that referenced this pull request Jan 25, 2026
This PR writes the new decoded token into the kv cache in such a way
that the whole sequence is left aligned. This is needed so that the
CUDNN attention NovaSky-AI#879 truly works
without attention mask.
tanmaysachan pushed a commit to tanmaysachan/SkyRL that referenced this pull request Jan 25, 2026
## Summary
  - Add `dot_product_attention()` in `tx/layers/attention.py` that
  uses cuDNN on GPU for both causal (prefill) and non-causal
  (decode) attention, with XLA fallback for CPU/TPU
  - Update Llama3 and Qwen3 models to use the shared attention
  function
  - Add GPU CI workflow via Anyscale

  ## Test plan
  - [ ] GPU tests verify cuDNN output matches XLA mask-based
  attention
  - [ ] Tests cover: causal with padding, no padding, mixed lengths,
   GQA, decode (non-causal)
  - [ ] CPU CI excludes GPU tests (`--ignore=tests/gpu`)
  
Benchmark results
NovaSky-AI#891 (comment)

---------

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
Co-authored-by: Philipp Moritz <pcmoritz@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants