@@ -1381,7 +1381,7 @@ def __init__(
13811381 "Only softmax scoring function is supported for non-grouped topk."
13821382 )
13831383
1384- moe = FusedMoEConfig (
1384+ self . moe_config : FusedMoEConfig = FusedMoEConfig (
13851385 num_experts = self .global_num_experts ,
13861386 experts_per_token = top_k ,
13871387 hidden_dim = hidden_size ,
@@ -1392,24 +1392,26 @@ def __init__(
13921392 has_bias = has_bias ,
13931393 is_act_and_mul = is_act_and_mul ,
13941394 )
1395- self . moe_config = moe
1395+
13961396 self .moe_quant_config : FusedMoEQuantConfig | None = None
13971397 self .quant_config = quant_config
13981398
1399+ def _get_quant_method () -> FusedMoEMethodBase :
1400+ """
1401+ Helper method to ensure self.quant_method is never None and
1402+ of the proper type.
1403+ """
1404+ quant_method = None
1405+ if self .quant_config is not None :
1406+ quant_method = self .quant_config .get_quant_method (self , prefix )
1407+ if quant_method is None :
1408+ quant_method = UnquantizedFusedMoEMethod (self .moe_config )
1409+ assert isinstance (quant_method , FusedMoEMethodBase )
1410+ return quant_method
1411+
13991412 # Note: get_quant_method will look at the layer's local_num_experts
14001413 # for heuristic purposes, so it must be initialized first.
1401- quant_method : QuantizeMethodBase | None = None
1402- quant_method = (
1403- UnquantizedFusedMoEMethod (moe )
1404- if quant_config is None
1405- else quant_config .get_quant_method (self , prefix )
1406- )
1407- if quant_method is None :
1408- quant_method = UnquantizedFusedMoEMethod (moe )
1409-
1410- assert quant_method is not None
1411- assert isinstance (quant_method , FusedMoEMethodBase )
1412- self .quant_method = quant_method
1414+ self .quant_method : FusedMoEMethodBase = _get_quant_method ()
14131415
14141416 if not self .moe_config .is_act_and_mul :
14151417 # Avoid circular import
@@ -1429,20 +1431,17 @@ def __init__(
14291431 "is_act_and_mul=False is supported only for CUDA for now"
14301432 )
14311433
1432- if self .enable_eplb :
1433- from vllm .model_executor .layers .quantization .fp8 import Fp8MoEMethod
1434-
1435- if not isinstance (quant_method , (Fp8MoEMethod , UnquantizedFusedMoEMethod )):
1436- # TODO: Add support for additional quantization methods.
1437- # The implementation for other quantization methods does not
1438- # contain essential differences, but the current quant API
1439- # design causes duplicated work when extending to new
1440- # quantization methods, so I'm leaving it for now.
1441- # If you plan to add support for more quantization methods,
1442- # please refer to the implementation in `Fp8MoEMethod`.
1443- raise NotImplementedError (
1444- "EPLB is only supported for FP8 quantization for now."
1445- )
1434+ if self .enable_eplb and not self .quant_method .supports_eplb :
1435+ # TODO: Add support for additional quantization methods.
1436+ # The implementation for other quantization methods does not
1437+ # contain essential differences, but the current quant API
1438+ # design causes duplicated work when extending to new
1439+ # quantization methods, so I'm leaving it for now.
1440+ # If you plan to add support for more quantization methods,
1441+ # please refer to the implementation in `Fp8MoEMethod`.
1442+ raise NotImplementedError (
1443+ "EPLB is only supported for FP8 quantization for now."
1444+ )
14461445
14471446 moe_quant_params = {
14481447 "num_experts" : self .local_num_experts ,
@@ -1471,19 +1470,24 @@ def __init__(
14711470 logits_shape : tuple [int , ...]
14721471
14731472 # Note here we use `num_experts` which is logical expert count
1473+ max_num_tokens = self .moe_config .max_num_tokens
14741474 if vllm_config .parallel_config .enable_dbo :
1475- states_shape = (2 , moe . max_num_tokens , self .hidden_size )
1476- logits_shape = (2 , moe . max_num_tokens , num_experts )
1475+ states_shape = (2 , max_num_tokens , self .hidden_size )
1476+ logits_shape = (2 , max_num_tokens , num_experts )
14771477 else :
1478- states_shape = (moe . max_num_tokens , self .hidden_size )
1479- logits_shape = (moe . max_num_tokens , num_experts )
1478+ states_shape = (max_num_tokens , self .hidden_size )
1479+ logits_shape = (max_num_tokens , num_experts )
14801480
14811481 self .batched_hidden_states = torch .zeros (
1482- states_shape , dtype = moe .in_dtype , device = torch .cuda .current_device ()
1482+ states_shape ,
1483+ dtype = self .moe_config .in_dtype ,
1484+ device = torch .cuda .current_device (),
14831485 )
14841486
14851487 self .batched_router_logits = torch .zeros (
1486- logits_shape , dtype = moe .in_dtype , device = torch .cuda .current_device ()
1488+ logits_shape ,
1489+ dtype = self .moe_config .in_dtype ,
1490+ device = torch .cuda .current_device (),
14871491 )
14881492
14891493 # Note: init_prepare_finalize should only be called by
0 commit comments