Skip to content

Commit d06a9c9

Browse files
authored
Fix ragged paged attention v2 to resolve the recompilation issue (#8797)
1 parent 9b61c1a commit d06a9c9

File tree

3 files changed

+46
-51
lines changed

3 files changed

+46
-51
lines changed

test/test_pallas.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -676,6 +676,7 @@ def test_ragged_paged_attention_wrapper_without_dynamo(self):
676676
kv_lens_xla = kv_lens.to("xla")
677677
page_indices_xla = page_indices.to("xla")
678678
cu_q_lens_xla = cu_q_lens.to("xla")
679+
num_seqs_xla = torch.tensor([num_seqs], dtype=torch.int32).to("xla")
679680

680681
output = ragged_paged_attention(
681682
q_xla,
@@ -684,7 +685,7 @@ def test_ragged_paged_attention_wrapper_without_dynamo(self):
684685
kv_lens_xla,
685686
page_indices_xla,
686687
cu_q_lens_xla,
687-
num_seqs=num_seqs,
688+
num_seqs=num_seqs_xla,
688689
num_kv_pages_per_block=num_kv_pages_per_block,
689690
num_queries_per_block=num_queries_per_block,
690691
use_kernel=True)[:cu_q_lens[num_seqs]]
@@ -696,7 +697,7 @@ def test_ragged_paged_attention_wrapper_without_dynamo(self):
696697
kv_lens_xla,
697698
page_indices_xla,
698699
cu_q_lens_xla,
699-
num_seqs=num_seqs,
700+
num_seqs=num_seqs_xla,
700701
num_kv_pages_per_block=num_kv_pages_per_block,
701702
num_queries_per_block=num_queries_per_block,
702703
use_kernel=False)
@@ -707,6 +708,7 @@ def test_ragged_paged_attention_wrapper_without_dynamo(self):
707708
kv_lens_jax = jnp.array(kv_lens.numpy(), dtype=jnp.int32)
708709
page_indices_jax = jnp.array(page_indices.numpy(), dtype=jnp.int32)
709710
cu_q_lens_jax = jnp.array(cu_q_lens.numpy(), dtype=jnp.int32)
711+
num_seqs_jax = jnp.array([num_seqs], dtype=jnp.int32)
710712

