diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 9fa346cca56d..1e604db31fa7 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1385,7 +1385,7 @@ def scaled_fp4_quant( rounded_m = round_up(m, 128) scale_n = n // block_size rounded_n = round_up(scale_n, 4) - output_scale = torch.empty( + output_scale = torch.zeros( (rounded_m, rounded_n // 4), device=device, dtype=torch.int32 ) diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index 159d19bfad31..78d4a92dc1af 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -386,8 +386,6 @@ def flashinfer_scaled_fp4_mm( assert block_scale_a.ndim == 2 and block_scale_b.ndim == 2 assert a.stride(-1) == 1 and b.stride(-1) == 1 assert a.shape[1] == b.shape[1] - assert block_scale_a.shape[1] == a.shape[1] // 8 - assert block_scale_b.shape[1] == b.shape[1] // 8 if backend == "cutlass": block_scale_a = block_scale_a.view(torch.uint8)