Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
235511c
feat: use cuDNN flash attention on GPU with CPU/TPU fallback
raulchen Jan 12, 2026
d7f8caf
fix: handle left-padded sequences in flash attention
raulchen Jan 14, 2026
4c48cc6
feat: enable cuDNN flash attention for left-padded prefill
raulchen Jan 14, 2026
63619a6
refactor: remove jax.lax.cond by using argmax for shift computation
raulchen Jan 14, 2026
a9fca15
test: add attention tests and GPU CI workflow
raulchen Jan 15, 2026
5fbfefe
lint
raulchen Jan 15, 2026
defed63
refine TestShiftSequences
raulchen Jan 15, 2026
775b746
remove abbreviations
raulchen Jan 15, 2026
1174531
refine TestDotProductAttentionCPU
raulchen Jan 16, 2026
7f88099
remove unnecessary CPU tests
raulchen Jan 16, 2026
ffed524
reduce duplication
raulchen Jan 16, 2026
a244a09
lint
raulchen Jan 16, 2026
dd63e80
address comments
raulchen Jan 21, 2026
c8e7095
lint
raulchen Jan 21, 2026
13922a9
remove assertion
raulchen Jan 21, 2026
a242427
move files
raulchen Jan 21, 2026
48d07db
add todo
raulchen Jan 21, 2026
e84c8c7
Merge main into flash-attention
raulchen Jan 23, 2026
38276cc
Remove left-padding shift logic, assume right-padded sequences
raulchen Jan 23, 2026
2e0bb7c
simplify
raulchen Jan 23, 2026
ef5c22f
fix dtype
raulchen Jan 23, 2026
76dd7f2
remove duplication
raulchen Jan 23, 2026
fb067ee
Merge branch 'flash-attention' of https://github.com/raulchen/SkyRL i…
raulchen Jan 23, 2026
28117db
fix
raulchen Jan 23, 2026
5c3ef62
test f32
raulchen Jan 23, 2026
d468ccb
remove hook
raulchen Jan 23, 2026
061f0ca
fix
raulchen Jan 23, 2026
eae82ff
gqa tests
raulchen Jan 23, 2026
748ca39
Merge branch 'main' into flash-attention
pcmoritz Jan 24, 2026
6920b6d
Merge branch 'main' into flash-attention
pcmoritz Jan 24, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/cpu_skyrl_tx.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ jobs:
# uv run --extra tinker --extra dev ty check
- name: Run pytest
run: |
uv run --extra tinker --extra dev pytest --forked -s tests
uv run --extra tinker --extra dev pytest --forked -s tests --ignore=tests/gpu
- name: Run a single training step
run: |
uv run tx train --model pcmoritz/qwen3-tiny-test --dataset mahiatlinux/TinyStories-GPT4-V2-50K-SUBSET --output-dir /tmp --batch-size 2 --max-steps 1 --optimizer-args '{"learning_rate": 0.002, "weight_decay": 0.1}'
Expand Down
50 changes: 50 additions & 0 deletions .github/workflows/gpu_skyrl_tx.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
name: SkyRL-tx-GPU

on:
push:
branches: [ main ]
paths:
- 'skyrl-tx/**'
- '.github/workflows/gpu_skyrl_tx.yaml'
pull_request:
paths:
- 'skyrl-tx/**'
- '.github/workflows/gpu_skyrl_tx.yaml'
workflow_dispatch:

permissions:
checks: write
contents: read

