@@ -68,23 +68,27 @@ def run_bench(
6868 q_data_type = torch .bfloat16 ,
6969 kv_data_type = torch .bfloat16 ,
7070 )
71+ o = wrapper_old .run (q , kv_data )
7172 ms_old = do_bench (lambda : wrapper_old .run (q , kv_data ))
7273
73- if len (p_kv_lens ) > 0 :
74+ if len (p_kv_lens ) == 1 :
7475 q_d = q [: d_q_indptr [- 1 ]]
7576 kv_d = kv_data [: d_kv_indptr [- 1 ]].unbind (1 )
7677 q_p = q [d_q_indptr [- 1 ] :]
7778 k_p , v_p = kv_data [d_kv_indptr [- 1 ] :].unbind (1 )
7879 k_p , v_p = k_p .squeeze (1 ), v_p .squeeze (1 )
80+ kv_indices_d = torch .arange (
81+ 0 , d_kv_indptr [- 1 ], device = device , dtype = torch .int32
82+ )
7983
8084 last_page_len_d = (d_seq_lens_blocks - 1 ) % page_block_size + 1
8185 wrapper_pod = flashinfer .PODWithPagedKVCacheWrapper (
8286 workspace_buffer ,
8387 kv_layout = kv_layout ,
8488 )
8589 wrapper_pod .plan (
86- d_q_indptr .to (device ),
8790 d_kv_indptr .to (device ),
91+ kv_indices_d .to (device ),
8892 last_page_len = last_page_len_d ,
8993 num_qo_heads = num_qo_heads ,
9094 num_kv_heads = num_kv_heads ,
@@ -93,6 +97,19 @@ def run_bench(
9397 q_data_type = torch .bfloat16 ,
9498 kv_data_type = torch .bfloat16 ,
9599 )
100+ o_p , o_d = wrapper_pod .run (
101+ q_p ,
102+ k_p ,
103+ v_p ,
104+ q_d ,
105+ kv_data ,
106+ causal_p = causal ,
107+ )
108+ o_pod = torch .cat ([o_d , o_p ], dim = 0 )
109+ # Verify output matches
110+ torch .testing .assert_close (
111+ o , o_pod , rtol = 1e-3 , atol = 1e-3 , msg = "POD-Attention output mismatch!"
112+ )
96113 ms_pod = do_bench (
97114 lambda : wrapper_pod .run (
98115 q_p ,
@@ -106,7 +123,7 @@ def run_bench(
106123 )
107124
108125 print (f"Elapsed time (Batched Prefill): { ms_old :.2f} ms" )
109- if len (p_kv_lens ) > 0 :
126+ if len (p_kv_lens ) == 1 :
110127 print (f"Elapsed time (POD Attention): { ms_pod :.2f} ms" )
111128 total_bytes = (
112129 q .numel () * q .element_size () + kv_data .numel () * kv_data .element_size ()
@@ -116,7 +133,7 @@ def run_bench(
116133 bandwidth_old_gb_s = total_bytes / (ms_old * 1e-3 ) / (1024 ** 3 )
117134
118135 print (f"Memory bandwidth (Batched Prefill): { bandwidth_old_gb_s :.2f} GB/s" )
119- if len (p_kv_lens ) > 0 :
136+ if len (p_kv_lens ) == 1 :
120137 bandwidth_pod_gb_s = total_bytes / (ms_pod * 1e-3 ) / (1024 ** 3 )
121138 print (f"Memory bandwidth (POD Attention): { bandwidth_pod_gb_s :.2f} GB/s" )
122139
@@ -128,8 +145,8 @@ def run_bench(
128145 # Irregular sequence lengths for prefill and decode
129146 d_q_len_configs = [[1 ] * 122 , [1 ] * 128 , [1 ] * 242 , [1 ] * 256 ]
130147 d_kv_len_configs = [[600 ] * 122 , [10000 ] * 128 , [400 ] * 242 , [8192 ] * 256 ]
131- p_q_configs = [[17 ] * 8 , [], [17 ] * 16 , []]
132- p_kv_configs = [[10000 ] * 8 , [], [8192 ] * 16 , []]
148+ p_q_configs = [[17 ] * 1 , [10000 ], [17 ] * 1 , []]
149+ p_kv_configs = [[10000 ] * 1 , [10000 ], [8192 ] * 1 , []]
133150
134151 # construct random length testcases
135152 for _ in range (1 ):
0 commit comments