@@ -281,10 +281,6 @@ def get_fused_moe_quant_config(
281281 ) -> FusedMoEQuantConfig | None :
282282 raise NotImplementedError
283283
284- @property
285- def using_modular_kernel (self ) -> bool :
286- return False
287-
288284 @property
289285 def supports_eplb (self ) -> bool :
290286 return False
@@ -337,10 +333,6 @@ def __init__(
337333 self .load_state_dict (old_moe_method .state_dict ())
338334 logger .debug ("Swapping out %s" , self .old_method_name )
339335
340- @property
341- def using_modular_kernel (self ) -> bool :
342- return True
343-
344336 @property
345337 def supports_eplb (self ) -> bool :
346338 return self ._supports_eplb
@@ -1378,13 +1370,12 @@ def __init__(
13781370
13791371 # Note: init_prepare_finalize should only be called by
13801372 # prepare_communication_buffer_for_model.
1373+ # This is called after all weight loading and post-processing, so it
1374+ # should be safe to swap out the quant_method.
13811375 def init_prepare_finalize (self ) -> None :
13821376 mk = self .quant_method .init_prepare_finalize (self )
13831377 if mk is not None :
1384- new_quant_method = FusedMoEModularMethod (self .quant_method , mk )
1385- if isinstance (self .quant_method , torch .nn .Module ):
1386- self .set_submodule (self .quant_method .name , new_quant_method )
1387- self .quant_method = new_quant_method
1378+ self .quant_method = FusedMoEModularMethod (self .quant_method , mk )
13881379
13891380 @property
13901381 def shared_experts (self ) -> torch .nn .Module | None :
@@ -2114,7 +2105,7 @@ def must_reduce_shared_expert_outputs(self) -> bool:
21142105 """
21152106 assert self .quant_method is not None
21162107 return (
2117- self .quant_method . fused_experts is not None
2108+ isinstance ( self .quant_method , FusedMoEModularMethod )
21182109 and self .quant_method .fused_experts .output_is_reduced ()
21192110 )
21202111
@@ -2228,7 +2219,7 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False):
22282219 # If there are shared experts but we are not using a modular kernel,
22292220 # the shared experts must be called here
22302221 if (
2231- not isinstance (self .quant_method . fused_experts , FusedMoEModularKernel )
2222+ not isinstance (self .quant_method , FusedMoEModularMethod )
22322223 and self .shared_experts is not None
22332224 ):
22342225 shared_output = self .shared_experts (staged_hidden_states )
@@ -2333,14 +2324,14 @@ def forward_impl(
23332324 if self .use_dp_chunking :
23342325 return self .forward_impl_chunked (hidden_states , router_logits )
23352326
2336- do_naive_dispatch_combine : bool = (
2337- self .dp_size > 1 and not self . quant_method . using_modular_kernel
2327+ do_naive_dispatch_combine : bool = self . dp_size > 1 and not isinstance (
2328+ self .quant_method , FusedMoEModularMethod
23382329 )
23392330
23402331 # If there are shared experts but we are not using a modular kernel, the
23412332 # shared experts must be called here
23422333 if (
2343- not isinstance (self .quant_method . fused_experts , FusedMoEModularKernel )
2334+ not isinstance (self .quant_method , FusedMoEModularMethod )
23442335 and self .shared_experts is not None
23452336 ):
23462337 shared_output = self .shared_experts (hidden_states )
0 commit comments