Skip to content

Commit 549a9fe

Browse files
committed
fix format
Signed-off-by: Bill Nell <bnell@redhat.com>
1 parent f22b693 commit 549a9fe

File tree

1 file changed

+4
-14
lines changed

1 file changed

+4
-14
lines changed

tests/kernels/test_block_fp8.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
1515
per_token_group_quant_fp8, w8a8_block_fp8_matmul)
1616
from vllm.platforms import current_platform
17-
from vllm.utils import round_up
1817

1918
dg_available = False
2019
try:
@@ -362,17 +361,10 @@ def test_moe_permute(a, a_s, topk_ids, num_groups, topk, block_m):
362361
M, K = a.shape
363362

364363
sorted_token_ids, m_indices, num_pad = moe_align_block_size(
365-
topk_ids, block_m, num_groups, None)
364+
topk_ids, block_m, num_groups, None, pad_sorted_ids=True)
366365

367366
num_tokens = topk * M
368367

369-
pad_size = (round_up(sorted_token_ids.numel(), block_m) -
370-
sorted_token_ids.numel())
371-
if pad_size > 0:
372-
sorted_token_ids = torch.nn.functional.pad(sorted_token_ids,
373-
(0, pad_size), "constant",
374-
num_tokens)
375-
376368
sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1)
377369
m_indices = torch.repeat_interleave(m_indices, block_m, dim=0)
378370
inv_perm = torch.argsort(sorted_token_ids)[:M * topk]
@@ -419,9 +411,7 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk,
419411
act_out = SiluAndMul().forward_native(inter_out)
420412
act_out_q, act_out_s = per_token_group_quant_fp8(act_out, block_k)
421413

422-
out = torch.zeros(a_q.shape[0], K,
423-
dtype=torch.bfloat16,
424-
device=a.device)
414+
out = torch.zeros(a_q.shape[0], K, dtype=torch.bfloat16, device=a.device)
425415

426416
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
427417
(act_out_q, act_out_s), (w2, w2_s), out, m_indices)
@@ -490,8 +480,8 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size,
490480
ref_out = deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s,
491481
score, topk, block_size)
492482
else:
493-
ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk,
494-
block_size)
483+
ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score,
484+
topk, block_size)
495485

496486
out = fused_moe(a,
497487
w1,

0 commit comments

Comments
 (0)