Skip to content

Commit 3298b91

Browse files
committed
Fix padded FP4 scaling, and enable its usage in flashinfer_scaled_fp4_mm
Signed-off-by: Roi Koren <roik@nvidia.com>
1 parent 7c2ec0f commit 3298b91

File tree

2 files changed

+2
-4
lines changed

2 files changed

+2
-4
lines changed

vllm/_custom_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1374,7 +1374,7 @@ def scaled_fp4_quant(
13741374
)
13751375

13761376
# Two fp4 values will be packed into an uint8.
1377-
output = torch.empty((m, n // 2), device=device, dtype=torch.uint8)
1377+
output = torch.zeros((m, n // 2), device=device, dtype=torch.uint8)
13781378

13791379
# We use the rounded values to store the swizzled values. Due to the
13801380
# requirement of the Tensor Core, the minimum tile is 128x4 for the scales.
@@ -1385,7 +1385,7 @@ def scaled_fp4_quant(
13851385
rounded_m = round_up(m, 128)
13861386
scale_n = n // block_size
13871387
rounded_n = round_up(scale_n, 4)
1388-
output_scale = torch.empty(
1388+
output_scale = torch.zeros(
13891389
(rounded_m, rounded_n // 4), device=device, dtype=torch.int32
13901390
)
13911391

vllm/utils/flashinfer.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -384,8 +384,6 @@ def flashinfer_scaled_fp4_mm(
384384
assert block_scale_a.ndim == 2 and block_scale_b.ndim == 2
385385
assert a.stride(-1) == 1 and b.stride(-1) == 1
386386
assert a.shape[1] == b.shape[1]
387-
assert block_scale_a.shape[1] == a.shape[1] // 8
388-
assert block_scale_b.shape[1] == b.shape[1] // 8
389387

390388
if backend == "cutlass":
391389
block_scale_a = block_scale_a.view(torch.uint8)

0 commit comments

Comments
 (0)