Skip to content

Commit 4069db3

Browse files
authored
[Bugfix] Enable padded FP4 quantization (#25947)
Signed-off-by: Roi Koren <roik@nvidia.com>
1 parent 0d37450 commit 4069db3

File tree

2 files changed

+1
-3
lines changed

2 files changed

+1
-3
lines changed

vllm/_custom_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1384,7 +1384,7 @@ def scaled_fp4_quant(
13841384
rounded_m = round_up(m, 128)
13851385
scale_n = n // block_size
13861386
rounded_n = round_up(scale_n, 4)
1387-
output_scale = torch.empty(
1387+
output_scale = torch.zeros(
13881388
(rounded_m, rounded_n // 4), device=device, dtype=torch.int32
13891389
)
13901390

vllm/utils/flashinfer.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -386,8 +386,6 @@ def flashinfer_scaled_fp4_mm(
386386
assert block_scale_a.ndim == 2 and block_scale_b.ndim == 2
387387
assert a.stride(-1) == 1 and b.stride(-1) == 1
388388
assert a.shape[1] == b.shape[1]
389-
assert block_scale_a.shape[1] == a.shape[1] // 8
390-
assert block_scale_b.shape[1] == b.shape[1] // 8
391389

392390
if backend == "cutlass":
393391
block_scale_a = block_scale_a.view(torch.uint8)

0 commit comments

Comments
 (0)