6060 is_deep_gemm_supported )
6161from vllm .utils .flashinfer import has_flashinfer_moe
6262
63+ import os
64+ _VLLM_MORI_MAX_TOKENS = int (os .getenv ("VLLM_MORI_MAX_TOKENS" , "4096" ))
65+ _USE_MORI_V1 = (int (os .getenv ("_USE_MORI_V1" , "1" )) == 1 )
66+
6367if TYPE_CHECKING :
6468 from vllm .model_executor .models .utils import WeightsMapper
6569
@@ -443,12 +447,14 @@ def mori_op_init(quant_dtype, dtype, rank, world_size, hdim, E, topk, max_num_to
443447 scale_dim = hdim // 128 ,
444448 scale_type_size = torch .float32 .itemsize ,
445449 max_token_type_size = dtype .itemsize ,
446- max_num_inp_token_per_rank = 4096 ,
450+ max_num_inp_token_per_rank = _VLLM_MORI_MAX_TOKENS ,
447451 num_experts_per_rank = E // world_size ,
448452 num_experts_per_token = topk ,
449453 )
450454 else :
451455 # multi node
456+ if _USE_MORI_V1 :
457+ print ('Using mori v1' )
452458 mori_config = mori .ops .EpDispatchCombineConfig (
453459 data_type = quant_dtype ,
454460 rank = rank ,
@@ -457,12 +463,14 @@ def mori_op_init(quant_dtype, dtype, rank, world_size, hdim, E, topk, max_num_to
457463 scale_dim = hdim // 128 ,
458464 scale_type_size = torch .float32 .itemsize ,
459465 max_token_type_size = dtype .itemsize ,
460- max_num_inp_token_per_rank = 4096 ,
466+ max_num_inp_token_per_rank = _VLLM_MORI_MAX_TOKENS ,
461467 num_experts_per_rank = E // world_size ,
462468 num_experts_per_token = topk ,
463469 warp_num_per_block = 16 ,
464470 block_num = 64 ,
465- kernel_type = mori .ops .EpDispatchCombineKernelType .InterNode ,
471+ kernel_type = mori .ops .EpDispatchCombineKernelType .InterNodeV1 if _USE_MORI_V1 else mori .ops .EpDispatchCombineKernelType .InterNode ,
472+ gpu_per_node = 8 ,
473+ rdma_block_num = 16 if _USE_MORI_V1 else 0 ,
466474 )
467475 mori_op = mori .ops .EpDispatchCombineOp (mori_config )
468476 return mori_op
0 commit comments