@@ -87,8 +87,9 @@ def _pagedattention_generate_qkv(
8787 q = torch .randn (batch_size , query_len , num_heads , head_dim , dtype = dtype )
8888 return q , k_pages , v_pages , page_indices
8989
90- def _round_up_closest_multiple_of (self , x , base ):
91- return (x + base - 1 ) // base * base
90+ def _ceil_div (self , a , b ):
91+ assert b != 0
92+ return (a + b - 1 ) // b
9293
9394 def _ragged_pagedattention_generate_qkv (
9495 self ,
@@ -97,64 +98,50 @@ def _ragged_pagedattention_generate_qkv(
9798 head_dim ,
9899 page_size ,
99100 num_pages ,
100- dtype = torch .float32 ,
101- num_queries_per_block = None ,
102- pad_num_q_tokens = False ,
101+ dtype ,
102+ * ,
103+ num_kv_pages_per_block = None ,
104+ max_num_batched_tokens = None ,
105+ max_num_seqs = 16 ,
103106 ):
104- num_seqs = len (seq_lens )
105- # Make sure the q_len is no longer than the kv_len. For example,
106- # seq_lens = [(1, 1328), (5, 18), (506, 463)] is not a valid test case because
107- # the 3rd sequence has q_len(506) > kv_len(463).
108- for i in range (num_seqs ):
109- cur_q_len = seq_lens [i ][0 ]
110- cur_kv_len = seq_lens [i ][1 ]
111- assert cur_q_len <= cur_kv_len , f"cur_q_len must be less than or equal to cur_kv_len. Got { cur_q_len } and { cur_kv_len } "
112-
113- query_lens = [seq_len [0 ] for seq_len in seq_lens ]
114- actual_num_q_tokens = sum (query_lens )
115- num_q_tokens = self ._round_up_closest_multiple_of (
116- actual_num_q_tokens ,
117- num_queries_per_block ) if pad_num_q_tokens else actual_num_q_tokens
118- kv_lens = torch .tensor ([seq_len [1 ] for seq_len in seq_lens ],
119- dtype = torch .int32 )
120- num_q_heads = num_heads [0 ]
121- num_kv_heads = num_heads [1 ]
122- assert num_q_heads % num_kv_heads == 0 , "num_q_heads % num_kv_heads !=0."
123- queries = torch .randn ((num_q_tokens , num_q_heads , head_dim ), dtype = dtype )
124- k_pages = torch .randn ((num_kv_heads , num_pages , page_size , head_dim ),
107+ cu_q_lens = [0 ]
108+ kv_lens = []
109+ for q_len , kv_len in seq_lens :
110+ assert q_len <= kv_len
111+ cu_q_lens .append (cu_q_lens [- 1 ] + q_len )
112+ kv_lens .append (kv_len )
113+
114+ if max_num_batched_tokens is None :
115+ max_num_batched_tokens = cu_q_lens [- 1 ]
116+ else :
117+ max_num_batched_tokens = max (cu_q_lens [- 1 ], max_num_batched_tokens )
118+ if max_num_seqs is None :
119+ max_num_seqs = len (seq_lens )
120+ else :
121+ max_num_seqs = max (len (seq_lens ), max_num_seqs )
122+ max_kv_len = max (kv_lens )
123+ pages_per_seq = self ._ceil_div (max_kv_len , page_size )
124+ pages_per_seq = (
125+ self ._ceil_div (pages_per_seq , num_kv_pages_per_block ) *
126+ num_kv_pages_per_block )
127+
128+ num_q_heads , num_kv_heads = num_heads
129+ cu_q_lens = torch .tensor (cu_q_lens , dtype = torch .int32 )
130+ kv_lens = torch .tensor (kv_lens , dtype = torch .int32 )
131+ cu_q_lens = torch .nn .functional .pad (
132+ cu_q_lens , (0 , max_num_seqs + 1 - cu_q_lens .shape [0 ]), "constant" , 0 )
133+ kv_lens = torch .nn .functional .pad (kv_lens ,
134+ (0 , max_num_seqs - kv_lens .shape [0 ]),
135+ "constant" , 0 )
136+ q = torch .randn ((max_num_batched_tokens , num_q_heads , head_dim ),
137+ dtype = dtype )
138+ k_pages = torch .randn ((num_pages , page_size , num_kv_heads , head_dim ),
125139 dtype = dtype )
126- v_pages = torch .randn ((num_kv_heads , num_pages , page_size , head_dim ),
140+ v_pages = torch .randn ((num_pages , page_size , num_kv_heads , head_dim ),
127141 dtype = dtype )
128-
129- # Create a kv_lens: i32[num_tokens]
130- kv_lens_with_paddings = [0 ] * num_q_tokens
131- for i in range (num_seqs ):
132- kv_lens_with_paddings [i ] = kv_lens [i ]
133- kv_lens_ = torch .tensor (kv_lens_with_paddings , dtype = torch .int32 )
134-
135- # Create a page_indices i32[num_tokens, pages_per_sequence]
136- max_kv_len = max ([seq_len [1 ] for seq_len in seq_lens ])
137- max_num_pages_per_seq = (max_kv_len + page_size - 1 ) // page_size
138-
139- # The reason why we need to pad max_num_pages_per_seq is that
140- # page_indices[1]=max_num_pages_per_seq and max_num_pages_per_seq%num_kv_pages_per_compute_block==0
141- max_num_pages_per_seq = 2 ** int (np .ceil (np .log2 (max_num_pages_per_seq )))
142-
143- # The assert below mimics the reality that each page get a unique index.
144- # But for testing, the assert could be omitted.
145- # assert max_num_pages_per_seq*num_q_tokens <= num_pages, f"assert failed: max_num_pages_per_seq*num_q_tokens < num_pages. Got {max_num_pages_per_seq*num_q_tokens} and {num_pages}"
146142 page_indices = torch .randint (
147- 0 , num_pages , (num_q_tokens , max_num_pages_per_seq ), dtype = torch .int32 )
148-
149- # Create a cu_q_lens i32[num_tokens + 1]
150- q_lens_with_paddings = [0 ] * num_q_tokens
151- for i in range (num_seqs ):
152- q_lens_with_paddings [i ] = query_lens [i ]
153- cu_q_lens = torch .cumsum (
154- torch .tensor ([0 ] + q_lens_with_paddings , dtype = torch .int32 ),
155- dim = 0 ,
156- dtype = torch .int32 )
157- return queries , k_pages , v_pages , page_indices , cu_q_lens , kv_lens_
143+ 0 , num_pages , (max_num_seqs , pages_per_seq ), dtype = torch .int32 )
144+ return q , k_pages , v_pages , page_indices , cu_q_lens , kv_lens
158145
159146 @unittest .skipIf (xr .device_type () != 'TPU' , "This test only works on TPU." )
160147 def test_tpu_custom_call_pallas_add (self ):
@@ -648,7 +635,7 @@ def test_paged_attention_wrapper(self):
648635 "This test only works on TPUv4+." )
649636 def test_ragged_paged_attention_wrapper_without_dynamo (self ):
650637 from torch_xla .experimental .custom_kernel import ragged_paged_attention
651- from torch_xla .experimental .pallas_kernels .ragged_paged_attention_kernel import ragged_paged_attention as jax_ragged_paged_attention
638+ from torch_xla .experimental .pallas_kernels .ragged_paged_attention_v2 import ragged_paged_attention as jax_ragged_paged_attention
652639
653640 seq_lens = [
654641 (1 , 1328 ),
@@ -663,18 +650,25 @@ def test_ragged_paged_attention_wrapper_without_dynamo(self):
663650 (1 , 17 ),
664651 (99 , 123 )
665652 ] # last 3 physical q blocks [(q_len, kv_len),...]
666- num_heads = (4 , 4 )
653+ num_heads = (32 , 8 )
667654 head_dim = 128
668655 dtype = torch .float32
669656 page_size = 16
670657 num_pages = 32768
671658 num_seqs = len (seq_lens )
672- num_kv_pages_per_block = 128
659+ num_kv_pages_per_block = 16
673660 num_queries_per_block = 8
674- block_kv_size = 256
675661
676662 q , k_pages , v_pages , page_indices , cu_q_lens , kv_lens = self ._ragged_pagedattention_generate_qkv (
677- seq_lens , num_heads , head_dim , page_size , num_pages , dtype = dtype )
663+ seq_lens ,
664+ num_heads ,
665+ head_dim ,
666+ page_size ,
667+ num_pages ,
668+ dtype ,
669+ num_kv_pages_per_block = num_kv_pages_per_block ,
670+ max_num_batched_tokens = 1024 ,
671+ max_num_seqs = 16 )
678672
679673 q_xla = q .to ("xla" )
680674 k_pages_xla = k_pages .to ("xla" )
@@ -693,7 +687,7 @@ def test_ragged_paged_attention_wrapper_without_dynamo(self):
693687 num_seqs = num_seqs ,
694688 num_kv_pages_per_block = num_kv_pages_per_block ,
695689 num_queries_per_block = num_queries_per_block ,
696- use_kernel = True )
690+ use_kernel = True )[: cu_q_lens [ num_seqs ]]
697691
698692 nonkernel_output = ragged_paged_attention (
699693 q_xla ,
@@ -726,7 +720,7 @@ def test_ragged_paged_attention_wrapper_without_dynamo(self):
726720 num_seqs = num_seqs ,
727721 num_kv_pages_per_block = num_kv_pages_per_block ,
728722 num_queries_per_block = num_queries_per_block ,
729- )[1 ]))
723+ )[: cu_q_lens [ num_seqs ] ]))
730724
731725 self .assertTrue (
732726 torch .allclose (
@@ -745,19 +739,25 @@ def _verify_ragged_paged_attention_with_dynamo(
745739 dtype ,
746740 num_kv_pages_per_block ,
747741 num_queries_per_block ,
748- pad_num_q_tokens = False ,
742+ pad_tokens_and_seqs = False ,
749743 sm_scale = 1.0 ,
750744 ):
751745 num_seqs = len (seq_lens )
746+ max_num_batched_tokens = None
747+ max_num_seqs = None
748+ if pad_tokens_and_seqs :
749+ max_num_batched_tokens = 1024
750+ max_num_seqs = 16
752751 q , k_pages , v_pages , page_indices , cu_q_lens , kv_lens = self ._ragged_pagedattention_generate_qkv (
753752 seq_lens ,
754753 num_heads ,
755754 head_dim ,
756755 page_size ,
757756 num_pages ,
758- dtype = dtype ,
759- num_queries_per_block = num_queries_per_block ,
760- pad_num_q_tokens = pad_num_q_tokens )
757+ dtype ,
758+ num_kv_pages_per_block = num_kv_pages_per_block ,
759+ max_num_batched_tokens = max_num_batched_tokens ,
760+ max_num_seqs = max_num_seqs )
761761
762762 q_xla = q .to ("xla" )
763763 k_pages_xla = k_pages .to ("xla" )
@@ -766,29 +766,7 @@ def _verify_ragged_paged_attention_with_dynamo(
766766 page_indices_xla = page_indices .to ("xla" )
767767 cu_q_lens_xla = cu_q_lens .to ("xla" )
768768
769- def ragged_paged_attention_wrapper (q , k_pages , v_pages , kv_lens ,
770- page_indices , cu_q_lens , num_seqs ,
771- num_kv_pages_per_block ,
772- num_queries_per_block , use_kernel ,
773- sm_scale ):
774- return torch .ops .xla .ragged_paged_attention (
775- q ,
776- k_pages ,
777- v_pages ,
778- kv_lens ,
779- page_indices ,
780- cu_q_lens ,
781- num_seqs ,
782- num_kv_pages_per_block ,
783- num_queries_per_block ,
784- use_kernel = use_kernel ,
785- sm_scale = sm_scale ,
786- )
787-
788- compiled_paged_attention = torch .compile (
789- ragged_paged_attention_wrapper , backend = "openxla" )
790-
791- kernel_output = compiled_paged_attention (
769+ kernel_output = torch .ops .xla .ragged_paged_attention (
792770 q_xla ,
793771 k_pages_xla ,
794772 v_pages_xla ,
@@ -800,9 +778,9 @@ def ragged_paged_attention_wrapper(q, k_pages, v_pages, kv_lens,
800778 num_queries_per_block = num_queries_per_block ,
801779 use_kernel = True ,
802780 sm_scale = sm_scale ,
803- )
781+ )[: cu_q_lens [ num_seqs ]]
804782
805- nonkernel_output = compiled_paged_attention (
783+ nonkernel_output = torch . ops . xla . ragged_paged_attention (
806784 q_xla ,
807785 k_pages_xla ,
808786 v_pages_xla ,
@@ -828,7 +806,7 @@ def ragged_paged_attention_wrapper(q, k_pages, v_pages, kv_lens,
828806 page_indices_jax = jnp .array (page_indices .numpy (), dtype = jnp .int32 )
829807 cu_q_lens_jax = jnp .array (cu_q_lens .numpy (), dtype = jnp .int32 )
830808
831- from torch_xla .experimental .pallas_kernels .ragged_paged_attention_kernel import ragged_paged_attention as jax_ragged_paged_attention
809+ from torch_xla .experimental .pallas_kernels .ragged_paged_attention_v2 import ragged_paged_attention as jax_ragged_paged_attention
832810 jax_kernel_output = torch .from_numpy (
833811 np .array (
834812 jax_ragged_paged_attention (
@@ -842,34 +820,19 @@ def ragged_paged_attention_wrapper(q, k_pages, v_pages, kv_lens,
842820 num_kv_pages_per_block = num_kv_pages_per_block ,
843821 num_queries_per_block = num_queries_per_block ,
844822 sm_scale = sm_scale ,
845- )[1 ]))
823+ )[: cu_q_lens [ num_seqs ] ]))
846824 jax_kernel_output_cpu = jax_kernel_output .cpu ()
847825
848- if pad_num_q_tokens :
849- actual_num_q_tokens = cu_q_lens [num_seqs ]
850- self .assertTrue (
851- torch .allclose (
852- kernel_output_cpu [:actual_num_q_tokens ],
853- nonkernel_output_cpu [:actual_num_q_tokens ],
854- atol = 2e-2 ,
855- rtol = 1e-2 ))
856- self .assertTrue (
857- torch .allclose (
858- kernel_output_cpu [:actual_num_q_tokens ],
859- jax_kernel_output_cpu [:actual_num_q_tokens ],
860- atol = 2e-2 ,
861- rtol = 1e-2 ))
862- else :
863- self .assertTrue (
864- torch .allclose (
865- kernel_output_cpu , nonkernel_output_cpu , atol = 2e-2 , rtol = 1e-2 ))
866- self .assertTrue (
867- torch .allclose (
868- kernel_output_cpu , jax_kernel_output_cpu , atol = 2e-2 , rtol = 1e-2 ))
826+ self .assertTrue (
827+ torch .allclose (
828+ kernel_output_cpu , nonkernel_output_cpu , atol = 2e-2 , rtol = 1e-2 ))
829+ self .assertTrue (
830+ torch .allclose (
831+ kernel_output_cpu , jax_kernel_output_cpu , atol = 2e-2 , rtol = 1e-2 ))
869832
870833 @unittest .skipIf (xr .device_type () != 'TPU' or tpu .version () < 4 ,
871834 "This test only works on TPUv4+." )
872- def test_ragged_paged_attention_wrapper_no_query_padding_with_dynamo (self ):
835+ def test_ragged_paged_attention_wrapper_no_padding_with_dynamo (self ):
873836 seq_lens = [
874837 (1 , 1328 ),
875838 (5 , 18 ),
@@ -883,7 +846,7 @@ def test_ragged_paged_attention_wrapper_no_query_padding_with_dynamo(self):
883846 (1 , 17 ),
884847 (99 , 123 )
885848 ] # last 3 physical q blocks [(q_len, kv_len),...]
886- num_heads = (4 , 4 )
849+ num_heads = (32 , 8 )
887850 head_dim = 128
888851 dtype = torch .float32
889852 page_size = 16
@@ -897,7 +860,7 @@ def test_ragged_paged_attention_wrapper_no_query_padding_with_dynamo(self):
897860 page_size ,
898861 num_pages ,
899862 dtype ,
900- num_kv_pages_per_block = 128 ,
863+ num_kv_pages_per_block = 16 ,
901864 num_queries_per_block = 8 ,
902865 sm_scale = sm_scale ,
903866 )
@@ -908,12 +871,12 @@ def test_ragged_paged_attention_wrapper_no_query_padding_with_dynamo(self):
908871 )
909872 @unittest .skipIf (xr .device_type () != 'TPU' or tpu .version () < 4 ,
910873 "This test only works on TPUv4+." )
911- def test_ragged_paged_attention_wrapper_with_query_padding_with_dynamo (
874+ def test_ragged_paged_attention_wrapper_with_padding_with_dynamo (
912875 self ,
913876 seq_lens ,
914877 num_queries_per_block ,
915878 ):
916- num_heads = (4 , 4 )
879+ num_heads = (32 , 8 )
917880 head_dim = 128
918881 dtype = torch .float32
919882 page_size = 16
@@ -927,9 +890,9 @@ def test_ragged_paged_attention_wrapper_with_query_padding_with_dynamo(
927890 page_size ,
928891 num_pages ,
929892 dtype ,
930- num_kv_pages_per_block = 128 ,
893+ num_kv_pages_per_block = 16 ,
931894 num_queries_per_block = num_queries_per_block ,
932- pad_num_q_tokens = True ,
895+ pad_tokens_and_seqs = True ,
933896 sm_scale = sm_scale ,
934897 )
935898
0 commit comments