Skip to content

Commit d00aa44

Browse files
committed
feat: Add support for MORI max tokens and V1 version (vllm-project#28)
1 parent e14933d commit d00aa44

File tree

1 file changed

+11
-3
lines changed
  • vllm/model_executor/layers/quantization

1 file changed

+11
-3
lines changed

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,10 @@
6060
is_deep_gemm_supported)
6161
from vllm.utils.flashinfer import has_flashinfer_moe
6262

63+
import os
64+
_VLLM_MORI_MAX_TOKENS = int(os.getenv("VLLM_MORI_MAX_TOKENS", "4096"))
65+
_USE_MORI_V1 = (int(os.getenv("_USE_MORI_V1", "1")) == 1)
66+
6367
if TYPE_CHECKING:
6468
from vllm.model_executor.models.utils import WeightsMapper
6569

@@ -443,12 +447,14 @@ def mori_op_init(quant_dtype, dtype, rank, world_size, hdim, E, topk, max_num_to
443447
scale_dim=hdim // 128,
444448
scale_type_size=torch.float32.itemsize,
445449
max_token_type_size=dtype.itemsize,
446-
max_num_inp_token_per_rank=4096,
450+
max_num_inp_token_per_rank=_VLLM_MORI_MAX_TOKENS,
447451
num_experts_per_rank=E // world_size,
448452
num_experts_per_token=topk,
449453
)
450454
else:
451455
# multi node
456+
if _USE_MORI_V1:
457+
print('Using mori v1')
452458
mori_config = mori.ops.EpDispatchCombineConfig(
453459
data_type=quant_dtype,
454460
rank=rank,
@@ -457,12 +463,14 @@ def mori_op_init(quant_dtype, dtype, rank, world_size, hdim, E, topk, max_num_to
457463
scale_dim=hdim // 128,
458464
scale_type_size=torch.float32.itemsize,
459465
max_token_type_size=dtype.itemsize,
460-
max_num_inp_token_per_rank=4096,
466+
max_num_inp_token_per_rank=_VLLM_MORI_MAX_TOKENS,
461467
num_experts_per_rank=E // world_size,
462468
num_experts_per_token=topk,
463469
warp_num_per_block=16,
464470
block_num=64,
465-
kernel_type=mori.ops.EpDispatchCombineKernelType.InterNode,
471+
kernel_type=mori.ops.EpDispatchCombineKernelType.InterNodeV1 if _USE_MORI_V1 else mori.ops.EpDispatchCombineKernelType.InterNode,
472+
gpu_per_node=8,
473+
rdma_block_num=16 if _USE_MORI_V1 else 0,
466474
)
467475
mori_op = mori.ops.EpDispatchCombineOp(mori_config)
468476
return mori_op

0 commit comments

Comments
 (0)