From 7dcc878847b88241e38b6ec5902d38c64ed66d53 Mon Sep 17 00:00:00 2001 From: Chengji Yao Date: Wed, 5 Mar 2025 19:29:11 +0000 Subject: [PATCH 1/3] Fix ragged paged attention v2 to resolve the recompilation issue --- test/test_pallas.py | 18 +++++++++++------- torch_xla/experimental/custom_kernel.py | 15 ++++++--------- .../ragged_paged_attention_v2.py | 15 ++++++++------- 3 files changed, 25 insertions(+), 23 deletions(-) diff --git a/test/test_pallas.py b/test/test_pallas.py index 8173e3b2713c..d8d4eaf5b370 100644 --- a/test/test_pallas.py +++ b/test/test_pallas.py @@ -676,6 +676,7 @@ def test_ragged_paged_attention_wrapper_without_dynamo(self): kv_lens_xla = kv_lens.to("xla") page_indices_xla = page_indices.to("xla") cu_q_lens_xla = cu_q_lens.to("xla") + num_seqs_xla = torch.tensor([num_seqs], dtype=torch.int32).to("xla") output = ragged_paged_attention( q_xla, @@ -684,7 +685,7 @@ def test_ragged_paged_attention_wrapper_without_dynamo(self): kv_lens_xla, page_indices_xla, cu_q_lens_xla, - num_seqs=num_seqs, + num_seqs=num_seqs_xla, num_kv_pages_per_block=num_kv_pages_per_block, num_queries_per_block=num_queries_per_block, use_kernel=True)[:cu_q_lens[num_seqs]] @@ -696,7 +697,7 @@ def test_ragged_paged_attention_wrapper_without_dynamo(self): kv_lens_xla, page_indices_xla, cu_q_lens_xla, - num_seqs=num_seqs, + num_seqs=num_seqs_xla, num_kv_pages_per_block=num_kv_pages_per_block, num_queries_per_block=num_queries_per_block, use_kernel=False) @@ -707,6 +708,7 @@ def test_ragged_paged_attention_wrapper_without_dynamo(self): kv_lens_jax = jnp.array(kv_lens.numpy(), dtype=jnp.int32) page_indices_jax = jnp.array(page_indices.numpy(), dtype=jnp.int32) cu_q_lens_jax = jnp.array(cu_q_lens.numpy(), dtype=jnp.int32) + num_seqs_jax = jnp.array([num_seqs], dtype=jnp.int32) expected_output = torch.from_numpy( np.array( @@ -717,7 +719,7 @@ def test_ragged_paged_attention_wrapper_without_dynamo(self): kv_lens_jax, page_indices_jax, cu_q_lens_jax, - num_seqs=num_seqs, + num_seqs=num_seqs_jax, num_kv_pages_per_block=num_kv_pages_per_block, num_queries_per_block=num_queries_per_block, )[:cu_q_lens[num_seqs]])) @@ -765,6 +767,7 @@ def _verify_ragged_paged_attention_with_dynamo( kv_lens_xla = kv_lens.to("xla") page_indices_xla = page_indices.to("xla") cu_q_lens_xla = cu_q_lens.to("xla") + num_seqs_xla = torch.tensor([num_seqs], dtype=torch.int32).to("xla") kernel_output = torch.ops.xla.ragged_paged_attention( q_xla, @@ -773,7 +776,7 @@ def _verify_ragged_paged_attention_with_dynamo( kv_lens_xla, page_indices_xla, cu_q_lens_xla, - num_seqs=num_seqs, + num_seqs=num_seqs_xla, num_kv_pages_per_block=num_kv_pages_per_block, num_queries_per_block=num_queries_per_block, use_kernel=True, @@ -787,7 +790,7 @@ def _verify_ragged_paged_attention_with_dynamo( kv_lens_xla, page_indices_xla, cu_q_lens_xla, - num_seqs=num_seqs, + num_seqs=num_seqs_xla, num_kv_pages_per_block=num_kv_pages_per_block, num_queries_per_block=num_queries_per_block, use_kernel=False, @@ -805,6 +808,7 @@ def _verify_ragged_paged_attention_with_dynamo( kv_lens_jax = jnp.array(kv_lens.numpy(), dtype=jnp.int32) page_indices_jax = jnp.array(page_indices.numpy(), dtype=jnp.int32) cu_q_lens_jax = jnp.array(cu_q_lens.numpy(), dtype=jnp.int32) + num_seqs_jax = jnp.array([num_seqs], dtype=jnp.int32) from torch_xla.experimental.pallas_kernels.ragged_paged_attention_v2 import ragged_paged_attention as jax_ragged_paged_attention jax_kernel_output = torch.from_numpy( @@ -816,7 +820,7 @@ def _verify_ragged_paged_attention_with_dynamo( kv_lens_jax, page_indices_jax, cu_q_lens_jax, - num_seqs=num_seqs, + num_seqs=num_seqs_jax, num_kv_pages_per_block=num_kv_pages_per_block, num_queries_per_block=num_queries_per_block, sm_scale=sm_scale, @@ -867,7 +871,7 @@ def test_ragged_paged_attention_wrapper_no_padding_with_dynamo(self): @parameterized.product( seq_lens=[[(1, 1328), (5, 18), (500, 563)]], - num_queries_per_block=[16, 64, 128], + num_queries_per_block=[16, 32], ) @unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 4, "This test only works on TPUv4+.") diff --git a/torch_xla/experimental/custom_kernel.py b/torch_xla/experimental/custom_kernel.py index 04e1db666e8f..6d520f939c1b 100644 --- a/torch_xla/experimental/custom_kernel.py +++ b/torch_xla/experimental/custom_kernel.py @@ -1001,7 +1001,7 @@ def ragged_paged_attention( kv_lens, # i32[max_num_seqs] page_indices, # i32[max_num_seqs, pages_per_seq] cu_q_lens, # i32[max_num_seqs + 1] - num_seqs, # i32 + num_seqs, # i32[1] *, sm_scale=1.0, mask_value=None, @@ -1022,7 +1022,7 @@ def ragged_paged_attention( kv_lens, page_indices, cu_q_lens, - num_seqs, + num_seqs.item(), sm_scale=sm_scale, mask_value=mask_value, ) @@ -1054,17 +1054,14 @@ def ragged_paged_attention( ], ) - num_q_blks = ceil_div(cu_q_lens[num_seqs], num_queries_per_block) seq_buf_idx = torch.tensor([0, 0], dtype=torch.int32).to("xla") - num_seqs_ref = torch.tensor([num_seqs], dtype=torch.int32).to("xla") output = torch_xla._XLAC._xla_tpu_custom_call( [ - num_q_blks, kv_lens, page_indices, cu_q_lens, seq_buf_idx, - num_seqs_ref, + num_seqs, q, k_pages, v_pages, @@ -1733,7 +1730,7 @@ def multi_queries_paged_attention_non_xla(q: torch.Tensor, XLA_LIB.define( "ragged_paged_attention(Tensor q, Tensor k_pages, Tensor v_pages, Tensor kv_lens, Tensor page_indices, " - "Tensor cu_q_lens, int num_seqs, int num_kv_pages_per_block, int num_queries_per_block, bool use_kernel, " + "Tensor cu_q_lens, Tensor num_seqs, int num_kv_pages_per_block, int num_queries_per_block, bool use_kernel, " "float sm_scale=1.0, float? mask_value=None, int? vmem_limit_bytes=None) -> Tensor", ) @@ -1746,7 +1743,7 @@ def ragged_paged_attention_xla( kv_lens: torch.Tensor, page_indices: torch.Tensor, cu_q_lens: torch.Tensor, - num_seqs: int, + num_seqs: torch.Tensor, num_kv_pages_per_block: int, num_queries_per_block: int, use_kernel: bool, @@ -1777,7 +1774,7 @@ def ragged_paged_attention_non_xla(q: torch.Tensor, kv_lens: torch.Tensor, page_indices: torch.Tensor, cu_q_lens: torch.Tensor, - num_seqs: int, + num_seqs: torch.Tensor, num_kv_pages_per_block: int, num_queries_per_block: int, use_kernel: bool, diff --git a/torch_xla/experimental/pallas_kernels/ragged_paged_attention_v2.py b/torch_xla/experimental/pallas_kernels/ragged_paged_attention_v2.py index d1db0a79430a..07eda897b880 100644 --- a/torch_xla/experimental/pallas_kernels/ragged_paged_attention_v2.py +++ b/torch_xla/experimental/pallas_kernels/ragged_paged_attention_v2.py @@ -253,7 +253,9 @@ def prefetch_first_kv_blk(): def is_cur_q_blk_needed(q_states): done, cur_seq_idx, _ = q_states - return jnp.logical_and(done == 0, cur_seq_idx < num_seqs) + should_run = jnp.logical_and(q_len_start < cu_q_lens_ref[num_seqs], + cur_seq_idx < num_seqs) + return jnp.logical_and(done == 0, should_run) def compute_with_cur_q_blk(q_states): done, cur_seq_idx, cur_buf_idx = q_states @@ -551,7 +553,7 @@ def ragged_paged_attention( kv_lens: jax.Array, # i32[max_num_seqs] page_indices: jax.Array, # i32[max_num_seqs, pages_per_seq] cu_q_lens: jax.Array, # i32[max_num_seqs + 1] - num_seqs, # i32 + num_seqs, # i32[1] *, sm_scale: float = 1.0, mask_value: float = DEFAULT_MASK_VALUE, @@ -582,13 +584,13 @@ def ragged_paged_attention( Returns: The output of the attention. """ - check_inputs_shapes(q, k_pages, v_pages, kv_lens, page_indices, cu_q_lens) - _, num_q_heads, head_dim = q.shape + # check_inputs_shapes(q, k_pages, v_pages, kv_lens, page_indices, cu_q_lens) + num_q, num_q_heads, head_dim = q.shape _, page_size, num_kv_heads, _ = k_pages.shape num_q_per_blk = num_queries_per_block num_kv_pages_per_blk = num_kv_pages_per_block num_q_heads_per_kv_head = num_q_heads // num_kv_heads - num_q_blks = ceil_div(cu_q_lens[num_seqs], num_q_per_blk) + num_q_blks = ceil_div(num_q, num_q_per_blk) num_q_heads_per_blk, num_kv_heads_per_blk = get_min_heads_per_blk( num_q_heads, num_kv_heads, q.dtype, k_pages.dtype) assert num_q_heads_per_blk % num_q_heads_per_kv_head == 0 @@ -636,8 +638,7 @@ def q_index_map(heads_blk_idx, q_blk_idx, *_): page_indices, cu_q_lens, jnp.array((0, 0), jnp.int32), # seq_idx, buf_idx - # Mosaic only takes dynamic scalar as ref, so we wrap it. - jnp.array([num_seqs], jnp.int32), # num_seqs + num_seqs ) kernel = pl.pallas_call( functools.partial( From 67068c40f7d58b4b73e7437ca8956b74504ee73b Mon Sep 17 00:00:00 2001 From: Chengji Yao Date: Wed, 5 Mar 2025 19:41:01 +0000 Subject: [PATCH 2/3] fix lint --- .../experimental/pallas_kernels/ragged_paged_attention_v2.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torch_xla/experimental/pallas_kernels/ragged_paged_attention_v2.py b/torch_xla/experimental/pallas_kernels/ragged_paged_attention_v2.py index 07eda897b880..89968560eefa 100644 --- a/torch_xla/experimental/pallas_kernels/ragged_paged_attention_v2.py +++ b/torch_xla/experimental/pallas_kernels/ragged_paged_attention_v2.py @@ -638,8 +638,7 @@ def q_index_map(heads_blk_idx, q_blk_idx, *_): page_indices, cu_q_lens, jnp.array((0, 0), jnp.int32), # seq_idx, buf_idx - num_seqs - ) + num_seqs) kernel = pl.pallas_call( functools.partial( ragged_paged_attention_kernel, From e99636c37c93abaaf416fc3460302fd99cbf7956 Mon Sep 17 00:00:00 2001 From: Chengji Yao Date: Wed, 5 Mar 2025 21:35:15 +0000 Subject: [PATCH 3/3] fix complication and runtime check --- torch_xla/experimental/custom_kernel.py | 47 +++++++++---------- .../ragged_paged_attention_v2.py | 5 +- 2 files changed, 23 insertions(+), 29 deletions(-) diff --git a/torch_xla/experimental/custom_kernel.py b/torch_xla/experimental/custom_kernel.py index 6d520f939c1b..45cd5525dc39 100644 --- a/torch_xla/experimental/custom_kernel.py +++ b/torch_xla/experimental/custom_kernel.py @@ -900,11 +900,11 @@ def validate_ragged_paged_attention_inputs( kv_lens, # i32[max_num_seqs] page_indices, # i32[max_num_seqs, pages_per_seq] cu_q_lens, # i32[max_num_seqs + 1] - num_seqs, # i32 + num_seqs, # i32[1] ): - max_num_batched_tokens, num_q_heads, head_dim = q.shape - _, page_size, num_kv_heads, head_dim_k = k_pages.shape - max_num_seqs, pages_per_seq = page_indices.shape + _, num_q_heads, head_dim = q.shape + _, _, num_kv_heads, head_dim_k = k_pages.shape + max_num_seqs, _ = page_indices.shape if k_pages.shape != v_pages.shape: raise ValueError( f"{k_pages.shape=} and {v_pages.shape=} must have the same shape.") @@ -918,9 +918,6 @@ def validate_ragged_paged_attention_inputs( raise ValueError( f"Expected {cu_q_lens.shape=} to be ({max_num_seqs + 1},) where" " `max_num_seqs` is `page_indices.shape[0]`.") - if max_num_seqs > max_num_batched_tokens: - raise ValueError( - f"{max_num_seqs=} must be less or equal to {max_num_batched_tokens=}.") if (kv_lens.dtype != torch.int32 or page_indices.dtype != torch.int32 or cu_q_lens.dtype != torch.int32): raise ValueError( @@ -931,24 +928,24 @@ def validate_ragged_paged_attention_inputs( raise ValueError(f"{num_q_heads=} must be divisible by {num_kv_heads=}") # Must check below on runtime! - if num_seqs > max_num_seqs: - raise ValueError(f"{num_seqs=} must be less or equal to {max_num_seqs=}") - max_kv_len = torch.max(kv_lens) - min_pages_per_seq = ceil_div(max_kv_len, page_size) - if pages_per_seq < min_pages_per_seq: - raise ValueError( - f"{pages_per_seq=} must be greater or equal to" - f" {min_pages_per_seq=} given {max_kv_len=} and {page_size=}.") - if cu_q_lens[num_seqs] > max_num_batched_tokens: - raise ValueError( - f"Total q tokens {cu_q_lens[num_seqs]} must be less or equal to" - f" {max_num_batched_tokens=}.") - for i in range(num_seqs): - q_len = cu_q_lens[i + 1] - cu_q_lens[i] - kv_len = kv_lens[i] - if q_len > kv_len: - raise ValueError( - f"{q_len=} must be less or equal to {kv_len=} at sequence {i}.") + # if num_seqs > max_num_seqs: + # raise ValueError(f"{num_seqs=} must be less or equal to {max_num_seqs=}") + # max_kv_len = torch.max(kv_lens) + # min_pages_per_seq = ceil_div(max_kv_len, page_size) + # if pages_per_seq < min_pages_per_seq: + # raise ValueError( + # f"{pages_per_seq=} must be greater or equal to" + # f" {min_pages_per_seq=} given {max_kv_len=} and {page_size=}.") + # if cu_q_lens[num_seqs] > max_num_batched_tokens: + # raise ValueError( + # f"Total q tokens {cu_q_lens[num_seqs]} must be less or equal to" + # f" {max_num_batched_tokens=}.") + # for i in range(num_seqs): + # q_len = cu_q_lens[i + 1] - cu_q_lens[i] + # kv_len = kv_lens[i] + # if q_len > kv_len: + # raise ValueError( + # f"{q_len=} must be less or equal to {kv_len=} at sequence {i}.") def _ragged_paged_attention_nonkernel( diff --git a/torch_xla/experimental/pallas_kernels/ragged_paged_attention_v2.py b/torch_xla/experimental/pallas_kernels/ragged_paged_attention_v2.py index 89968560eefa..ef002d5d3003 100644 --- a/torch_xla/experimental/pallas_kernels/ragged_paged_attention_v2.py +++ b/torch_xla/experimental/pallas_kernels/ragged_paged_attention_v2.py @@ -151,9 +151,6 @@ def check_inputs_shapes( raise ValueError( f"Expected {cu_q_lens.shape=} to be ({max_num_seqs + 1},) where" " `max_num_seqs` is `page_indices.shape[0]`.") - if max_num_seqs > max_num_batched_tokens: - raise ValueError( - f"{max_num_seqs=} must be less or equal to {max_num_batched_tokens=}.") if (kv_lens.dtype != jnp.int32 or page_indices.dtype != jnp.int32 or cu_q_lens.dtype != jnp.int32): raise ValueError( @@ -584,7 +581,7 @@ def ragged_paged_attention( Returns: The output of the attention. """ - # check_inputs_shapes(q, k_pages, v_pages, kv_lens, page_indices, cu_q_lens) + check_inputs_shapes(q, k_pages, v_pages, kv_lens, page_indices, cu_q_lens) num_q, num_q_heads, head_dim = q.shape _, page_size, num_kv_heads, _ = k_pages.shape num_q_per_blk = num_queries_per_block