Skip to content

Commit f733f88

Browse files
committed
Fix padded FP4 scaling, and enable its usage in flashinfer_scaled_fp4_mm
Signed-off-by: Roi Koren <roik@nvidia.com>
1 parent 1ad3aca commit f733f88

File tree

2 files changed

+2
-4
lines changed

2 files changed

+2
-4
lines changed

vllm/_custom_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1128,7 +1128,7 @@ def scaled_fp4_quant(
11281128
f'input.dtype needs to be fp16 or bf16 but got {input.dtype}.')
11291129

11301130
# Two fp4 values will be packed into an uint8.
1131-
output = torch.empty((m, n // 2), device=device, dtype=torch.uint8)
1131+
output = torch.zeros((m, n // 2), device=device, dtype=torch.uint8)
11321132

11331133
# We use the rounded values to store the swizzled values. Due to the
11341134
# requirement of the Tensor Core, the minimum tile is 128x4 for the scales.
@@ -1139,7 +1139,7 @@ def scaled_fp4_quant(
11391139
rounded_m = round_up(m, 128)
11401140
scale_n = n // block_size
11411141
rounded_n = round_up(scale_n, 4)
1142-
output_scale = torch.empty((rounded_m, rounded_n // 4),
1142+
output_scale = torch.zeros((rounded_m, rounded_n // 4),
11431143
device=device,
11441144
dtype=torch.int32)
11451145

vllm/utils/flashinfer.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -368,8 +368,6 @@ def flashinfer_scaled_fp4_mm(a: torch.Tensor, b: torch.Tensor,
368368
assert block_scale_a.ndim == 2 and block_scale_b.ndim == 2
369369
assert a.stride(-1) == 1 and b.stride(-1) == 1
370370
assert a.shape[1] == b.shape[1]
371-
assert block_scale_a.shape[1] == a.shape[1] // 8
372-
assert block_scale_b.shape[1] == b.shape[1] // 8
373371

374372
if backend == "cutlass":
375373
block_scale_a = block_scale_a.view(torch.uint8)

0 commit comments

Comments
 (0)