@@ -1363,7 +1363,7 @@ def __init__(
13631363 "Only softmax scoring function is supported for non-grouped topk."
13641364 )
13651365
1366- moe = FusedMoEConfig (
1366+ self . moe_config : FusedMoEConfig = FusedMoEConfig (
13671367 num_experts = self .global_num_experts ,
13681368 experts_per_token = top_k ,
13691369 hidden_dim = hidden_size ,
@@ -1373,39 +1373,38 @@ def __init__(
13731373 max_num_tokens = envs .VLLM_MOE_DP_CHUNK_SIZE ,
13741374 has_bias = has_bias ,
13751375 )
1376- self . moe_config = moe
1376+
13771377 self .moe_quant_config : FusedMoEQuantConfig | None = None
13781378 self .quant_config = quant_config
13791379
1380+ def _get_quant_method () -> FusedMoEMethodBase :
1381+ """
1382+ Helper method to ensure self.quant_method is never None and
1383+ of the proper type.
1384+ """
1385+ quant_method = None
1386+ if self .quant_config is not None :
1387+ quant_method = self .quant_config .get_quant_method (self , prefix )
1388+ if quant_method is None :
1389+ quant_method = UnquantizedFusedMoEMethod (self .moe_config )
1390+ assert isinstance (quant_method , FusedMoEMethodBase )
1391+ return quant_method
1392+
13801393 # Note: get_quant_method will look at the layer's local_num_experts
13811394 # for heuristic purposes, so it must be initialized first.
1382- quant_method : QuantizeMethodBase | None = None
1383- quant_method = (
1384- UnquantizedFusedMoEMethod (moe )
1385- if quant_config is None
1386- else quant_config .get_quant_method (self , prefix )
1387- )
1388- if quant_method is None :
1389- quant_method = UnquantizedFusedMoEMethod (moe )
1390-
1391- assert quant_method is not None
1392- assert isinstance (quant_method , FusedMoEMethodBase )
1393- self .quant_method = quant_method
1394-
1395- if self .enable_eplb :
1396- from vllm .model_executor .layers .quantization .fp8 import Fp8MoEMethod
1397-
1398- if not isinstance (quant_method , (Fp8MoEMethod , UnquantizedFusedMoEMethod )):
1399- # TODO: Add support for additional quantization methods.
1400- # The implementation for other quantization methods does not
1401- # contain essential differences, but the current quant API
1402- # design causes duplicated work when extending to new
1403- # quantization methods, so I'm leaving it for now.
1404- # If you plan to add support for more quantization methods,
1405- # please refer to the implementation in `Fp8MoEMethod`.
1406- raise NotImplementedError (
1407- "EPLB is only supported for FP8 quantization for now."
1408- )
1395+ self .quant_method : FusedMoEMethodBase = _get_quant_method ()
1396+
1397+ if self .enable_eplb and not self .quant_method .supports_eplb :
1398+ # TODO: Add support for additional quantization methods.
1399+ # The implementation for other quantization methods does not
1400+ # contain essential differences, but the current quant API
1401+ # design causes duplicated work when extending to new
1402+ # quantization methods, so I'm leaving it for now.
1403+ # If you plan to add support for more quantization methods,
1404+ # please refer to the implementation in `Fp8MoEMethod`.
1405+ raise NotImplementedError (
1406+ "EPLB is only supported for FP8 quantization for now."
1407+ )
14091408
14101409 moe_quant_params = {
14111410 "num_experts" : self .local_num_experts ,
@@ -1433,19 +1432,24 @@ def __init__(
14331432 logits_shape : tuple [int , ...]
14341433
14351434 # Note here we use `num_experts` which is logical expert count
1435+ max_num_tokens = self .moe_config .max_num_tokens
14361436 if vllm_config .parallel_config .enable_dbo :
1437- states_shape = (2 , moe . max_num_tokens , self .hidden_size )
1438- logits_shape = (2 , moe . max_num_tokens , num_experts )
1437+ states_shape = (2 , max_num_tokens , self .hidden_size )
1438+ logits_shape = (2 , max_num_tokens , num_experts )
14391439 else :
1440- states_shape = (moe . max_num_tokens , self .hidden_size )
1441- logits_shape = (moe . max_num_tokens , num_experts )
1440+ states_shape = (max_num_tokens , self .hidden_size )
1441+ logits_shape = (max_num_tokens , num_experts )
14421442
14431443 self .batched_hidden_states = torch .zeros (
1444- states_shape , dtype = moe .in_dtype , device = torch .cuda .current_device ()
1444+ states_shape ,
1445+ dtype = self .moe_config .in_dtype ,
1446+ device = torch .cuda .current_device (),
14451447 )
14461448
14471449 self .batched_router_logits = torch .zeros (
1448- logits_shape , dtype = moe .in_dtype , device = torch .cuda .current_device ()
1450+ logits_shape ,
1451+ dtype = self .moe_config .in_dtype ,
1452+ device = torch .cuda .current_device (),
14491453 )
14501454
14511455 # Note: init_prepare_finalize should only be called by
0 commit comments