@@ -35,6 +35,7 @@ def to_float8(x, dtype=torch.float8_e4m3fn):
3535 # (q_quant_dtype, kv_quant_dtype, o_quant_dtype)
3636 (None , None , None ),
3737 (None , FP8_DTYPE , None ),
38+ (FP8_DTYPE , FP8_DTYPE , None ),
3839 (FP8_DTYPE , FP8_DTYPE , FP8_DTYPE ),
3940 (FP8_DTYPE , FP8_DTYPE , FP4_DTYPE ),
4041]
@@ -44,6 +45,7 @@ def to_float8(x, dtype=torch.float8_e4m3fn):
4445HEAD_SIZE = [128 ]
4546KV_LAYOUT = ["HND" ] # currently only HND is supported
4647BLOCK_SIZE = [16 ]
48+ WINDOW_LEFT = [- 1 , 127 ]
4749SOFT_CAP = [None , 50.0 ]
4850
4951NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation.
@@ -57,6 +59,7 @@ def to_float8(x, dtype=torch.float8_e4m3fn):
5759@pytest .mark .parametrize ("head_size" , HEAD_SIZE )
5860@pytest .mark .parametrize ("kv_layout" , KV_LAYOUT )
5961@pytest .mark .parametrize ("block_size" , BLOCK_SIZE )
62+ @pytest .mark .parametrize ("window_left" , WINDOW_LEFT )
6063@pytest .mark .parametrize ("soft_cap" , SOFT_CAP )
6164@torch .inference_mode
6265def test_flashinfer_trtllm_decode_with_baseline (
@@ -69,6 +72,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
6972 head_size : int ,
7073 kv_layout : str ,
7174 block_size : int ,
75+ window_left : int ,
7276 soft_cap : Optional [float ],
7377) -> None :
7478 torch .set_default_device ("cuda" )
@@ -155,6 +159,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
155159 sm_scale = sm_scale ,
156160 q_data_type = dtype ,
157161 kv_data_type = dtype ,
162+ window_left = window_left ,
158163 logits_soft_cap = soft_cap )
159164
160165 output = torch .empty (ref_query .shape , dtype = dtype )
@@ -188,6 +193,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
188193 max_seq_len = max_seq_len ,
189194 bmm1_scale = q_scale * k_scale * sm_scale ,
190195 bmm2_scale = v_scale / o_scale ,
196+ window_left = window_left ,
191197 o_sf_scale = o_sf_scale ,
192198 out = output_trtllm ,
193199 )
@@ -222,6 +228,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
222228@pytest .mark .parametrize ("head_size" , HEAD_SIZE )
223229@pytest .mark .parametrize ("kv_layout" , KV_LAYOUT )
224230@pytest .mark .parametrize ("block_size" , BLOCK_SIZE )
231+ @pytest .mark .parametrize ("window_left" , WINDOW_LEFT )
225232@pytest .mark .parametrize ("soft_cap" , [None ])
226233@torch .inference_mode
227234def test_flashinfer_trtllm_prefill_with_baseline (
@@ -234,6 +241,7 @@ def test_flashinfer_trtllm_prefill_with_baseline(
234241 head_size : int ,
235242 kv_layout : str ,
236243 block_size : int ,
244+ window_left : int ,
237245 soft_cap : Optional [float ],
238246) -> None :
239247 torch .set_default_device ("cuda" )
@@ -334,6 +342,7 @@ def test_flashinfer_trtllm_prefill_with_baseline(
334342 sm_scale = sm_scale ,
335343 q_data_type = dtype ,
336344 kv_data_type = dtype ,
345+ window_left = window_left ,
337346 logits_soft_cap = soft_cap )
338347
339348 output = torch .empty (ref_query .shape , dtype = dtype )
@@ -371,6 +380,7 @@ def test_flashinfer_trtllm_prefill_with_baseline(
371380 batch_size = batch_size ,
372381 cum_seq_lens_q = q_indptr ,
373382 cum_seq_lens_kv = kv_indptr ,
383+ window_left = window_left ,
374384 o_sf_scale = o_sf_scale ,
375385 out = output_trtllm ,
376386 )
@@ -390,6 +400,8 @@ def test_flashinfer_trtllm_prefill_with_baseline(
390400 rtol , atol = 4e-1 , 1e0
391401 elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE :
392402 rtol , atol = 5e-2 , 7e-2
403+ elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == dtype :
404+ rtol , atol = 4e-2 , 6e-2
393405 else :
394406 rtol , atol = 1e-2 , 1e-2
395407
0 commit comments