@@ -111,7 +111,6 @@ def __init__(self, moe: FusedMoEConfig):
111111 super ().__init__ ()
112112 self .moe = moe
113113 self .moe_quant_config : FusedMoEQuantConfig | None = None
114- self .fused_experts : FusedMoEModularKernel | None = None
115114 self .topk_indices_dtype = None
116115
117116 @abstractmethod
@@ -254,9 +253,6 @@ def init_prepare_finalize(
254253 "%s for %s(%s)" , prepare_finalize .__class__ .__name__ , self , id (self )
255254 )
256255 assert self .topk_indices_dtype is None
257- assert self .fused_experts is None , (
258- f"Attempt to override experts for { id (self )} !"
259- )
260256 self .topk_indices_dtype = prepare_finalize .topk_indices_dtype ()
261257 experts = self .select_gemm_impl (prepare_finalize , layer )
262258 return FusedMoEModularKernel (
@@ -287,7 +283,11 @@ def get_fused_moe_quant_config(
287283
288284 @property
289285 def using_modular_kernel (self ) -> bool :
290- return self .fused_experts is not None
286+ return False
287+
288+ @property
289+ def supports_eplb (self ) -> bool :
290+ return False
291291
292292 @abstractmethod
293293 def apply (
@@ -330,10 +330,21 @@ def __init__(
330330 self .moe_quant_config = old_moe_method .moe_quant_config
331331 self .fused_experts = fused_experts
332332 self .topk_indices_dtype = old_moe_method .topk_indices_dtype
333-
333+ self .disable_expert_map = not fused_experts .supports_expert_map ()
334+ self .old_method_name = old_moe_method .__class__ .__name__
335+ self ._supports_eplb = old_moe_method .supports_eplb
334336 if isinstance (old_moe_method , torch .nn .Module ):
335337 self .load_state_dict (old_moe_method .state_dict ())
336- logger .debug ("Swapping out %s" , old_moe_method .__class__ .__name__ )
338+ logger .debug ("Swapping out %s" , self .old_method_name )
339+
340+ @property
341+ def using_modular_kernel (self ) -> bool :
342+ return True
343+
344+ @property
345+ @abstractmethod
346+ def supports_eplb (self ) -> bool :
347+ return self ._supports_eplb
337348
338349 def create_weights (
339350 self ,
@@ -374,12 +385,21 @@ def apply(
374385 logical_to_physical_map : torch .Tensor | None = None ,
375386 logical_replica_count : torch .Tensor | None = None ,
376387 ) -> torch .Tensor | tuple [torch .Tensor , torch .Tensor ]:
377- assert self .fused_experts is not None
378-
379388 # Is getattr needed?
380389 zero_expert_num = getattr (layer , "zero_expert_num" , 0 )
381390 zero_expert_type = getattr (layer , "zero_expert_type" , None )
382391
392+ if enable_eplb :
393+ if not self .supports_eplb :
394+ assert expert_load_view is not None
395+ assert logical_to_physical_map is not None
396+ assert logical_replica_count is not None
397+ assert isinstance (layer , FusedMoE )
398+ else :
399+ raise NotImplementedError (
400+ f"EPLB is not supported for { self .old_method_name } "
401+ )
402+
383403 select_result = FusedMoE .select_experts (
384404 hidden_states = x ,
385405 router_logits = router_logits ,
@@ -415,7 +435,7 @@ def apply(
415435 activation = activation ,
416436 global_num_experts = global_num_experts ,
417437 apply_router_weight_on_input = apply_router_weight_on_input ,
418- expert_map = expert_map ,
438+ expert_map = None if self . disable_expert_map else expert_map ,
419439 )
420440
421441 if zero_expert_num != 0 and zero_expert_type is not None :
@@ -750,7 +770,6 @@ def forward_cuda(
750770 )
751771
752772 if self .rocm_aiter_moe_enabled :
753- assert self .fused_experts is None
754773 result = self .rocm_aiter_fused_experts (
755774 hidden_states = x ,
756775 w1 = layer .w13_weight ,
@@ -771,23 +790,7 @@ def forward_cuda(
771790 activation = activation ,
772791 apply_router_weight_on_input = apply_router_weight_on_input ,
773792 )
774- elif self .fused_experts is not None :
775- if self .moe .has_bias :
776- raise ValueError ("FusedMoEModularKernel does not support bias." )
777- result = self .fused_experts (
778- hidden_states = x ,
779- w1 = layer .w13_weight ,
780- w2 = layer .w2_weight ,
781- topk_weights = topk_weights ,
782- topk_ids = topk_ids ,
783- inplace = True ,
784- activation = activation ,
785- apply_router_weight_on_input = apply_router_weight_on_input ,
786- global_num_experts = global_num_experts ,
787- expert_map = expert_map ,
788- )
789793 else :
790- assert fused_experts is not None
791794 result = fused_experts (
792795 hidden_states = x ,
793796 w1 = layer .w13_weight ,
0 commit comments