@@ -198,6 +198,26 @@ def _setup(self):
198198 self .kv_bmm_quantizer = TensorQuantizer ()
199199 self .pe_bmm_quantizer = TensorQuantizer ()
200200
201+ class CalibMoe (deekseep_model .MoE ):
202+ def __init__ (self , * args , ** kwargs ):
203+ super ().__init__ (* args , ** kwargs )
204+ self ._setup ()
205+
206+ def _setup (self ):
207+ self ._original_topk = self .gate .topk
208+ self ._original_topk_groups = self .gate .topk_groups
209+
210+ def forward (self , x : torch .Tensor ) -> torch .Tensor :
211+ # Forward all tokens to all experts for calibration
212+ self .gate .topk = self .n_routed_experts
213+ self .gate .topk_groups = self .gate .n_groups
214+ super ().forward (x )
215+ # Restore the original topk and topk_groups
216+ self .gate .topk = self ._original_topk
217+ self .gate .topk_groups = self ._original_topk_groups
218+
219+ return super ().forward (x )
220+
201221 mtq .register (
202222 original_cls = deekseep_model .RowParallelLinear ,
203223 quantized_cls = QuantRowParallelLinear ,
@@ -208,6 +228,7 @@ def _setup(self):
208228 )
209229 mtq .register (original_cls = deekseep_model .Linear , quantized_cls = QuantLinear )
210230 mtq .register (original_cls = deekseep_model .MLA , quantized_cls = QuantMLA )
231+ mtq .register (original_cls = deekseep_model .MoE , quantized_cls = CalibMoe )
211232
212233
213234def load_deepseek_model (model_config : str , model_path : str , batch_size : int ):
@@ -319,6 +340,13 @@ def state_dict_filter(state_dict):
319340 os .path .join (output_path , f"amax_dict_rank{ rank } -mp{ world_size } .pt" ),
320341 )
321342
343+ # if rank == 0:
344+ # with open("expert_activation_counts.txt", "w") as f:
345+ # for name, module in model.named_modules():
346+ # if isinstance(module, deekseep_model.MoE):
347+ # counts = module.activated_expert_counts()
348+ # f.writelines(f"{name}: {count}\n" for count in counts)
349+
322350 quant_config = get_quant_config (model .named_modules ())
323351
324352 if enable_fp8_kvcache :
0 commit comments