[tx] Add cuDNN flash attention#879
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces an excellent optimization by adding support for cuDNN flash attention, which will significantly reduce memory usage during training and prefill on GPUs. The implementation is robust, automatically handling both left- and right-padded sequences by cleverly converting them to the right-padded format required by cuDNN. The fallback to a standard mask-based attention for the decode phase and non-GPU backends is a sound design choice. The new functionality is accompanied by a comprehensive set of unit tests for both CPU and GPU, ensuring numerical correctness across different scenarios. The addition of a dedicated GPU CI workflow is also a great step towards maintaining the stability of this feature. I have one suggestion to further improve the test coverage for an edge case.
|
/gemini review |
There was a problem hiding this comment.
Code Review
The pull request introduces cuDNN flash attention support for JAX models, significantly reducing memory usage for causal attention during training and prefill. It includes a new shared attention module that intelligently selects between cuDNN flash attention and a mask-based fallback, handling left-padded sequences by internally converting them to a right-padded format. New GPU CI workflows and comprehensive tests for the new attention mechanism and sequence shifting utility have also been added. The changes are well-structured and integrate smoothly into the existing model architectures.
- Use seq_lengths instead of attention_mask for attention computation - On GPU: use cuDNN flash attention with query_seq_lengths/key_value_seq_lengths - On CPU/TPU: fall back to mask-based attention (construct mask from seq_lengths) - cuDNN flash attention provides O(seq) memory vs O(seq²) for standard attention Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Extract shared attention logic to tx/models/attention.py - Use cuDNN flash attention only for right-padded sequences on GPU - Fall back to mask-based attention for left-padded (generation) or CPU/TPU - Fixes generation bug where cuDNN received wrong valid positions
Shift left-padded sequences to right-padded before applying cuDNN flash attention, then shift output back. This enables O(S^2) -> O(S) memory savings for inference prefill while keeping mask-based attention for decode (where flash attention provides minimal benefit). Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Use argmax to find first valid token position (0 for right-padded, >0 for left-padded) - Always apply shift (no-op when shift=0), avoiding dual-branch compilation - Document that attention_mask must have at least one valid token per batch Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Add CPU tests for _shift_sequences and basic attention (tests/models/) - Add GPU tests for cuDNN vs mask-based numerical equivalence (tests/gpu/) - Add gpu_skyrl_tx.yaml workflow using Anyscale for GPU testing - Update cpu_skyrl_tx.yaml to exclude tests/gpu/ Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
91b1fc0 to
c8e7095
Compare
) This is in preparation for merging #879. Updated version of #918, rebased on top of #919 The reason we do this is, while it is slightly less natural for sampling prefill to be left aligned (right padded), it makes things more uniform, since that's the same alignment used during training, and the jax cudnn flash attention doesn't currently support left padding as far as we are aware. It is a small change and actually a little simpler, so what is not to like about it. Plus the work on the PR uncovered a bug in the prompt logprobs.
Resolve conflicts in llama3.py and qwen3.py by keeping both dot_product_attention and LogitsProcessorMixin imports. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Main branch updated inference to use right padding, so the shifting logic to convert left-padded to right-padded format is no longer needed. - Remove _shift_sequences() function from attention.py - Simplify dot_product_attention() to directly use cuDNN - Remove CPU tests for shift_sequences - Update GPU tests to only test right-padded sequences Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
skyrl-tx/ci/anyscale_gpu_ci.yaml
Outdated
| ray_version: "2.51.1" | ||
| compute_config: l4_ci | ||
| working_dir: . | ||
| env_vars: |
There was a problem hiding this comment.
I'm pretty sure this is not needed any more on modern Ray versions :)
There was a problem hiding this comment.
yeah, copied from other files. removed all
In #927, we left aligned the prompts for sampling prefill in preparation for #879. Here we shift them after the prefill so they will again be right aligned for the decoding. This will allow us to use the cudnn attention #927 for both prefill and decoding. --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This PR writes the new decoded token into the kv cache in such a way that the whole sequence is left aligned. This is needed so that the CUDNN attention #879 truly works without attention mask.
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces a significant performance improvement by adding support for cuDNN's flash attention on GPUs. The new dot_product_attention utility is well-designed, providing a clean abstraction with a fallback for other backends. The models are correctly updated to use this new utility, and the changes are accompanied by a comprehensive suite of GPU-specific tests that validate the numerical equivalence with the existing XLA implementation.
My review focuses on the new test code, where I've suggested a couple of improvements for maintainability and adherence to modern JAX practices. Overall, this is a solid contribution.
| q = jax.random.normal(jax.random.key(0), (batch, seq_len, num_heads, head_dim), dtype=dtype) | ||
| k = jax.random.normal(jax.random.key(1), (batch, seq_len, num_kv_heads, head_dim), dtype=dtype) | ||
| v = jax.random.normal(jax.random.key(2), (batch, seq_len, num_kv_heads, head_dim), dtype=dtype) |
There was a problem hiding this comment.
jax.random.key is an alias for jax.random.PRNGKey and its usage is discouraged in new code. Please use jax.random.PRNGKey for clarity and future compatibility.
| q = jax.random.normal(jax.random.key(0), (batch, seq_len, num_heads, head_dim), dtype=dtype) | |
| k = jax.random.normal(jax.random.key(1), (batch, seq_len, num_kv_heads, head_dim), dtype=dtype) | |
| v = jax.random.normal(jax.random.key(2), (batch, seq_len, num_kv_heads, head_dim), dtype=dtype) | |
| q = jax.random.normal(jax.random.PRNGKey(0), (batch, seq_len, num_heads, head_dim), dtype=dtype) | |
| k = jax.random.normal(jax.random.PRNGKey(1), (batch, seq_len, num_kv_heads, head_dim), dtype=dtype) | |
| v = jax.random.normal(jax.random.PRNGKey(2), (batch, seq_len, num_kv_heads, head_dim), dtype=dtype) |
| q = jax.random.normal(jax.random.key(0), (batch, 1, num_heads, head_dim), dtype=jnp.bfloat16) | ||
| k = jax.random.normal(jax.random.key(1), (batch, kv_len, num_heads, head_dim), dtype=jnp.bfloat16) | ||
| v = jax.random.normal(jax.random.key(2), (batch, kv_len, num_heads, head_dim), dtype=jnp.bfloat16) |
There was a problem hiding this comment.
jax.random.key is an alias for jax.random.PRNGKey and its usage is discouraged in new code. Please use jax.random.PRNGKey for clarity and future compatibility.
Additionally, this block duplicates tensor creation logic from make_qkv and test_gqa_decode. Consider refactoring make_qkv to accept separate q_len and kv_len to centralize this logic and improve maintainability.
| q = jax.random.normal(jax.random.key(0), (batch, 1, num_heads, head_dim), dtype=jnp.bfloat16) | |
| k = jax.random.normal(jax.random.key(1), (batch, kv_len, num_heads, head_dim), dtype=jnp.bfloat16) | |
| v = jax.random.normal(jax.random.key(2), (batch, kv_len, num_heads, head_dim), dtype=jnp.bfloat16) | |
| q = jax.random.normal(jax.random.PRNGKey(0), (batch, 1, num_heads, head_dim), dtype=jnp.bfloat16) | |
| k = jax.random.normal(jax.random.PRNGKey(1), (batch, kv_len, num_heads, head_dim), dtype=jnp.bfloat16) | |
| v = jax.random.normal(jax.random.PRNGKey(2), (batch, kv_len, num_heads, head_dim), dtype=jnp.bfloat16) |
| q = jax.random.normal(jax.random.key(0), (batch, 1, num_heads, head_dim), dtype=jnp.bfloat16) | ||
| k = jax.random.normal(jax.random.key(1), (batch, kv_len, num_kv_heads, head_dim), dtype=jnp.bfloat16) | ||
| v = jax.random.normal(jax.random.key(2), (batch, kv_len, num_kv_heads, head_dim), dtype=jnp.bfloat16) |
There was a problem hiding this comment.
jax.random.key is an alias for jax.random.PRNGKey and its usage is discouraged in new code. Please use jax.random.PRNGKey for clarity and future compatibility.
As mentioned in test_decode, this is another instance of duplicated tensor creation logic. Refactoring make_qkv would benefit this test as well.
| q = jax.random.normal(jax.random.key(0), (batch, 1, num_heads, head_dim), dtype=jnp.bfloat16) | |
| k = jax.random.normal(jax.random.key(1), (batch, kv_len, num_kv_heads, head_dim), dtype=jnp.bfloat16) | |
| v = jax.random.normal(jax.random.key(2), (batch, kv_len, num_kv_heads, head_dim), dtype=jnp.bfloat16) | |
| q = jax.random.normal(jax.random.PRNGKey(0), (batch, 1, num_heads, head_dim), dtype=jnp.bfloat16) | |
| k = jax.random.normal(jax.random.PRNGKey(1), (batch, kv_len, num_kv_heads, head_dim), dtype=jnp.bfloat16) | |
| v = jax.random.normal(jax.random.PRNGKey(2), (batch, kv_len, num_kv_heads, head_dim), dtype=jnp.bfloat16) |
…ovaSky-AI#927) This is in preparation for merging NovaSky-AI#879. Updated version of NovaSky-AI#918, rebased on top of NovaSky-AI#919 The reason we do this is, while it is slightly less natural for sampling prefill to be left aligned (right padded), it makes things more uniform, since that's the same alignment used during training, and the jax cudnn flash attention doesn't currently support left padding as far as we are aware. It is a small change and actually a little simpler, so what is not to like about it. Plus the work on the PR uncovered a bug in the prompt logprobs.
In NovaSky-AI#927, we left aligned the prompts for sampling prefill in preparation for NovaSky-AI#879. Here we shift them after the prefill so they will again be right aligned for the decoding. This will allow us to use the cudnn attention NovaSky-AI#927 for both prefill and decoding. --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This PR writes the new decoded token into the kv cache in such a way that the whole sequence is left aligned. This is needed so that the CUDNN attention NovaSky-AI#879 truly works without attention mask.
## Summary - Add `dot_product_attention()` in `tx/layers/attention.py` that uses cuDNN on GPU for both causal (prefill) and non-causal (decode) attention, with XLA fallback for CPU/TPU - Update Llama3 and Qwen3 models to use the shared attention function - Add GPU CI workflow via Anyscale ## Test plan - [ ] GPU tests verify cuDNN output matches XLA mask-based attention - [ ] Tests cover: causal with padding, no padding, mixed lengths, GQA, decode (non-causal) - [ ] CPU CI excludes GPU tests (`--ignore=tests/gpu`) Benchmark results NovaSky-AI#891 (comment) --------- Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com> Co-authored-by: Philipp Moritz <pcmoritz@gmail.com>
Summary
dot_product_attention()intx/layers/attention.pythatuses cuDNN on GPU for both causal (prefill) and non-causal
(decode) attention, with XLA fallback for CPU/TPU
function
Test plan
attention
GQA, decode (non-causal)
--ignore=tests/gpu)Benchmark results #891 (comment)