@@ -35,6 +35,7 @@ def to_float8(x, dtype=torch.float8_e4m3fn):
3535    # (q_quant_dtype, kv_quant_dtype, o_quant_dtype) 
3636    (None , None , None ),
3737    (None , FP8_DTYPE , None ),
38+     (FP8_DTYPE , FP8_DTYPE , None ),
3839    (FP8_DTYPE , FP8_DTYPE , FP8_DTYPE ),
3940    (FP8_DTYPE , FP8_DTYPE , FP4_DTYPE ),
4041]
@@ -44,6 +45,7 @@ def to_float8(x, dtype=torch.float8_e4m3fn):
4445HEAD_SIZE  =  [128 ]
4546KV_LAYOUT  =  ["HND" ]  # currently only HND is supported 
4647BLOCK_SIZE  =  [16 ]
48+ WINDOW_LEFT  =  [- 1 , 127 ]
4749SOFT_CAP  =  [None , 50.0 ]
4850
4951NUM_BLOCKS  =  32768   # Large enough to test overflow in index calculation. 
@@ -57,6 +59,7 @@ def to_float8(x, dtype=torch.float8_e4m3fn):
5759@pytest .mark .parametrize ("head_size" , HEAD_SIZE ) 
5860@pytest .mark .parametrize ("kv_layout" , KV_LAYOUT ) 
5961@pytest .mark .parametrize ("block_size" , BLOCK_SIZE ) 
62+ @pytest .mark .parametrize ("window_left" , WINDOW_LEFT ) 
6063@pytest .mark .parametrize ("soft_cap" , SOFT_CAP ) 
6164@torch .inference_mode  
6265def  test_flashinfer_trtllm_decode_with_baseline (
@@ -69,6 +72,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
6972    head_size : int ,
7073    kv_layout : str ,
7174    block_size : int ,
75+     window_left : int ,
7276    soft_cap : Optional [float ],
7377) ->  None :
7478    torch .set_default_device ("cuda" )
@@ -155,6 +159,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
155159                 sm_scale = sm_scale ,
156160                 q_data_type = dtype ,
157161                 kv_data_type = dtype ,
162+                  window_left = window_left ,
158163                 logits_soft_cap = soft_cap )
159164
160165    output  =  torch .empty (ref_query .shape , dtype = dtype )
@@ -188,6 +193,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
188193        max_seq_len = max_seq_len ,
189194        bmm1_scale = q_scale  *  k_scale  *  sm_scale ,
190195        bmm2_scale = v_scale  /  o_scale ,
196+         window_left = window_left ,
191197        o_sf_scale = o_sf_scale ,
192198        out = output_trtllm ,
193199    )
@@ -222,6 +228,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
222228@pytest .mark .parametrize ("head_size" , HEAD_SIZE ) 
223229@pytest .mark .parametrize ("kv_layout" , KV_LAYOUT ) 
224230@pytest .mark .parametrize ("block_size" , BLOCK_SIZE ) 
231+ @pytest .mark .parametrize ("window_left" , WINDOW_LEFT ) 
225232@pytest .mark .parametrize ("soft_cap" , [None ]) 
226233@torch .inference_mode  
227234def  test_flashinfer_trtllm_prefill_with_baseline (
@@ -234,6 +241,7 @@ def test_flashinfer_trtllm_prefill_with_baseline(
234241    head_size : int ,
235242    kv_layout : str ,
236243    block_size : int ,
244+     window_left : int ,
237245    soft_cap : Optional [float ],
238246) ->  None :
239247    torch .set_default_device ("cuda" )
@@ -334,6 +342,7 @@ def test_flashinfer_trtllm_prefill_with_baseline(
334342                 sm_scale = sm_scale ,
335343                 q_data_type = dtype ,
336344                 kv_data_type = dtype ,
345+                  window_left = window_left ,
337346                 logits_soft_cap = soft_cap )
338347
339348    output  =  torch .empty (ref_query .shape , dtype = dtype )
@@ -371,6 +380,7 @@ def test_flashinfer_trtllm_prefill_with_baseline(
371380        batch_size = batch_size ,
372381        cum_seq_lens_q = q_indptr ,
373382        cum_seq_lens_kv = kv_indptr ,
383+         window_left = window_left ,
374384        o_sf_scale = o_sf_scale ,
375385        out = output_trtllm ,
376386    )
@@ -390,6 +400,8 @@ def test_flashinfer_trtllm_prefill_with_baseline(
390400        rtol , atol  =  4e-1 , 1e0 
391401    elif  q_quant_dtype  ==  FP8_DTYPE  and  o_quant_dtype  ==  FP8_DTYPE :
392402        rtol , atol  =  5e-2 , 7e-2 
403+     elif  q_quant_dtype  ==  FP8_DTYPE  and  o_quant_dtype  ==  dtype :
404+         rtol , atol  =  4e-2 , 6e-2 
393405    else :
394406        rtol , atol  =  1e-2 , 1e-2 
395407
0 commit comments