@@ -49,6 +49,7 @@ def to_float8(x, dtype=torch.float8_e4m3fn):
4949BLOCK_SIZE = [16 ]
5050WINDOW_LEFT = [- 1 , 127 ]
5151SOFT_CAP = [None , 50.0 ]
52+ HAS_SINKS = [True , False ]
5253
5354NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation.
5455
@@ -63,6 +64,7 @@ def to_float8(x, dtype=torch.float8_e4m3fn):
6364@pytest .mark .parametrize ("block_size" , BLOCK_SIZE )
6465@pytest .mark .parametrize ("window_left" , WINDOW_LEFT )
6566@pytest .mark .parametrize ("soft_cap" , SOFT_CAP )
67+ @pytest .mark .parametrize ("has_sinks" , HAS_SINKS )
6668@torch .inference_mode
6769def test_flashinfer_trtllm_decode_with_baseline (
6870 dtype : torch .dtype ,
@@ -77,9 +79,10 @@ def test_flashinfer_trtllm_decode_with_baseline(
7779 block_size : int ,
7880 window_left : int ,
7981 soft_cap : Optional [float ],
82+ has_sinks : bool ,
8083) -> None :
8184 torch .set_default_device ("cuda" )
82- current_platform .seed_everything (0 )
85+ current_platform .seed_everything (42 )
8386
8487 q_quant_dtype , kv_quant_dtype , o_quant_dtype = quant_dtypes
8588 q_quant_dtype = q_quant_dtype or dtype
@@ -101,7 +104,16 @@ def test_flashinfer_trtllm_decode_with_baseline(
101104 else :
102105 raise ValueError (f"Invalid kv_layout: { kv_layout } " )
103106
104- query = torch .randn (batch_size , num_qo_heads , head_size , dtype = dtype )
107+ # max_q_len = 1
108+ q_lens = torch .ones ((batch_size ,), dtype = torch .int32 )
109+ q_indptr = torch .cat (
110+ [
111+ torch .tensor ([0 ], dtype = torch .int32 ),
112+ torch .cumsum (q_lens , dim = 0 , dtype = torch .int32 ),
113+ ]
114+ )
115+
116+ query = torch .randn (torch .sum (q_lens ).item (), num_qo_heads , head_size , dtype = dtype )
105117 if q_quant_dtype == FP8_DTYPE :
106118 query , q_scale = to_float8 (query )
107119 ref_query = query .to (dtype ) * q_scale
@@ -112,7 +124,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
112124 kv_lens = torch .randint (1 , max_kv_len , (batch_size ,), dtype = torch .int32 )
113125 kv_lens [- 1 ] = max_kv_len
114126
115- seq_lens = kv_lens
127+ seq_lens = kv_lens + q_lens
116128 max_seq_len = torch .max (seq_lens ).item ()
117129
118130 kv_cache = torch .randn (kv_cache_shape , dtype = dtype )
@@ -148,27 +160,36 @@ def test_flashinfer_trtllm_decode_with_baseline(
148160 workspace_buffer = torch .zeros (128 * 1024 * 1024 , dtype = torch .int8 )
149161
150162 # Baseline Decode
151- wrapper = flashinfer .BatchDecodeWithPagedKVCacheWrapper (
152- workspace_buffer , kv_layout , use_tensor_cores = True
153- )
163+ if has_sinks :
164+ sinks = torch .rand (num_qo_heads , dtype = torch .float32 ) * 5
165+ wrapper = flashinfer .BatchAttentionWithAttentionSinkWrapper (
166+ float_workspace_buffer = workspace_buffer , kv_layout = kv_layout , backend = "fa2"
167+ )
168+ else :
169+ sinks = None
170+ wrapper = flashinfer .BatchPrefillWithPagedKVCacheWrapper (
171+ float_workspace_buffer = workspace_buffer , kv_layout = kv_layout , backend = "fa2"
172+ )
173+
154174 wrapper .plan (
155- kv_indptr ,
156- kv_indices ,
157- kv_last_page_lens ,
158- num_qo_heads ,
159- num_kv_heads ,
160- head_size ,
161- block_size ,
162- "NONE" ,
175+ qo_indptr = q_indptr ,
176+ paged_kv_indptr = kv_indptr ,
177+ paged_kv_indices = kv_indices ,
178+ paged_kv_last_page_len = kv_last_page_lens ,
179+ num_qo_heads = num_qo_heads ,
180+ num_kv_heads = num_kv_heads ,
181+ head_dim_qk = head_size ,
182+ page_size = block_size ,
183+ causal = True ,
163184 sm_scale = sm_scale ,
164- q_data_type = dtype ,
165- kv_data_type = dtype ,
166185 window_left = window_left ,
167186 logits_soft_cap = soft_cap ,
187+ q_data_type = dtype ,
188+ kv_data_type = dtype ,
168189 )
169-
170190 output = torch .empty (ref_query .shape , dtype = dtype )
171- wrapper .run (ref_query , ref_kv_cache , out = output )
191+ wrapper .run (ref_query , ref_kv_cache , sinks , sm_scale , out = output )
192+
172193 o_scale = 1.0
173194 o_sf_scale_float = None
174195 if o_quant_dtype == FP8_DTYPE :
@@ -202,6 +223,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
202223 bmm1_scale = q_scale * k_scale * sm_scale ,
203224 bmm2_scale = v_scale / o_scale ,
204225 window_left = window_left ,
226+ sinks = sinks ,
205227 o_sf_scale = o_sf_scale_float ,
206228 out = output_trtllm ,
207229 )
@@ -217,11 +239,13 @@ def test_flashinfer_trtllm_decode_with_baseline(
217239 output_trtllm = output_trtllm .reshape (- 1 , query .shape [1 ], query .shape [2 ])
218240
219241 if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP4_DTYPE :
220- rtol , atol = 3e-1 , 1e0
242+ rtol , atol = 7e-2 , 9e-2
221243 elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE :
222- rtol , atol = 5e -2 , 7e -2
223- else :
244+ rtol , atol = 2e -2 , 4e -2
245+ elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == dtype :
224246 rtol , atol = 1e-2 , 2e-2
247+ else :
248+ rtol , atol = 1e-2 , 1e-2
225249
226250 (
227251 torch .testing .assert_close (output , output_trtllm , atol = atol , rtol = rtol ),
@@ -239,6 +263,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
239263@pytest .mark .parametrize ("block_size" , BLOCK_SIZE )
240264@pytest .mark .parametrize ("window_left" , WINDOW_LEFT )
241265@pytest .mark .parametrize ("soft_cap" , [None ])
266+ @pytest .mark .parametrize ("has_sinks" , HAS_SINKS )
242267@torch .inference_mode
243268def test_flashinfer_trtllm_prefill_with_baseline (
244269 dtype : torch .dtype ,
@@ -253,9 +278,10 @@ def test_flashinfer_trtllm_prefill_with_baseline(
253278 block_size : int ,
254279 window_left : int ,
255280 soft_cap : Optional [float ],
281+ has_sinks : bool ,
256282) -> None :
257283 torch .set_default_device ("cuda" )
258- current_platform .seed_everything (0 )
284+ current_platform .seed_everything (42 )
259285
260286 q_quant_dtype , kv_quant_dtype , o_quant_dtype = quant_dtypes
261287 q_quant_dtype = q_quant_dtype or dtype
@@ -297,7 +323,7 @@ def test_flashinfer_trtllm_prefill_with_baseline(
297323 q_scale = 1.0
298324 ref_query = query
299325
300- kv_lens = torch .randint (0 , max_kv_len , (batch_size ,), dtype = torch .int32 )
326+ kv_lens = torch .randint (1 , max_kv_len , (batch_size ,), dtype = torch .int32 )
301327 kv_lens [- 1 ] = max_kv_len
302328
303329 seq_lens = kv_lens + q_lens
@@ -336,28 +362,36 @@ def test_flashinfer_trtllm_prefill_with_baseline(
336362 workspace_buffer = torch .zeros (128 * 1024 * 1024 , dtype = torch .int8 )
337363
338364 # Baseline Prefill
339- wrapper = flashinfer .BatchPrefillWithPagedKVCacheWrapper (
340- workspace_buffer , kv_layout
341- )
365+ if has_sinks :
366+ sinks = torch .rand (num_qo_heads , dtype = torch .float32 ) * 5
367+ wrapper = flashinfer .BatchAttentionWithAttentionSinkWrapper (
368+ float_workspace_buffer = workspace_buffer , kv_layout = kv_layout , backend = "fa2"
369+ )
370+ else :
371+ sinks = None
372+ wrapper = flashinfer .BatchPrefillWithPagedKVCacheWrapper (
373+ float_workspace_buffer = workspace_buffer , kv_layout = kv_layout , backend = "fa2"
374+ )
375+
342376 wrapper .plan (
343- q_indptr ,
344- kv_indptr ,
345- kv_indices ,
346- kv_last_page_lens ,
347- num_qo_heads ,
348- num_kv_heads ,
349- head_size ,
350- block_size ,
377+ qo_indptr = q_indptr ,
378+ paged_kv_indptr = kv_indptr ,
379+ paged_kv_indices = kv_indices ,
380+ paged_kv_last_page_len = kv_last_page_lens ,
381+ num_qo_heads = num_qo_heads ,
382+ num_kv_heads = num_kv_heads ,
383+ head_dim_qk = head_size ,
384+ page_size = block_size ,
351385 causal = True ,
352386 sm_scale = sm_scale ,
353- q_data_type = dtype ,
354- kv_data_type = dtype ,
355387 window_left = window_left ,
356388 logits_soft_cap = soft_cap ,
389+ q_data_type = dtype ,
390+ kv_data_type = dtype ,
357391 )
358-
359392 output = torch .empty (ref_query .shape , dtype = dtype )
360- wrapper .run (ref_query , ref_kv_cache , out = output )
393+ wrapper .run (ref_query , ref_kv_cache , sinks , sm_scale , out = output )
394+
361395 o_scale = 1.0
362396 o_sf_scale_float = None
363397 if o_quant_dtype == FP8_DTYPE :
@@ -395,6 +429,7 @@ def test_flashinfer_trtllm_prefill_with_baseline(
395429 cum_seq_lens_q = q_indptr ,
396430 cum_seq_lens_kv = kv_indptr ,
397431 window_left = window_left ,
432+ sinks = sinks ,
398433 o_sf_scale = o_sf_scale_float ,
399434 out = output_trtllm ,
400435 )
@@ -410,11 +445,11 @@ def test_flashinfer_trtllm_prefill_with_baseline(
410445 output_trtllm = output_trtllm .reshape (- 1 , query .shape [1 ], query .shape [2 ])
411446
412447 if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP4_DTYPE :
413- rtol , atol = 4e -1 , 1e0
448+ rtol , atol = 1e -1 , 2e-1
414449 elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE :
415- rtol , atol = 5e-2 , 7e-2
416- elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == dtype :
417450 rtol , atol = 4e-2 , 6e-2
451+ elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == dtype :
452+ rtol , atol = 2e-2 , 3e-2
418453 else :
419454 rtol , atol = 1e-2 , 1e-2
420455
0 commit comments