concurrency:
group: skyrl-tx-gpu-${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true

jobs:
skyrl_tx_gpu_tests:
runs-on: ubuntu-latest
defaults:
run:
shell: bash
working-directory: ./skyrl-tx

steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.12'
cache: 'pip'
- name: Install the latest version of uv
uses: astral-sh/setup-uv@v6
with:
activate-environment: true
- name: Install dependencies
run: uv pip install anyscale==0.24.79 typer==0.9.0
- name: GPU tests
env:
ANYSCALE_CLI_TOKEN: ${{ secrets.ANYSCALE_CLI_TOKEN }}
ANYSCALE_HOST: https://console.anyscale.com
run: |
anyscale job submit -f ci/anyscale_gpu_ci.yaml --timeout 10000
anyscale job wait --cloud sky-anyscale-aws-us-east-1 --name skyrl-tx-gpu-ci --timeout 10000
6 changes: 2 additions & 4 deletions skyrl-train/ci/anyscale_gpu_ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ entrypoint: bash ci/gpu_ci_run.sh
image_uri: novaskyai/skyrl-train-ray-2.51.1-py3.12-cu12.8 # (Optional) Exclusive with `containerfile`.
cloud: sky-anyscale-aws-us-east-1
ray_version: "2.51.1"
compute_config: l4_ci
compute_config: l4_ci
working_dir: . # (Optional) Use current working directory "." as the working_dir. Can be any local path or remote .zip file in cloud storage.
env_vars:
RAY_RUNTIME_ENV_HOOK: ray._private.runtime_env.uv_runtime_env_hook.hook
max_retries: 0 # (Optional) Maximum number of times the job will be retried before being marked failed. Defaults to `1`.
max_retries: 0 # (Optional) Maximum number of times the job will be retried before being marked failed. Defaults to `1`.
6 changes: 2 additions & 4 deletions skyrl-train/ci/anyscale_gpu_ci_megatron.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ entrypoint: bash ci/gpu_ci_run_megatron.sh
image_uri: novaskyai/skyrl-train-ray-2.51.1-py3.12-cu12.8-megatron # (Optional) Exclusive with `containerfile`.
cloud: sky-anyscale-aws-us-east-1
ray_version: "2.51.1"
compute_config: l4_ci
compute_config: l4_ci
working_dir: . # (Optional) Use current working directory "." as the working_dir. Can be any local path or remote .zip file in cloud storage.
env_vars:
RAY_RUNTIME_ENV_HOOK: ray._private.runtime_env.uv_runtime_env_hook.hook
max_retries: 0 # (Optional) Maximum number of times the job will be retried before being marked failed. Defaults to `1`.
max_retries: 0 # (Optional) Maximum number of times the job will be retried before being marked failed. Defaults to `1`.
5 changes: 2 additions & 3 deletions skyrl-train/ci/anyscale_gpu_e2e_test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@ entrypoint: bash ci/gpu_e2e_test_run.sh
image_uri: novaskyai/skyrl-train-ray-2.51.1-py3.12-cu12.8 # (Optional) Exclusive with `containerfile`.
cloud: sky-anyscale-aws-us-east-1
ray_version: "2.51.1"
compute_config: l4_ci
compute_config: l4_ci
working_dir: . # (Optional) Use current working directory "." as the working_dir. Can be any local path or remote .zip file in cloud storage.
env_vars:
RAY_OVERRIDE_JOB_RUNTIME_ENV: "1"
WANDB_API_KEY: $WANDB_API_KEY
RAY_RUNTIME_ENV_HOOK: ray._private.runtime_env.uv_runtime_env_hook.hook
max_retries: 1 # (Optional) Maximum number of times the job will be retried before being marked failed. Defaults to `1`.
max_retries: 1 # (Optional) Maximum number of times the job will be retried before being marked failed. Defaults to `1`.
5 changes: 2 additions & 3 deletions skyrl-train/ci/anyscale_gpu_e2e_test_megatron.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@ entrypoint: bash ci/gpu_e2e_test_run_megatron.sh
image_uri: novaskyai/skyrl-train-ray-2.51.1-py3.12-cu12.8-megatron # (Optional) Exclusive with `containerfile`.
cloud: sky-anyscale-aws-us-east-1
ray_version: "2.51.1"
compute_config: l4_ci
compute_config: l4_ci
working_dir: . # (Optional) Use current working directory "." as the working_dir. Can be any local path or remote .zip file in cloud storage.
env_vars:
RAY_OVERRIDE_JOB_RUNTIME_ENV: "1"
WANDB_API_KEY: $WANDB_API_KEY
RAY_RUNTIME_ENV_HOOK: ray._private.runtime_env.uv_runtime_env_hook.hook
max_retries: 1 # (Optional) Maximum number of times the job will be retried before being marked failed. Defaults to `1`.
max_retries: 1 # (Optional) Maximum number of times the job will be retried before being marked failed. Defaults to `1`.
8 changes: 8 additions & 0 deletions skyrl-tx/ci/anyscale_gpu_ci.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
name: skyrl-tx-gpu-ci
entrypoint: bash ci/gpu_ci_run.sh
image_uri: novaskyai/skyrl-train-ray-2.51.1-py3.12-cu12.8
cloud: sky-anyscale-aws-us-east-1
ray_version: "2.51.1"
compute_config: l4_ci
working_dir: .
max_retries: 0
7 changes: 7 additions & 0 deletions skyrl-tx/ci/gpu_ci_run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#!/usr/bin/env bash
set -xeuo pipefail

export CI=true

# Run GPU-specific tests
uv run --extra gpu --extra tinker --extra dev pytest tests/gpu
Empty file added skyrl-tx/tests/gpu/__init__.py
Empty file.
123 changes: 123 additions & 0 deletions skyrl-tx/tests/gpu/test_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
"""GPU tests for flash attention.

These tests require a GPU and verify that cuDNN flash attention produces
numerically equivalent results to the mask-based implementation.
"""

import jax
import jax.numpy as jnp
import pytest

from tx.layers.attention import dot_product_attention

# Skip all tests if not on GPU
pytestmark = pytest.mark.skipif(jax.default_backend() != "gpu", reason="GPU tests require CUDA")


def make_qkv(batch, seq_len, num_heads, head_dim, num_kv_heads=None, dtype=jnp.bfloat16):
"""Create random Q, K, V tensors."""
if num_kv_heads is None:
num_kv_heads = num_heads
q = jax.random.normal(jax.random.key(0), (batch, seq_len, num_heads, head_dim), dtype=dtype)
k = jax.random.normal(jax.random.key(1), (batch, seq_len, num_kv_heads, head_dim), dtype=dtype)
v = jax.random.normal(jax.random.key(2), (batch, seq_len, num_kv_heads, head_dim), dtype=dtype)
Comment on lines +21 to +23
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

jax.random.key is an alias for jax.random.PRNGKey and its usage is discouraged in new code. Please use jax.random.PRNGKey for clarity and future compatibility.

Suggested change
q = jax.random.normal(jax.random.key(0), (batch, seq_len, num_heads, head_dim), dtype=dtype)
k = jax.random.normal(jax.random.key(1), (batch, seq_len, num_kv_heads, head_dim), dtype=dtype)
v = jax.random.normal(jax.random.key(2), (batch, seq_len, num_kv_heads, head_dim), dtype=dtype)
q = jax.random.normal(jax.random.PRNGKey(0), (batch, seq_len, num_heads, head_dim), dtype=dtype)
k = jax.random.normal(jax.random.PRNGKey(1), (batch, seq_len, num_kv_heads, head_dim), dtype=dtype)
v = jax.random.normal(jax.random.PRNGKey(2), (batch, seq_len, num_kv_heads, head_dim), dtype=dtype)

return q, k, v


def make_right_padded_mask(batch, seq_len, seq_lengths):
"""Create right-padded mask: [1,1,1,...,0,0]."""
seq_lengths = jnp.array(seq_lengths)
return (jnp.arange(seq_len)[None, :] < seq_lengths[:, None]).astype(jnp.float32)


def assert_attention_match(q, k, v, mask, is_causal, head_dim, seq_lengths=None):
"""Run both attention implementations and assert they match.

Args:
seq_lengths: If provided, only compare valid positions per batch element.
If None, compare all positions.
"""
scale = 1.0 / head_dim**0.5
result = dot_product_attention(q, k, v, mask, is_causal=is_causal, head_dim=head_dim)
expected = jax.nn.dot_product_attention(
q, k, v, scale=scale, mask=mask[:, None, None, :].astype(bool), is_causal=is_causal
)

# bfloat16 has ~7 bits of mantissa (epsilon ≈ 2^-7 = 0.0078)
# Attention chains multiple ops, so errors compound to ~2^-6 = 0.0156
atol = 0.02

if seq_lengths is None:
assert jnp.allclose(result, expected, atol=atol)
else:
for b, length in enumerate(seq_lengths):
assert jnp.allclose(result[b, :length], expected[b, :length], atol=atol), f"Mismatch at batch {b}"


class TestFlashAttention:
"""Verify cuDNN flash attention matches mask-based attention."""

@pytest.mark.parametrize("seq_len", [32, 128, 512])
def test_padded_equivalence(self, seq_len):
"""cuDNN matches mask-based for right-padded sequences."""
batch, num_heads, head_dim = 2, 4, 64
q, k, v = make_qkv(batch, seq_len, num_heads, head_dim)
seq_lengths = [seq_len - 4, seq_len - 8]
mask = make_right_padded_mask(batch, seq_len, seq_lengths)
assert_attention_match(q, k, v, mask, is_causal=True, head_dim=head_dim, seq_lengths=seq_lengths)

def test_no_padding(self):
"""Full sequences (no padding) work correctly."""
batch, seq_len, num_heads, head_dim = 2, 64, 4, 64
q, k, v = make_qkv(batch, seq_len, num_heads, head_dim)
mask = jnp.ones((batch, seq_len))
assert_attention_match(q, k, v, mask, is_causal=True, head_dim=head_dim)

@pytest.mark.parametrize(
"seq_lengths",
[
[128, 96, 64, 32], # decreasing lengths
[32, 64, 96, 128], # increasing lengths
[128, 128, 128, 128], # all full (no padding)
[1, 1, 1, 1], # minimal valid sequences
],
)
def test_mixed_seq_lengths(self, seq_lengths):
"""Batch with varying sequence lengths."""
batch, seq_len, num_heads, head_dim = 4, 128, 4, 64
q, k, v = make_qkv(batch, seq_len, num_heads, head_dim)
mask = make_right_padded_mask(batch, seq_len, seq_lengths)
assert_attention_match(q, k, v, mask, is_causal=True, head_dim=head_dim, seq_lengths=seq_lengths)

def test_decode(self):
"""Decode mode (is_causal=False, single query token)."""
batch, kv_len, num_heads, head_dim = 2, 128, 4, 64
q = jax.random.normal(jax.random.key(0), (batch, 1, num_heads, head_dim), dtype=jnp.bfloat16)
k = jax.random.normal(jax.random.key(1), (batch, kv_len, num_heads, head_dim), dtype=jnp.bfloat16)
v = jax.random.normal(jax.random.key(2), (batch, kv_len, num_heads, head_dim), dtype=jnp.bfloat16)
Comment on lines +95 to +97
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

jax.random.key is an alias for jax.random.PRNGKey and its usage is discouraged in new code. Please use jax.random.PRNGKey for clarity and future compatibility.

Additionally, this block duplicates tensor creation logic from make_qkv and test_gqa_decode. Consider refactoring make_qkv to accept separate q_len and kv_len to centralize this logic and improve maintainability.

Suggested change
q = jax.random.normal(jax.random.key(0), (batch, 1, num_heads, head_dim), dtype=jnp.bfloat16)
k = jax.random.normal(jax.random.key(1), (batch, kv_len, num_heads, head_dim), dtype=jnp.bfloat16)
v = jax.random.normal(jax.random.key(2), (batch, kv_len, num_heads, head_dim), dtype=jnp.bfloat16)
q = jax.random.normal(jax.random.PRNGKey(0), (batch, 1, num_heads, head_dim), dtype=jnp.bfloat16)
k = jax.random.normal(jax.random.PRNGKey(1), (batch, kv_len, num_heads, head_dim), dtype=jnp.bfloat16)
v = jax.random.normal(jax.random.PRNGKey(2), (batch, kv_len, num_heads, head_dim), dtype=jnp.bfloat16)

mask = make_right_padded_mask(batch, kv_len, [100, 80])
assert_attention_match(q, k, v, mask, is_causal=False, head_dim=head_dim)

def test_float32_fallback(self):
"""float32 (unsupported by cuDNN) uses mask-based fallback."""
batch, seq_len, num_heads, head_dim = 2, 64, 4, 64
q, k, v = make_qkv(batch, seq_len, num_heads, head_dim, dtype=jnp.float32)
mask = jnp.ones((batch, seq_len))
assert_attention_match(q, k, v, mask, is_causal=True, head_dim=head_dim)

def test_gqa_decode(self):
"""GQA decode mode (8 Q heads, 2 KV heads, single query token)."""
batch, kv_len, num_heads, num_kv_heads, head_dim = 2, 128, 8, 2, 64
q = jax.random.normal(jax.random.key(0), (batch, 1, num_heads, head_dim), dtype=jnp.bfloat16)
k = jax.random.normal(jax.random.key(1), (batch, kv_len, num_kv_heads, head_dim), dtype=jnp.bfloat16)
v = jax.random.normal(jax.random.key(2), (batch, kv_len, num_kv_heads, head_dim), dtype=jnp.bfloat16)
Comment on lines +111 to +113
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

jax.random.key is an alias for jax.random.PRNGKey and its usage is discouraged in new code. Please use jax.random.PRNGKey for clarity and future compatibility.

As mentioned in test_decode, this is another instance of duplicated tensor creation logic. Refactoring make_qkv would benefit this test as well.

Suggested change
q = jax.random.normal(jax.random.key(0), (batch, 1, num_heads, head_dim), dtype=jnp.bfloat16)
k = jax.random.normal(jax.random.key(1), (batch, kv_len, num_kv_heads, head_dim), dtype=jnp.bfloat16)
v = jax.random.normal(jax.random.key(2), (batch, kv_len, num_kv_heads, head_dim), dtype=jnp.bfloat16)
q = jax.random.normal(jax.random.PRNGKey(0), (batch, 1, num_heads, head_dim), dtype=jnp.bfloat16)
k = jax.random.normal(jax.random.PRNGKey(1), (batch, kv_len, num_kv_heads, head_dim), dtype=jnp.bfloat16)
v = jax.random.normal(jax.random.PRNGKey(2), (batch, kv_len, num_kv_heads, head_dim), dtype=jnp.bfloat16)

mask = make_right_padded_mask(batch, kv_len, [100, 80])
assert_attention_match(q, k, v, mask, is_causal=False, head_dim=head_dim)

def test_gqa_prefill(self):
"""GQA prefill mode with right-padded sequences (8 Q heads, 2 KV heads)."""
batch, seq_len, num_heads, num_kv_heads, head_dim = 2, 128, 8, 2, 64
q, k, v = make_qkv(batch, seq_len, num_heads, head_dim, num_kv_heads)
seq_lengths = [100, 80]
mask = make_right_padded_mask(batch, seq_len, seq_lengths)
assert_attention_match(q, k, v, mask, is_causal=True, head_dim=head_dim, seq_lengths=seq_lengths)
54 changes: 54 additions & 0 deletions skyrl-tx/tx/layers/attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""Shared attention utilities for transformer models."""

import jax
import jax.numpy as jnp

# cuDNN flash attention supported dtypes
# https://github.com/jax-ml/jax/blob/8b1f782540f71fbe230a2dccd331975faafc6c83/jax/_src/cudnn/fused_attention_stablehlo.py#L290
_CUDNN_SUPPORTED_DTYPES = (jnp.float16, jnp.bfloat16, jnp.float8_e4m3fn, jnp.float8_e5m2)


def dot_product_attention(
q: jax.Array,
k: jax.Array,
v: jax.Array,
attention_mask: jax.Array,
is_causal: bool,
head_dim: int,
) -> jax.Array:
"""Compute dot-product attention with automatic backend selection.

Uses cuDNN on GPU for memory-efficient attention. Falls back to XLA for CPU/TPU.

Args:
q: Query tensor of shape [batch, q_len, num_heads, head_dim]
k: Key tensor of shape [batch, kv_len, num_kv_heads, head_dim]
v: Value tensor of shape [batch, kv_len, num_kv_heads, head_dim]
attention_mask: Mask of shape [batch, kv_len] where 1 = valid, 0 = masked.
Sequences must be right-padded (valid tokens first, then padding).
is_causal: Whether to apply causal masking (for prefill/training)
head_dim: Dimension of each attention head (for scaling)

Returns:
Attention output of shape [batch, q_len, num_heads, head_dim]
"""
scale = 1.0 / head_dim**0.5

if jax.default_backend() == "gpu" and q.dtype in _CUDNN_SUPPORTED_DTYPES:
kv_seq_lengths = attention_mask.sum(axis=1).astype(jnp.int32)
q_seq_lengths = jnp.minimum(kv_seq_lengths, q.shape[1])
return jax.nn.dot_product_attention(
q,
k,
v,
scale=scale,
is_causal=is_causal,
query_seq_lengths=q_seq_lengths,
key_value_seq_lengths=kv_seq_lengths,
implementation="cudnn",
)

# CPU/TPU fallback
return jax.nn.dot_product_attention(
q, k, v, scale=scale, mask=attention_mask[:, None, None, :].astype(bool), is_causal=is_causal
)
12 changes: 3 additions & 9 deletions skyrl-tx/tx/models/llama3.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from tx.layers.lora import LoRAEmbed, LoRALinear
from tx.layers.rotary_embedding import apply_rope
from tx.layers.layernorm import RMSNorm
from tx.layers.attention import dot_product_attention
from tx.utils.logits_processor import LogitsProcessorMixin, LMHead
from tx.models.types import CausalLMOutput, ModelOutput
from tx.utils.generator import GeneratorMixin, KVCache
Expand Down Expand Up @@ -101,15 +102,8 @@ def __call__(

updated_cache = (k, v)

# Attention (causal only during prefill, GQA handled natively by dot_product_attention)
attn_output = jax.nn.dot_product_attention(
q,
k,
v,
scale=1.0 / self.head_dim**0.5,
mask=attention_mask[:, None, None, :].astype(bool),
is_causal=kv_cache is None,
)
is_causal = kv_cache is None
attn_output = dot_product_attention(q, k, v, attention_mask, is_causal, self.head_dim)

output = attn_output.reshape(B, T, self.num_heads * self.head_dim)
return self.o_proj(output, adapter_indices=adapter_indices), updated_cache
Expand Down
14 changes: 4 additions & 10 deletions skyrl-tx/tx/models/qwen3.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
from tx.layers.util import prepare_routing, shard_map_ep
from tx.layers.rotary_embedding import apply_rope
from tx.utils.logits_processor import LogitsProcessorMixin, LMHead
from tx.models.configs import Qwen3Config
from tx.layers.layernorm import RMSNorm
from tx.layers.attention import dot_product_attention
from tx.models.configs import Qwen3Config
from tx.models.types import CausalLMOutput, ModelOutput
from tx.utils.generator import GeneratorMixin, KVCache

Expand Down Expand Up @@ -102,15 +103,8 @@ def __call__(

updated_cache = (k, v)

# Attention (causal only during prefill, GQA handled natively by dot_product_attention)
attn_output = jax.nn.dot_product_attention(
q,
k,
v,
scale=1.0 / self.head_dim**0.5,
mask=attention_mask[:, None, None, :].astype(bool),
is_causal=kv_cache is None,
)
is_causal = kv_cache is None
attn_output = dot_product_attention(q, k, v, attention_mask, is_causal, self.head_dim)

output = attn_output.reshape(B, T, self.num_heads * self.head_dim)
return self.o_proj(output, adapter_indices=adapter_indices), updated_cache
Expand Down
Loading