Skip to content

Commit aa79a76

Browse files
committed
fix flashinfer unit test
Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
1 parent abd9c46 commit aa79a76

File tree

2 files changed

+14
-15
lines changed

2 files changed

+14
-15
lines changed

tests/kernels/attention/test_flashinfer_trtllm_attention.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,8 @@
77
import torch
88

99
from tests.kernels.quantization.nvfp4_utils import (
10-
FLOAT4_E2M1_MAX,
11-
FLOAT8_E4M3_MAX,
1210
dequantize_nvfp4_to_dtype,
11+
get_nvfp4_global_scale,
1312
)
1413
from vllm.platforms import current_platform
1514
from vllm.utils import round_up
@@ -171,13 +170,12 @@ def test_flashinfer_trtllm_decode_with_baseline(
171170
output = torch.empty(ref_query.shape, dtype=dtype)
172171
wrapper.run(ref_query, ref_kv_cache, out=output)
173172
o_scale = 1.0
174-
o_sf_scale = None
173+
o_sf_scale_float = None
175174
if o_quant_dtype == FP8_DTYPE:
176175
_, o_scale = to_float8(output)
177176
elif o_quant_dtype == FP4_DTYPE:
178-
o_sf_scale = (
179-
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(output.flatten(), dim=-1)
180-
).to(torch.float32)
177+
o_sf_scale = get_nvfp4_global_scale(output)
178+
o_sf_scale_float = o_sf_scale.item()
181179

182180
# TRTLLM Decode
183181
if o_quant_dtype == FP4_DTYPE:
@@ -204,7 +202,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
204202
bmm1_scale=q_scale * k_scale * sm_scale,
205203
bmm2_scale=v_scale / o_scale,
206204
window_left=window_left,
207-
o_sf_scale=o_sf_scale,
205+
o_sf_scale=o_sf_scale_float,
208206
out=output_trtllm,
209207
)
210208
if o_quant_dtype == FP8_DTYPE:
@@ -361,13 +359,12 @@ def test_flashinfer_trtllm_prefill_with_baseline(
361359
output = torch.empty(ref_query.shape, dtype=dtype)
362360
wrapper.run(ref_query, ref_kv_cache, out=output)
363361
o_scale = 1.0
364-
o_sf_scale = None
362+
o_sf_scale_float = None
365363
if o_quant_dtype == FP8_DTYPE:
366364
_, o_scale = to_float8(output)
367365
elif o_quant_dtype == FP4_DTYPE:
368-
o_sf_scale = (
369-
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(output.flatten(), dim=-1)
370-
).to(torch.float32)
366+
o_sf_scale = get_nvfp4_global_scale(output)
367+
o_sf_scale_float = o_sf_scale.item()
371368

372369
# TRTLLM Prefill
373370
if o_quant_dtype == FP4_DTYPE:
@@ -398,7 +395,7 @@ def test_flashinfer_trtllm_prefill_with_baseline(
398395
cum_seq_lens_q=q_indptr,
399396
cum_seq_lens_kv=kv_indptr,
400397
window_left=window_left,
401-
o_sf_scale=o_sf_scale,
398+
o_sf_scale=o_sf_scale_float,
402399
out=output_trtllm,
403400
)
404401
if o_quant_dtype == FP8_DTYPE:

tests/kernels/quantization/nvfp4_utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,11 @@ def break_fp4_bytes(a, dtype):
6666
return values.reshape(m, n * 2).to(dtype=dtype)
6767

6868

69+
def get_nvfp4_global_scale(a: torch.Tensor):
70+
return (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.abs(a).max().to(torch.float32)
71+
72+
6973
def quant_nvfp4_tensor(a: torch.Tensor):
70-
a_global_scale = (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.abs(a).max().to(
71-
torch.float32
72-
)
74+
a_global_scale = get_nvfp4_global_scale(a)
7375
a_quant, a_block_scale = scaled_fp4_quant(a, a_global_scale)
7476
return a_quant, a_block_scale, a_global_scale

0 commit comments

Comments
 (0)