711713
expected_output = torch.from_numpy(
712714
np.array(
@@ -717,7 +719,7 @@ def test_ragged_paged_attention_wrapper_without_dynamo(self):
717719
kv_lens_jax,
718720
page_indices_jax,
719721
cu_q_lens_jax,
720-
num_seqs=num_seqs,
722+
num_seqs=num_seqs_jax,
721723
num_kv_pages_per_block=num_kv_pages_per_block,
722724
num_queries_per_block=num_queries_per_block,
723725
)[:cu_q_lens[num_seqs]]))
@@ -765,6 +767,7 @@ def _verify_ragged_paged_attention_with_dynamo(
765767
kv_lens_xla = kv_lens.to("xla")
766768
page_indices_xla = page_indices.to("xla")
767769
cu_q_lens_xla = cu_q_lens.to("xla")
770+
num_seqs_xla = torch.tensor([num_seqs], dtype=torch.int32).to("xla")
768771

769772
kernel_output = torch.ops.xla.ragged_paged_attention(
770773
q_xla,
@@ -773,7 +776,7 @@ def _verify_ragged_paged_attention_with_dynamo(
773776
kv_lens_xla,
774777
page_indices_xla,
775778
cu_q_lens_xla,
776-
num_seqs=num_seqs,
779+
num_seqs=num_seqs_xla,
777780
num_kv_pages_per_block=num_kv_pages_per_block,
778781
num_queries_per_block=num_queries_per_block,
779782
use_kernel=True,
@@ -787,7 +790,7 @@ def _verify_ragged_paged_attention_with_dynamo(
787790
kv_lens_xla,
788791
page_indices_xla,
789792
cu_q_lens_xla,
790-
num_seqs=num_seqs,
793+
num_seqs=num_seqs_xla,
791794
num_kv_pages_per_block=num_kv_pages_per_block,
792795
num_queries_per_block=num_queries_per_block,
793796
use_kernel=False,
@@ -805,6 +808,7 @@ def _verify_ragged_paged_attention_with_dynamo(
805808
kv_lens_jax = jnp.array(kv_lens.numpy(), dtype=jnp.int32)
806809
page_indices_jax = jnp.array(page_indices.numpy(), dtype=jnp.int32)
807810
cu_q_lens_jax = jnp.array(cu_q_lens.numpy(), dtype=jnp.int32)
811+
num_seqs_jax = jnp.array([num_seqs], dtype=jnp.int32)
808812

809813
from torch_xla.experimental.pallas_kernels.ragged_paged_attention_v2 import ragged_paged_attention as jax_ragged_paged_attention
810814
jax_kernel_output = torch.from_numpy(
@@ -816,7 +820,7 @@ def _verify_ragged_paged_attention_with_dynamo(
816820
kv_lens_jax,
817821
page_indices_jax,
818822
cu_q_lens_jax,
819-
num_seqs=num_seqs,
823+
num_seqs=num_seqs_jax,
820824
num_kv_pages_per_block=num_kv_pages_per_block,
821825
num_queries_per_block=num_queries_per_block,
822826
sm_scale=sm_scale,
@@ -867,7 +871,7 @@ def test_ragged_paged_attention_wrapper_no_padding_with_dynamo(self):
867871

868872
@parameterized.product(
869873
seq_lens=[[(1, 1328), (5, 18), (500, 563)]],
870-
num_queries_per_block=[16, 64, 128],
874+
num_queries_per_block=[16, 32],
871875
)
872876
@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4,
873877
"This test only works on TPUv4+.")

torch_xla/experimental/custom_kernel.py

Lines changed: 28 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -900,11 +900,11 @@ def validate_ragged_paged_attention_inputs(
900900
kv_lens, # i32[max_num_seqs]
901901
page_indices, # i32[max_num_seqs, pages_per_seq]
902902
cu_q_lens, # i32[max_num_seqs + 1]
903-
num_seqs, # i32
903+
num_seqs, # i32[1]
904904
):
905-
max_num_batched_tokens, num_q_heads, head_dim = q.shape
906-
_, page_size, num_kv_heads, head_dim_k = k_pages.shape
907-
max_num_seqs, pages_per_seq = page_indices.shape
905+
_, num_q_heads, head_dim = q.shape
906+
_, _, num_kv_heads, head_dim_k = k_pages.shape
907+
max_num_seqs, _ = page_indices.shape
908908
if k_pages.shape != v_pages.shape:
909909
raise ValueError(
910910
f"{k_pages.shape=} and {v_pages.shape=} must have the same shape.")
@@ -918,9 +918,6 @@ def validate_ragged_paged_attention_inputs(
918918
raise ValueError(
919919
f"Expected {cu_q_lens.shape=} to be ({max_num_seqs + 1},) where"
920920
" `max_num_seqs` is `page_indices.shape[0]`.")
921-
if max_num_seqs > max_num_batched_tokens:
922-
raise ValueError(
923-
f"{max_num_seqs=} must be less or equal to {max_num_batched_tokens=}.")
924921
if (kv_lens.dtype != torch.int32 or page_indices.dtype != torch.int32 or
925922
cu_q_lens.dtype != torch.int32):
926923
raise ValueError(
@@ -931,24 +928,24 @@ def validate_ragged_paged_attention_inputs(
931928
raise ValueError(f"{num_q_heads=} must be divisible by {num_kv_heads=}")
932929

933930
# Must check below on runtime!
934-
if num_seqs > max_num_seqs:
935-
raise ValueError(f"{num_seqs=} must be less or equal to {max_num_seqs=}")
936-
max_kv_len = torch.max(kv_lens)
937-
min_pages_per_seq = ceil_div(max_kv_len, page_size)
938-
if pages_per_seq < min_pages_per_seq:
939-
raise ValueError(
940-
f"{pages_per_seq=} must be greater or equal to"
941-
f" {min_pages_per_seq=} given {max_kv_len=} and {page_size=}.")
942-
if cu_q_lens[num_seqs] > max_num_batched_tokens:
943-
raise ValueError(
944-
f"Total q tokens {cu_q_lens[num_seqs]} must be less or equal to"
945-
f" {max_num_batched_tokens=}.")
946-
for i in range(num_seqs):
947-
q_len = cu_q_lens[i + 1] - cu_q_lens[i]
948-
kv_len = kv_lens[i]
949-
if q_len > kv_len:
950-
raise ValueError(
951-
f"{q_len=} must be less or equal to {kv_len=} at sequence {i}.")
931+
# if num_seqs > max_num_seqs:
932+
# raise ValueError(f"{num_seqs=} must be less or equal to {max_num_seqs=}")
933+
# max_kv_len = torch.max(kv_lens)
934+
# min_pages_per_seq = ceil_div(max_kv_len, page_size)
935+
# if pages_per_seq < min_pages_per_seq:
936+
# raise ValueError(
937+
# f"{pages_per_seq=} must be greater or equal to"
938+
# f" {min_pages_per_seq=} given {max_kv_len=} and {page_size=}.")
939+
# if cu_q_lens[num_seqs] > max_num_batched_tokens:
940+
# raise ValueError(
941+
# f"Total q tokens {cu_q_lens[num_seqs]} must be less or equal to"
942+
# f" {max_num_batched_tokens=}.")
943+
# for i in range(num_seqs):
944+
# q_len = cu_q_lens[i + 1] - cu_q_lens[i]
945+
# kv_len = kv_lens[i]
946+
# if q_len > kv_len:
947+
# raise ValueError(
948+
# f"{q_len=} must be less or equal to {kv_len=} at sequence {i}.")
952949

953950

954951
def _ragged_paged_attention_nonkernel(
@@ -1001,7 +998,7 @@ def ragged_paged_attention(
1001998
kv_lens, # i32[max_num_seqs]
1002999
page_indices, # i32[max_num_seqs, pages_per_seq]
10031000
cu_q_lens, # i32[max_num_seqs + 1]
1004-
num_seqs, # i32
1001+
num_seqs, # i32[1]
10051002
*,
10061003
sm_scale=1.0,
10071004
mask_value=None,
@@ -1022,7 +1019,7 @@ def ragged_paged_attention(
10221019
kv_lens,
10231020
page_indices,
10241021
cu_q_lens,
1025-
num_seqs,
1022+
num_seqs.item(),
10261023
sm_scale=sm_scale,
10271024
mask_value=mask_value,
10281025
)
@@ -1054,17 +1051,14 @@ def ragged_paged_attention(
10541051
],
10551052
)
10561053

1057-
num_q_blks = ceil_div(cu_q_lens[num_seqs], num_queries_per_block)
10581054
seq_buf_idx = torch.tensor([0, 0], dtype=torch.int32).to("xla")
1059-
num_seqs_ref = torch.tensor([num_seqs], dtype=torch.int32).to("xla")
10601055
output = torch_xla._XLAC._xla_tpu_custom_call(
10611056
[
1062-
num_q_blks,
10631057
kv_lens,
10641058
page_indices,
10651059
cu_q_lens,
10661060
seq_buf_idx,
1067-
num_seqs_ref,
1061+
num_seqs,
10681062
q,
10691063
k_pages,
10701064
v_pages,
@@ -1733,7 +1727,7 @@ def multi_queries_paged_attention_non_xla(q: torch.Tensor,
17331727

17341728
XLA_LIB.define(
17351729
"ragged_paged_attention(Tensor q, Tensor k_pages, Tensor v_pages, Tensor kv_lens, Tensor page_indices, "
1736-
"Tensor cu_q_lens, int num_seqs, int num_kv_pages_per_block, int num_queries_per_block, bool use_kernel, "
1730+
"Tensor cu_q_lens, Tensor num_seqs, int num_kv_pages_per_block, int num_queries_per_block, bool use_kernel, "
17371731
"float sm_scale=1.0, float? mask_value=None, int? vmem_limit_bytes=None) -> Tensor",
17381732
)
17391733

@@ -1746,7 +1740,7 @@ def ragged_paged_attention_xla(
17461740
kv_lens: torch.Tensor,
17471741
page_indices: torch.Tensor,
17481742
cu_q_lens: torch.Tensor,
1749-
num_seqs: int,
1743+
num_seqs: torch.Tensor,
17501744
num_kv_pages_per_block: int,
17511745
num_queries_per_block: int,
17521746
use_kernel: bool,
@@ -1777,7 +1771,7 @@ def ragged_paged_attention_non_xla(q: torch.Tensor,
17771771
kv_lens: torch.Tensor,
17781772
page_indices: torch.Tensor,
17791773
cu_q_lens: torch.Tensor,
1780-
num_seqs: int,
1774+
num_seqs: torch.Tensor,
17811775
num_kv_pages_per_block: int,
17821776
num_queries_per_block: int,
17831777
use_kernel: bool,

torch_xla/experimental/pallas_kernels/ragged_paged_attention_v2.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -151,9 +151,6 @@ def check_inputs_shapes(
151151
raise ValueError(
152152
f"Expected {cu_q_lens.shape=} to be ({max_num_seqs + 1},) where"
153153
" `max_num_seqs` is `page_indices.shape[0]`.")
154-
if max_num_seqs > max_num_batched_tokens:
155-
raise ValueError(
156-
f"{max_num_seqs=} must be less or equal to {max_num_batched_tokens=}.")
157154
if (kv_lens.dtype != jnp.int32 or page_indices.dtype != jnp.int32 or
158155
cu_q_lens.dtype != jnp.int32):
159156
raise ValueError(
@@ -253,7 +250,9 @@ def prefetch_first_kv_blk():
253250

254251
def is_cur_q_blk_needed(q_states):
255252
done, cur_seq_idx, _ = q_states
256-
return jnp.logical_and(done == 0, cur_seq_idx < num_seqs)
253+
should_run = jnp.logical_and(q_len_start < cu_q_lens_ref[num_seqs],
254+
cur_seq_idx < num_seqs)
255+
return jnp.logical_and(done == 0, should_run)
257256

258257
def compute_with_cur_q_blk(q_states):
259258
done, cur_seq_idx, cur_buf_idx = q_states
@@ -551,7 +550,7 @@ def ragged_paged_attention(
551550
kv_lens: jax.Array, # i32[max_num_seqs]
552551
page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq]
553552
cu_q_lens: jax.Array, # i32[max_num_seqs + 1]
554-
num_seqs, # i32
553+
num_seqs, # i32[1]
555554
*,
556555
sm_scale: float = 1.0,
557556
mask_value: float = DEFAULT_MASK_VALUE,
@@ -583,12 +582,12 @@ def ragged_paged_attention(
583582
The output of the attention.
584583
"""
585584
check_inputs_shapes(q, k_pages, v_pages, kv_lens, page_indices, cu_q_lens)
586-
_, num_q_heads, head_dim = q.shape
585+
num_q, num_q_heads, head_dim = q.shape
587586
_, page_size, num_kv_heads, _ = k_pages.shape
588587
num_q_per_blk = num_queries_per_block
589588
num_kv_pages_per_blk = num_kv_pages_per_block
590589
num_q_heads_per_kv_head = num_q_heads // num_kv_heads
591-
num_q_blks = ceil_div(cu_q_lens[num_seqs], num_q_per_blk)
590+
num_q_blks = ceil_div(num_q, num_q_per_blk)
592591
num_q_heads_per_blk, num_kv_heads_per_blk = get_min_heads_per_blk(
593592
num_q_heads, num_kv_heads, q.dtype, k_pages.dtype)
594593
assert num_q_heads_per_blk % num_q_heads_per_kv_head == 0
@@ -636,9 +635,7 @@ def q_index_map(heads_blk_idx, q_blk_idx, *_):
636635
page_indices,
637636
cu_q_lens,
638637
jnp.array((0, 0), jnp.int32), # seq_idx, buf_idx
639-
# Mosaic only takes dynamic scalar as ref, so we wrap it.
640-
jnp.array([num_seqs], jnp.int32), # num_seqs
641-
)
638+
num_seqs)
642639
kernel = pl.pallas_call(
643640
functools.partial(
644641
ragged_paged_attention_kernel,

0 commit comments

Comments
 (0)