@@ -237,8 +237,6 @@ def maybe_make_prepare_finalize(self) -> FusedMoEPrepareAndFinalize | None:
237237 else :
238238 return None
239239
240- # Note: init_prepare_finalize should only be called by
241- # prepare_communication_buffer_for_model.
242240 def init_prepare_finalize (
243241 self , layer : torch .nn .Module
244242 ) -> FusedMoEModularKernel | None :
@@ -266,8 +264,8 @@ def init_prepare_finalize(
266264 experts ,
267265 layer .shared_experts ,
268266 )
269-
270- return None
267+ else :
268+ return None
271269
272270 def select_gemm_impl (
273271 self ,
@@ -321,7 +319,9 @@ def apply(
321319@CustomOp .register ("modular_fused_moe" )
322320class FusedMoEModularMethod (FusedMoEMethodBase , CustomOp ):
323321 def __init__ (
324- self , old_moe_method : FusedMoEMethodBase , fused_experts : FusedMoEModularKernel
322+ self ,
323+ old_moe_method : FusedMoEMethodBase ,
324+ fused_experts : FusedMoEModularKernel ,
325325 ):
326326 super ().__init__ (old_moe_method .moe )
327327 # Find better way to copy attributes
@@ -374,6 +374,8 @@ def apply(
374374 logical_to_physical_map : torch .Tensor | None = None ,
375375 logical_replica_count : torch .Tensor | None = None ,
376376 ) -> torch .Tensor | tuple [torch .Tensor , torch .Tensor ]:
377+ assert self .fused_experts is not None
378+
377379 # Is getattr needed?
378380 zero_expert_num = getattr (layer , "zero_expert_num" , 0 )
379381 zero_expert_type = getattr (layer , "zero_expert_type" , None )
@@ -1372,6 +1374,8 @@ def __init__(
13721374 logits_shape , dtype = moe .in_dtype , device = torch .cuda .current_device ()
13731375 )
13741376
1377+ # Note: init_prepare_finalize should only be called by
1378+ # prepare_communication_buffer_for_model.
13751379 def init_prepare_finalize (self ) -> None :
13761380 mk = self .quant_method .init_prepare_finalize (self )
13771381 if mk is not None :
0 commit comments