-
Notifications
You must be signed in to change notification settings - Fork 261
[tx] Add cuDNN flash attention #879
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
235511c
d7f8caf
4c48cc6
63619a6
a9fca15
5fbfefe
defed63
775b746
1174531
7f88099
ffed524
a244a09
dd63e80
c8e7095
13922a9
a242427
48d07db
e84c8c7
38276cc
2e0bb7c
ef5c22f
76dd7f2
fb067ee
28117db
5c3ef62
d468ccb
061f0ca
eae82ff
748ca39
6920b6d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,50 @@ | ||
| name: SkyRL-tx-GPU | ||
|
|
||
| on: | ||
| push: | ||
| branches: [ main ] | ||
| paths: | ||
| - 'skyrl-tx/**' | ||
| - '.github/workflows/gpu_skyrl_tx.yaml' | ||
| pull_request: | ||
| paths: | ||
| - 'skyrl-tx/**' | ||
| - '.github/workflows/gpu_skyrl_tx.yaml' | ||
| workflow_dispatch: | ||
|
|
||
| permissions: | ||
| checks: write | ||
| contents: read | ||
|
|
||
| concurrency: | ||
| group: skyrl-tx-gpu-${{ github.workflow }}-${{ github.ref }} | ||
| cancel-in-progress: true | ||
|
|
||
| jobs: | ||
| skyrl_tx_gpu_tests: | ||
| runs-on: ubuntu-latest | ||
| defaults: | ||
| run: | ||
| shell: bash | ||
| working-directory: ./skyrl-tx | ||
|
|
||
| steps: | ||
| - uses: actions/checkout@v4 | ||
| - name: Set up Python | ||
| uses: actions/setup-python@v5 | ||
| with: | ||
| python-version: '3.12' | ||
| cache: 'pip' | ||
| - name: Install the latest version of uv | ||
| uses: astral-sh/setup-uv@v6 | ||
| with: | ||
| activate-environment: true | ||
| - name: Install dependencies | ||
| run: uv pip install anyscale==0.24.79 typer==0.9.0 | ||
| - name: GPU tests | ||
| env: | ||
| ANYSCALE_CLI_TOKEN: ${{ secrets.ANYSCALE_CLI_TOKEN }} | ||
| ANYSCALE_HOST: https://console.anyscale.com | ||
| run: | | ||
| anyscale job submit -f ci/anyscale_gpu_ci.yaml --timeout 10000 | ||
| anyscale job wait --cloud sky-anyscale-aws-us-east-1 --name skyrl-tx-gpu-ci --timeout 10000 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,8 @@ | ||
| name: skyrl-tx-gpu-ci | ||
| entrypoint: bash ci/gpu_ci_run.sh | ||
| image_uri: novaskyai/skyrl-train-ray-2.51.1-py3.12-cu12.8 | ||
| cloud: sky-anyscale-aws-us-east-1 | ||
| ray_version: "2.51.1" | ||
| compute_config: l4_ci | ||
| working_dir: . | ||
| max_retries: 0 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,7 @@ | ||
| #!/usr/bin/env bash | ||
| set -xeuo pipefail | ||
|
|
||
| export CI=true | ||
|
|
||
| # Run GPU-specific tests | ||
| uv run --extra gpu --extra tinker --extra dev pytest tests/gpu |
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,123 @@ | ||||||||||||||
| """GPU tests for flash attention. | ||||||||||||||
|
|
||||||||||||||
| These tests require a GPU and verify that cuDNN flash attention produces | ||||||||||||||
| numerically equivalent results to the mask-based implementation. | ||||||||||||||
| """ | ||||||||||||||
|
|
||||||||||||||
| import jax | ||||||||||||||
| import jax.numpy as jnp | ||||||||||||||
| import pytest | ||||||||||||||
|
|
||||||||||||||
| from tx.layers.attention import dot_product_attention | ||||||||||||||
|
|
||||||||||||||
| # Skip all tests if not on GPU | ||||||||||||||
| pytestmark = pytest.mark.skipif(jax.default_backend() != "gpu", reason="GPU tests require CUDA") | ||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
| def make_qkv(batch, seq_len, num_heads, head_dim, num_kv_heads=None, dtype=jnp.bfloat16): | ||||||||||||||
| """Create random Q, K, V tensors.""" | ||||||||||||||
| if num_kv_heads is None: | ||||||||||||||
| num_kv_heads = num_heads | ||||||||||||||
| q = jax.random.normal(jax.random.key(0), (batch, seq_len, num_heads, head_dim), dtype=dtype) | ||||||||||||||
| k = jax.random.normal(jax.random.key(1), (batch, seq_len, num_kv_heads, head_dim), dtype=dtype) | ||||||||||||||
| v = jax.random.normal(jax.random.key(2), (batch, seq_len, num_kv_heads, head_dim), dtype=dtype) | ||||||||||||||
| return q, k, v | ||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
| def make_right_padded_mask(batch, seq_len, seq_lengths): | ||||||||||||||
| """Create right-padded mask: [1,1,1,...,0,0].""" | ||||||||||||||
| seq_lengths = jnp.array(seq_lengths) | ||||||||||||||
| return (jnp.arange(seq_len)[None, :] < seq_lengths[:, None]).astype(jnp.float32) | ||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
| def assert_attention_match(q, k, v, mask, is_causal, head_dim, seq_lengths=None): | ||||||||||||||
| """Run both attention implementations and assert they match. | ||||||||||||||
|
|
||||||||||||||
| Args: | ||||||||||||||
| seq_lengths: If provided, only compare valid positions per batch element. | ||||||||||||||
| If None, compare all positions. | ||||||||||||||
| """ | ||||||||||||||
| scale = 1.0 / head_dim**0.5 | ||||||||||||||
| result = dot_product_attention(q, k, v, mask, is_causal=is_causal, head_dim=head_dim) | ||||||||||||||
| expected = jax.nn.dot_product_attention( | ||||||||||||||
| q, k, v, scale=scale, mask=mask[:, None, None, :].astype(bool), is_causal=is_causal | ||||||||||||||
| ) | ||||||||||||||
|
|
||||||||||||||
| # bfloat16 has ~7 bits of mantissa (epsilon ≈ 2^-7 = 0.0078) | ||||||||||||||
| # Attention chains multiple ops, so errors compound to ~2^-6 = 0.0156 | ||||||||||||||
| atol = 0.02 | ||||||||||||||
|
|
||||||||||||||
| if seq_lengths is None: | ||||||||||||||
| assert jnp.allclose(result, expected, atol=atol) | ||||||||||||||
| else: | ||||||||||||||
| for b, length in enumerate(seq_lengths): | ||||||||||||||
| assert jnp.allclose(result[b, :length], expected[b, :length], atol=atol), f"Mismatch at batch {b}" | ||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
| class TestFlashAttention: | ||||||||||||||
| """Verify cuDNN flash attention matches mask-based attention.""" | ||||||||||||||
|
|
||||||||||||||
| @pytest.mark.parametrize("seq_len", [32, 128, 512]) | ||||||||||||||
| def test_padded_equivalence(self, seq_len): | ||||||||||||||
| """cuDNN matches mask-based for right-padded sequences.""" | ||||||||||||||
| batch, num_heads, head_dim = 2, 4, 64 | ||||||||||||||
| q, k, v = make_qkv(batch, seq_len, num_heads, head_dim) | ||||||||||||||
| seq_lengths = [seq_len - 4, seq_len - 8] | ||||||||||||||
| mask = make_right_padded_mask(batch, seq_len, seq_lengths) | ||||||||||||||
| assert_attention_match(q, k, v, mask, is_causal=True, head_dim=head_dim, seq_lengths=seq_lengths) | ||||||||||||||
|
|
||||||||||||||
| def test_no_padding(self): | ||||||||||||||
| """Full sequences (no padding) work correctly.""" | ||||||||||||||
| batch, seq_len, num_heads, head_dim = 2, 64, 4, 64 | ||||||||||||||
| q, k, v = make_qkv(batch, seq_len, num_heads, head_dim) | ||||||||||||||
| mask = jnp.ones((batch, seq_len)) | ||||||||||||||
| assert_attention_match(q, k, v, mask, is_causal=True, head_dim=head_dim) | ||||||||||||||
|
|
||||||||||||||
| @pytest.mark.parametrize( | ||||||||||||||
| "seq_lengths", | ||||||||||||||
| [ | ||||||||||||||
| [128, 96, 64, 32], # decreasing lengths | ||||||||||||||
| [32, 64, 96, 128], # increasing lengths | ||||||||||||||
| [128, 128, 128, 128], # all full (no padding) | ||||||||||||||
| [1, 1, 1, 1], # minimal valid sequences | ||||||||||||||
| ], | ||||||||||||||
| ) | ||||||||||||||
| def test_mixed_seq_lengths(self, seq_lengths): | ||||||||||||||
| """Batch with varying sequence lengths.""" | ||||||||||||||
| batch, seq_len, num_heads, head_dim = 4, 128, 4, 64 | ||||||||||||||
| q, k, v = make_qkv(batch, seq_len, num_heads, head_dim) | ||||||||||||||
| mask = make_right_padded_mask(batch, seq_len, seq_lengths) | ||||||||||||||
| assert_attention_match(q, k, v, mask, is_causal=True, head_dim=head_dim, seq_lengths=seq_lengths) | ||||||||||||||
|
|
||||||||||||||
| def test_decode(self): | ||||||||||||||
| """Decode mode (is_causal=False, single query token).""" | ||||||||||||||
| batch, kv_len, num_heads, head_dim = 2, 128, 4, 64 | ||||||||||||||
| q = jax.random.normal(jax.random.key(0), (batch, 1, num_heads, head_dim), dtype=jnp.bfloat16) | ||||||||||||||
| k = jax.random.normal(jax.random.key(1), (batch, kv_len, num_heads, head_dim), dtype=jnp.bfloat16) | ||||||||||||||
| v = jax.random.normal(jax.random.key(2), (batch, kv_len, num_heads, head_dim), dtype=jnp.bfloat16) | ||||||||||||||
|
Comment on lines
+95
to
+97
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Additionally, this block duplicates tensor creation logic from
Suggested change
|
||||||||||||||
| mask = make_right_padded_mask(batch, kv_len, [100, 80]) | ||||||||||||||
| assert_attention_match(q, k, v, mask, is_causal=False, head_dim=head_dim) | ||||||||||||||
|
|
||||||||||||||
| def test_float32_fallback(self): | ||||||||||||||
| """float32 (unsupported by cuDNN) uses mask-based fallback.""" | ||||||||||||||
| batch, seq_len, num_heads, head_dim = 2, 64, 4, 64 | ||||||||||||||
| q, k, v = make_qkv(batch, seq_len, num_heads, head_dim, dtype=jnp.float32) | ||||||||||||||
| mask = jnp.ones((batch, seq_len)) | ||||||||||||||
| assert_attention_match(q, k, v, mask, is_causal=True, head_dim=head_dim) | ||||||||||||||
|
|
||||||||||||||
| def test_gqa_decode(self): | ||||||||||||||
| """GQA decode mode (8 Q heads, 2 KV heads, single query token).""" | ||||||||||||||
| batch, kv_len, num_heads, num_kv_heads, head_dim = 2, 128, 8, 2, 64 | ||||||||||||||
| q = jax.random.normal(jax.random.key(0), (batch, 1, num_heads, head_dim), dtype=jnp.bfloat16) | ||||||||||||||
| k = jax.random.normal(jax.random.key(1), (batch, kv_len, num_kv_heads, head_dim), dtype=jnp.bfloat16) | ||||||||||||||
| v = jax.random.normal(jax.random.key(2), (batch, kv_len, num_kv_heads, head_dim), dtype=jnp.bfloat16) | ||||||||||||||
|
Comment on lines
+111
to
+113
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
As mentioned in
Suggested change
|
||||||||||||||
| mask = make_right_padded_mask(batch, kv_len, [100, 80]) | ||||||||||||||
| assert_attention_match(q, k, v, mask, is_causal=False, head_dim=head_dim) | ||||||||||||||
|
|
||||||||||||||
| def test_gqa_prefill(self): | ||||||||||||||
| """GQA prefill mode with right-padded sequences (8 Q heads, 2 KV heads).""" | ||||||||||||||
| batch, seq_len, num_heads, num_kv_heads, head_dim = 2, 128, 8, 2, 64 | ||||||||||||||
| q, k, v = make_qkv(batch, seq_len, num_heads, head_dim, num_kv_heads) | ||||||||||||||
| seq_lengths = [100, 80] | ||||||||||||||
| mask = make_right_padded_mask(batch, seq_len, seq_lengths) | ||||||||||||||
| assert_attention_match(q, k, v, mask, is_causal=True, head_dim=head_dim, seq_lengths=seq_lengths) | ||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,54 @@ | ||
| """Shared attention utilities for transformer models.""" | ||
|
|
||
| import jax | ||
| import jax.numpy as jnp | ||
|
|
||
| # cuDNN flash attention supported dtypes | ||
| # https://github.com/jax-ml/jax/blob/8b1f782540f71fbe230a2dccd331975faafc6c83/jax/_src/cudnn/fused_attention_stablehlo.py#L290 | ||
| _CUDNN_SUPPORTED_DTYPES = (jnp.float16, jnp.bfloat16, jnp.float8_e4m3fn, jnp.float8_e5m2) | ||
|
|
||
|
|
||
| def dot_product_attention( | ||
| q: jax.Array, | ||
| k: jax.Array, | ||
| v: jax.Array, | ||
| attention_mask: jax.Array, | ||
| is_causal: bool, | ||
| head_dim: int, | ||
| ) -> jax.Array: | ||
| """Compute dot-product attention with automatic backend selection. | ||
|
|
||
| Uses cuDNN on GPU for memory-efficient attention. Falls back to XLA for CPU/TPU. | ||
|
|
||
| Args: | ||
| q: Query tensor of shape [batch, q_len, num_heads, head_dim] | ||
| k: Key tensor of shape [batch, kv_len, num_kv_heads, head_dim] | ||
| v: Value tensor of shape [batch, kv_len, num_kv_heads, head_dim] | ||
| attention_mask: Mask of shape [batch, kv_len] where 1 = valid, 0 = masked. | ||
| Sequences must be right-padded (valid tokens first, then padding). | ||
| is_causal: Whether to apply causal masking (for prefill/training) | ||
| head_dim: Dimension of each attention head (for scaling) | ||
|
|
||
| Returns: | ||
| Attention output of shape [batch, q_len, num_heads, head_dim] | ||
| """ | ||
| scale = 1.0 / head_dim**0.5 | ||
|
|
||
| if jax.default_backend() == "gpu" and q.dtype in _CUDNN_SUPPORTED_DTYPES: | ||
| kv_seq_lengths = attention_mask.sum(axis=1).astype(jnp.int32) | ||
| q_seq_lengths = jnp.minimum(kv_seq_lengths, q.shape[1]) | ||
| return jax.nn.dot_product_attention( | ||
| q, | ||
| k, | ||
| v, | ||
| scale=scale, | ||
| is_causal=is_causal, | ||
| query_seq_lengths=q_seq_lengths, | ||
| key_value_seq_lengths=kv_seq_lengths, | ||
| implementation="cudnn", | ||
| ) | ||
|
|
||
| # CPU/TPU fallback | ||
| return jax.nn.dot_product_attention( | ||
| q, k, v, scale=scale, mask=attention_mask[:, None, None, :].astype(bool), is_causal=is_causal | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
jax.random.keyis an alias forjax.random.PRNGKeyand its usage is discouraged in new code. Please usejax.random.PRNGKeyfor clarity and future compatibility.