@@ -900,11 +900,11 @@ def validate_ragged_paged_attention_inputs(
900900 kv_lens , # i32[max_num_seqs]
901901 page_indices , # i32[max_num_seqs, pages_per_seq]
902902 cu_q_lens , # i32[max_num_seqs + 1]
903- num_seqs , # i32
903+ num_seqs , # i32[1]
904904):
905- max_num_batched_tokens , num_q_heads , head_dim = q .shape
906- _ , page_size , num_kv_heads , head_dim_k = k_pages .shape
907- max_num_seqs , pages_per_seq = page_indices .shape
905+ _ , num_q_heads , head_dim = q .shape
906+ _ , _ , num_kv_heads , head_dim_k = k_pages .shape
907+ max_num_seqs , _ = page_indices .shape
908908 if k_pages .shape != v_pages .shape :
909909 raise ValueError (
910910 f"{ k_pages .shape = } and { v_pages .shape = } must have the same shape." )
@@ -918,9 +918,6 @@ def validate_ragged_paged_attention_inputs(
918918 raise ValueError (
919919 f"Expected { cu_q_lens .shape = } to be ({ max_num_seqs + 1 } ,) where"
920920 " `max_num_seqs` is `page_indices.shape[0]`." )
921- if max_num_seqs > max_num_batched_tokens :
922- raise ValueError (
923- f"{ max_num_seqs = } must be less or equal to { max_num_batched_tokens = } ." )
924921 if (kv_lens .dtype != torch .int32 or page_indices .dtype != torch .int32 or
925922 cu_q_lens .dtype != torch .int32 ):
926923 raise ValueError (
@@ -931,24 +928,24 @@ def validate_ragged_paged_attention_inputs(
931928 raise ValueError (f"{ num_q_heads = } must be divisible by { num_kv_heads = } " )
932929
933930 # Must check below on runtime!
934- if num_seqs > max_num_seqs :
935- raise ValueError (f"{ num_seqs = } must be less or equal to { max_num_seqs = } " )
936- max_kv_len = torch .max (kv_lens )
937- min_pages_per_seq = ceil_div (max_kv_len , page_size )
938- if pages_per_seq < min_pages_per_seq :
939- raise ValueError (
940- f"{ pages_per_seq = } must be greater or equal to"
941- f" { min_pages_per_seq = } given { max_kv_len = } and { page_size = } ." )
942- if cu_q_lens [num_seqs ] > max_num_batched_tokens :
943- raise ValueError (
944- f"Total q tokens { cu_q_lens [num_seqs ]} must be less or equal to"
945- f" { max_num_batched_tokens = } ." )
946- for i in range (num_seqs ):
947- q_len = cu_q_lens [i + 1 ] - cu_q_lens [i ]
948- kv_len = kv_lens [i ]
949- if q_len > kv_len :
950- raise ValueError (
951- f"{ q_len = } must be less or equal to { kv_len = } at sequence { i } ." )
931+ # if num_seqs > max_num_seqs:
932+ # raise ValueError(f"{num_seqs=} must be less or equal to {max_num_seqs=}")
933+ # max_kv_len = torch.max(kv_lens)
934+ # min_pages_per_seq = ceil_div(max_kv_len, page_size)
935+ # if pages_per_seq < min_pages_per_seq:
936+ # raise ValueError(
937+ # f"{pages_per_seq=} must be greater or equal to"
938+ # f" {min_pages_per_seq=} given {max_kv_len=} and {page_size=}.")
939+ # if cu_q_lens[num_seqs] > max_num_batched_tokens:
940+ # raise ValueError(
941+ # f"Total q tokens {cu_q_lens[num_seqs]} must be less or equal to"
942+ # f" {max_num_batched_tokens=}.")
943+ # for i in range(num_seqs):
944+ # q_len = cu_q_lens[i + 1] - cu_q_lens[i]
945+ # kv_len = kv_lens[i]
946+ # if q_len > kv_len:
947+ # raise ValueError(
948+ # f"{q_len=} must be less or equal to {kv_len=} at sequence {i}.")
952949
953950
954951def _ragged_paged_attention_nonkernel (
@@ -1001,7 +998,7 @@ def ragged_paged_attention(
1001998 kv_lens , # i32[max_num_seqs]
1002999 page_indices , # i32[max_num_seqs, pages_per_seq]
10031000 cu_q_lens , # i32[max_num_seqs + 1]
1004- num_seqs , # i32
1001+ num_seqs , # i32[1]
10051002 * ,
10061003 sm_scale = 1.0 ,
10071004 mask_value = None ,
@@ -1022,7 +1019,7 @@ def ragged_paged_attention(
10221019 kv_lens ,
10231020 page_indices ,
10241021 cu_q_lens ,
1025- num_seqs ,
1022+ num_seqs . item () ,
10261023 sm_scale = sm_scale ,
10271024 mask_value = mask_value ,
10281025 )
@@ -1054,17 +1051,14 @@ def ragged_paged_attention(
10541051 ],
10551052 )
10561053
1057- num_q_blks = ceil_div (cu_q_lens [num_seqs ], num_queries_per_block )
10581054 seq_buf_idx = torch .tensor ([0 , 0 ], dtype = torch .int32 ).to ("xla" )
1059- num_seqs_ref = torch .tensor ([num_seqs ], dtype = torch .int32 ).to ("xla" )
10601055 output = torch_xla ._XLAC ._xla_tpu_custom_call (
10611056 [
1062- num_q_blks ,
10631057 kv_lens ,
10641058 page_indices ,
10651059 cu_q_lens ,
10661060 seq_buf_idx ,
1067- num_seqs_ref ,
1061+ num_seqs ,
10681062 q ,
10691063 k_pages ,
10701064 v_pages ,
@@ -1733,7 +1727,7 @@ def multi_queries_paged_attention_non_xla(q: torch.Tensor,
17331727
17341728XLA_LIB .define (
17351729 "ragged_paged_attention(Tensor q, Tensor k_pages, Tensor v_pages, Tensor kv_lens, Tensor page_indices, "
1736- "Tensor cu_q_lens, int num_seqs, int num_kv_pages_per_block, int num_queries_per_block, bool use_kernel, "
1730+ "Tensor cu_q_lens, Tensor num_seqs, int num_kv_pages_per_block, int num_queries_per_block, bool use_kernel, "
17371731 "float sm_scale=1.0, float? mask_value=None, int? vmem_limit_bytes=None) -> Tensor" ,
17381732)
17391733
@@ -1746,7 +1740,7 @@ def ragged_paged_attention_xla(
17461740 kv_lens : torch .Tensor ,
17471741 page_indices : torch .Tensor ,
17481742 cu_q_lens : torch .Tensor ,
1749- num_seqs : int ,
1743+ num_seqs : torch . Tensor ,
17501744 num_kv_pages_per_block : int ,
17511745 num_queries_per_block : int ,
17521746 use_kernel : bool ,
@@ -1777,7 +1771,7 @@ def ragged_paged_attention_non_xla(q: torch.Tensor,
17771771 kv_lens : torch .Tensor ,
17781772 page_indices : torch .Tensor ,
17791773 cu_q_lens : torch .Tensor ,
1780- num_seqs : int ,
1774+ num_seqs : torch . Tensor ,
17811775 num_kv_pages_per_block : int ,
17821776 num_queries_per_block : int ,
17831777 use_kernel : bool ,
0 commit comments