@@ -359,6 +359,15 @@ def forward(
359359 hidden_states = self .norm (hidden_states )
360360 return hidden_states
361361
362+ def get_expert_mapping (self ) -> list [tuple [str , str , int , str ]]:
363+ # Params for weights, fp8 weight scales, fp8 activation scales
364+ # (param_name, weight_name, expert_id, shard_id)
365+ return FusedMoE .make_expert_params_mapping (
366+ ckpt_gate_proj_name = "gate_proj" ,
367+ ckpt_down_proj_name = "down_proj" ,
368+ ckpt_up_proj_name = "up_proj" ,
369+ num_experts = self .config .num_experts )
370+
362371 def load_weights (self , weights : Iterable [tuple [str ,
363372 torch .Tensor ]]) -> set [str ]:
364373 stacked_params_mapping = [
@@ -370,16 +379,9 @@ def load_weights(self, weights: Iterable[tuple[str,
370379 ("gate_up_proj" , "up_proj" , 1 ),
371380 ]
372381
373- # Params for weights, fp8 weight scales, fp8 activation scales
374- # (param_name, weight_name, expert_id, shard_id)
375- expert_params_mapping = FusedMoE .make_expert_params_mapping (
376- ckpt_gate_proj_name = "gate_proj" ,
377- ckpt_down_proj_name = "down_proj" ,
378- ckpt_up_proj_name = "up_proj" ,
379- num_experts = self .config .num_experts )
380-
381382 params_dict = dict (self .named_parameters ())
382383 loaded_params : set [str ] = set ()
384+ expert_params_mapping = self .get_expert_mapping ()
383385 for name , loaded_weight in weights :
384386 if "rotary_emb.inv_freq" in name :
385387 continue
@@ -511,3 +513,6 @@ def sample(
511513 def load_weights (self , weights : Iterable [tuple [str , torch .Tensor ]]):
512514 loader = AutoWeightsLoader (self )
513515 return loader .load_weights (weights )
516+
517+ def get_expert_mapping (self ) -> list [tuple [str , str , int , str ]]:
518+ return self .model .get_expert_mapping ()
0 commit comments