66import pytest
77import torch
88
9- from tests .kernels .quantization .nvfp4_utils import (FLOAT4_E2M1_MAX ,
10- FLOAT8_E4M3_MAX ,
11- dequantize_nvfp4_to_dtype )
9+ from tests .kernels .quantization .nvfp4_utils import (dequantize_nvfp4_to_dtype ,
10+ get_nvfp4_global_scale )
1211from vllm .platforms import current_platform
1312from vllm .utils import round_up
1413
@@ -47,6 +46,7 @@ def to_float8(x, dtype=torch.float8_e4m3fn):
4746BLOCK_SIZE = [16 ]
4847WINDOW_LEFT = [- 1 , 127 ]
4948SOFT_CAP = [None , 50.0 ]
49+ HAS_SINKS = [True , False ]
5050
5151NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation.
5252
@@ -61,6 +61,7 @@ def to_float8(x, dtype=torch.float8_e4m3fn):
6161@pytest .mark .parametrize ("block_size" , BLOCK_SIZE )
6262@pytest .mark .parametrize ("window_left" , WINDOW_LEFT )
6363@pytest .mark .parametrize ("soft_cap" , SOFT_CAP )
64+ @pytest .mark .parametrize ("has_sinks" , HAS_SINKS )
6465@torch .inference_mode
6566def test_flashinfer_trtllm_decode_with_baseline (
6667 dtype : torch .dtype ,
@@ -74,9 +75,10 @@ def test_flashinfer_trtllm_decode_with_baseline(
7475 block_size : int ,
7576 window_left : int ,
7677 soft_cap : Optional [float ],
78+ has_sinks : bool ,
7779) -> None :
7880 torch .set_default_device ("cuda" )
79- current_platform .seed_everything (0 )
81+ current_platform .seed_everything (42 )
8082
8183 q_quant_dtype , kv_quant_dtype , o_quant_dtype = quant_dtypes
8284 q_quant_dtype = q_quant_dtype or dtype
@@ -98,7 +100,17 @@ def test_flashinfer_trtllm_decode_with_baseline(
98100 else :
99101 raise ValueError (f"Invalid kv_layout: { kv_layout } " )
100102
101- query = torch .randn (batch_size , num_qo_heads , head_size , dtype = dtype )
103+ # max_q_len = 1
104+ q_lens = torch .ones ((batch_size , ), dtype = torch .int32 )
105+ q_indptr = torch .cat ([
106+ torch .tensor ([0 ], dtype = torch .int32 ),
107+ torch .cumsum (q_lens , dim = 0 , dtype = torch .int32 ),
108+ ])
109+
110+ query = torch .randn (torch .sum (q_lens ).item (),
111+ num_qo_heads ,
112+ head_size ,
113+ dtype = dtype )
102114 if q_quant_dtype == FP8_DTYPE :
103115 query , q_scale = to_float8 (query )
104116 ref_query = query .to (dtype ) * q_scale
@@ -109,7 +121,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
109121 kv_lens = torch .randint (1 , max_kv_len , (batch_size , ), dtype = torch .int32 )
110122 kv_lens [- 1 ] = max_kv_len
111123
112- seq_lens = kv_lens
124+ seq_lens = kv_lens + q_lens
113125 max_seq_len = torch .max (seq_lens ).item ()
114126
115127 kv_cache = torch .randn (kv_cache_shape , dtype = dtype )
@@ -146,31 +158,42 @@ def test_flashinfer_trtllm_decode_with_baseline(
146158 workspace_buffer = torch .zeros (128 * 1024 * 1024 , dtype = torch .int8 )
147159
148160 # Baseline Decode
149- wrapper = flashinfer .BatchDecodeWithPagedKVCacheWrapper (
150- workspace_buffer , kv_layout , use_tensor_cores = True )
151- wrapper .plan (kv_indptr ,
152- kv_indices ,
153- kv_last_page_lens ,
154- num_qo_heads ,
155- num_kv_heads ,
156- head_size ,
157- block_size ,
158- "NONE" ,
161+ if has_sinks :
162+ sinks = torch .rand (num_qo_heads , dtype = torch .float32 ) * 5
163+ wrapper = flashinfer .BatchAttentionWithAttentionSinkWrapper (
164+ float_workspace_buffer = workspace_buffer ,
165+ kv_layout = kv_layout ,
166+ backend = "fa2" )
167+ else :
168+ sinks = None
169+ wrapper = flashinfer .BatchPrefillWithPagedKVCacheWrapper (
170+ float_workspace_buffer = workspace_buffer ,
171+ kv_layout = kv_layout ,
172+ backend = "fa2" )
173+
174+ wrapper .plan (qo_indptr = q_indptr ,
175+ paged_kv_indptr = kv_indptr ,
176+ paged_kv_indices = kv_indices ,
177+ paged_kv_last_page_len = kv_last_page_lens ,
178+ num_qo_heads = num_qo_heads ,
179+ num_kv_heads = num_kv_heads ,
180+ head_dim_qk = head_size ,
181+ page_size = block_size ,
182+ causal = True ,
159183 sm_scale = sm_scale ,
160- q_data_type = dtype ,
161- kv_data_type = dtype ,
162184 window_left = window_left ,
163- logits_soft_cap = soft_cap )
164-
185+ logits_soft_cap = soft_cap ,
186+ q_data_type = dtype ,
187+ kv_data_type = dtype )
165188 output = torch .empty (ref_query .shape , dtype = dtype )
166- wrapper .run (ref_query , ref_kv_cache , out = output )
189+ wrapper .run (ref_query , ref_kv_cache , sinks , sm_scale , out = output )
190+
167191 o_scale = 1.0
168192 o_sf_scale = None
169193 if o_quant_dtype == FP8_DTYPE :
170194 _ , o_scale = to_float8 (output )
171195 elif o_quant_dtype == FP4_DTYPE :
172- o_sf_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX ) /
173- torch .amax (output .flatten (), dim = - 1 )).to (torch .float32 )
196+ o_sf_scale = get_nvfp4_global_scale (output )
174197
175198 # TRTLLM Decode
176199 if o_quant_dtype == FP4_DTYPE :
@@ -194,6 +217,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
194217 bmm1_scale = q_scale * k_scale * sm_scale ,
195218 bmm2_scale = v_scale / o_scale ,
196219 window_left = window_left ,
220+ sinks = sinks ,
197221 o_sf_scale = o_sf_scale ,
198222 out = output_trtllm ,
199223 )
@@ -210,11 +234,13 @@ def test_flashinfer_trtllm_decode_with_baseline(
210234 query .shape [2 ])
211235
212236 if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP4_DTYPE :
213- rtol , atol = 3e-1 , 1e0
237+ rtol , atol = 7e-2 , 9e-2
214238 elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE :
215- rtol , atol = 5e -2 , 7e -2
216- else :
239+ rtol , atol = 2e -2 , 4e -2
240+ elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == dtype :
217241 rtol , atol = 1e-2 , 2e-2
242+ else :
243+ rtol , atol = 1e-2 , 1e-2
218244
219245 torch .testing .assert_close (output , output_trtllm , atol = atol , rtol = rtol ), \
220246 f"{ torch .max (torch .abs (output - output_trtllm ))} "
@@ -230,6 +256,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
230256@pytest .mark .parametrize ("block_size" , BLOCK_SIZE )
231257@pytest .mark .parametrize ("window_left" , WINDOW_LEFT )
232258@pytest .mark .parametrize ("soft_cap" , [None ])
259+ @pytest .mark .parametrize ("has_sinks" , HAS_SINKS )
233260@torch .inference_mode
234261def test_flashinfer_trtllm_prefill_with_baseline (
235262 dtype : torch .dtype ,
@@ -243,9 +270,10 @@ def test_flashinfer_trtllm_prefill_with_baseline(
243270 block_size : int ,
244271 window_left : int ,
245272 soft_cap : Optional [float ],
273+ has_sinks : bool ,
246274) -> None :
247275 torch .set_default_device ("cuda" )
248- current_platform .seed_everything (0 )
276+ current_platform .seed_everything (42 )
249277
250278 q_quant_dtype , kv_quant_dtype , o_quant_dtype = quant_dtypes
251279 q_quant_dtype = q_quant_dtype or dtype
@@ -288,7 +316,7 @@ def test_flashinfer_trtllm_prefill_with_baseline(
288316 q_scale = 1.0
289317 ref_query = query
290318
291- kv_lens = torch .randint (0 , max_kv_len , (batch_size , ), dtype = torch .int32 )
319+ kv_lens = torch .randint (1 , max_kv_len , (batch_size , ), dtype = torch .int32 )
292320 kv_lens [- 1 ] = max_kv_len
293321
294322 seq_lens = kv_lens + q_lens
@@ -328,32 +356,42 @@ def test_flashinfer_trtllm_prefill_with_baseline(
328356 workspace_buffer = torch .zeros (128 * 1024 * 1024 , dtype = torch .int8 )
329357
330358 # Baseline Prefill
331- wrapper = flashinfer .BatchPrefillWithPagedKVCacheWrapper (
332- workspace_buffer , kv_layout )
333- wrapper .plan (q_indptr ,
334- kv_indptr ,
335- kv_indices ,
336- kv_last_page_lens ,
337- num_qo_heads ,
338- num_kv_heads ,
339- head_size ,
340- block_size ,
359+ if has_sinks :
360+ sinks = torch .rand (num_qo_heads , dtype = torch .float32 ) * 5
361+ wrapper = flashinfer .BatchAttentionWithAttentionSinkWrapper (
362+ float_workspace_buffer = workspace_buffer ,
363+ kv_layout = kv_layout ,
364+ backend = "fa2" )
365+ else :
366+ sinks = None
367+ wrapper = flashinfer .BatchPrefillWithPagedKVCacheWrapper (
368+ float_workspace_buffer = workspace_buffer ,
369+ kv_layout = kv_layout ,
370+ backend = "fa2" )
371+
372+ wrapper .plan (qo_indptr = q_indptr ,
373+ paged_kv_indptr = kv_indptr ,
374+ paged_kv_indices = kv_indices ,
375+ paged_kv_last_page_len = kv_last_page_lens ,
376+ num_qo_heads = num_qo_heads ,
377+ num_kv_heads = num_kv_heads ,
378+ head_dim_qk = head_size ,
379+ page_size = block_size ,
341380 causal = True ,
342381 sm_scale = sm_scale ,
343- q_data_type = dtype ,
344- kv_data_type = dtype ,
345382 window_left = window_left ,
346- logits_soft_cap = soft_cap )
347-
383+ logits_soft_cap = soft_cap ,
384+ q_data_type = dtype ,
385+ kv_data_type = dtype )
348386 output = torch .empty (ref_query .shape , dtype = dtype )
349- wrapper .run (ref_query , ref_kv_cache , out = output )
387+ wrapper .run (ref_query , ref_kv_cache , sinks , sm_scale , out = output )
388+
350389 o_scale = 1.0
351390 o_sf_scale = None
352391 if o_quant_dtype == FP8_DTYPE :
353392 _ , o_scale = to_float8 (output )
354393 elif o_quant_dtype == FP4_DTYPE :
355- o_sf_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX ) /
356- torch .amax (output .flatten (), dim = - 1 )).to (torch .float32 )
394+ o_sf_scale = get_nvfp4_global_scale (output )
357395
358396 # TRTLLM Prefill
359397 if o_quant_dtype == FP4_DTYPE :
@@ -381,6 +419,7 @@ def test_flashinfer_trtllm_prefill_with_baseline(
381419 cum_seq_lens_q = q_indptr ,
382420 cum_seq_lens_kv = kv_indptr ,
383421 window_left = window_left ,
422+ sinks = sinks ,
384423 o_sf_scale = o_sf_scale ,
385424 out = output_trtllm ,
386425 )
@@ -397,11 +436,11 @@ def test_flashinfer_trtllm_prefill_with_baseline(
397436 query .shape [2 ])
398437
399438 if q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP4_DTYPE :
400- rtol , atol = 4e -1 , 1e0
439+ rtol , atol = 1e -1 , 2e-1
401440 elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE :
402- rtol , atol = 5e-2 , 7e-2
403- elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == dtype :
404441 rtol , atol = 4e-2 , 6e-2
442+ elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == dtype :
443+ rtol , atol = 2e-2 , 3e-2
405444 else :
406445 rtol , atol = 1e-2 , 1e-2
407446
0 commit comments