|
14 | 14 | from vllm.model_executor.layers.quantization.utils.fp8_utils import ( |
15 | 15 | per_token_group_quant_fp8, w8a8_block_fp8_matmul) |
16 | 16 | from vllm.platforms import current_platform |
17 | | -from vllm.utils import round_up |
18 | 17 |
|
19 | 18 | dg_available = False |
20 | 19 | try: |
@@ -362,17 +361,10 @@ def test_moe_permute(a, a_s, topk_ids, num_groups, topk, block_m): |
362 | 361 | M, K = a.shape |
363 | 362 |
|
364 | 363 | sorted_token_ids, m_indices, num_pad = moe_align_block_size( |
365 | | - topk_ids, block_m, num_groups, None) |
| 364 | + topk_ids, block_m, num_groups, None, pad_sorted_ids=True) |
366 | 365 |
|
367 | 366 | num_tokens = topk * M |
368 | 367 |
|
369 | | - pad_size = (round_up(sorted_token_ids.numel(), block_m) - |
370 | | - sorted_token_ids.numel()) |
371 | | - if pad_size > 0: |
372 | | - sorted_token_ids = torch.nn.functional.pad(sorted_token_ids, |
373 | | - (0, pad_size), "constant", |
374 | | - num_tokens) |
375 | | - |
376 | 368 | sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1) |
377 | 369 | m_indices = torch.repeat_interleave(m_indices, block_m, dim=0) |
378 | 370 | inv_perm = torch.argsort(sorted_token_ids)[:M * topk] |
@@ -419,9 +411,7 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, |
419 | 411 | act_out = SiluAndMul().forward_native(inter_out) |
420 | 412 | act_out_q, act_out_s = per_token_group_quant_fp8(act_out, block_k) |
421 | 413 |
|
422 | | - out = torch.zeros(a_q.shape[0], K, |
423 | | - dtype=torch.bfloat16, |
424 | | - device=a.device) |
| 414 | + out = torch.zeros(a_q.shape[0], K, dtype=torch.bfloat16, device=a.device) |
425 | 415 |
|
426 | 416 | deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( |
427 | 417 | (act_out_q, act_out_s), (w2, w2_s), out, m_indices) |
@@ -490,8 +480,8 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, block_size, |
490 | 480 | ref_out = deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, |
491 | 481 | score, topk, block_size) |
492 | 482 | else: |
493 | | - ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, |
494 | | - block_size) |
| 483 | + ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, |
| 484 | + topk, block_size) |
495 | 485 |
|
496 | 486 | out = fused_moe(a, |
497 | 487 | w1, |
|
0 commit comments