@@ -505,7 +505,7 @@ def forward(
505505 # native FusedMoE. here we need to design a better FusedMoE
506506 # (maybe using AscendFusedMoE) to enable these different
507507 # communication schema.
508- final_hidden_states = self .experts .quant_method (
508+ final_hidden_states = self .experts .quant_method . apply (
509509 layer = self .experts ,
510510 x = hidden_states ,
511511 router_logits = router_logits ,
@@ -937,6 +937,8 @@ def sample(
937937 return next_tokens
938938
939939 def load_weights (self , weights : Iterable [Tuple [str , torch .Tensor ]]):
940+ tp_size = get_tp_group ().world_size
941+ tp_rank = get_tp_group ().rank_in_group
940942 stacked_params_mapping = [
941943 # (param_name, shard_name, shard_id)
942944 ("qkv_proj" , "q_proj" , "q" ),
@@ -972,6 +974,51 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
972974 if "module" in name :
973975 continue
974976
977+ if name .endswith ('kv_cache_offset' ):
978+ continue
979+
980+ if name .endswith ("k_proj.kv_cache_scale" ):
981+ remapped_kv_scale_name = name .replace (
982+ "k_proj.kv_cache_scale" , "attn.key_antiquant_scale" )
983+ if remapped_kv_scale_name not in params_dict :
984+ logger .warning_once (
985+ "Found kv scale in the checkpoint "
986+ f"(e.g. { name } ), but not found the expected "
987+ f"name in the model "
988+ f"(e.g. { remapped_kv_scale_name } ). "
989+ "kv-scale is not loaded." )
990+ continue
991+ else :
992+ name = remapped_kv_scale_name
993+ param = params_dict [name ]
994+ loaded_weight = torch .tensor_split (loaded_weight ,
995+ tp_size ,
996+ dim = 0 )[tp_rank ]
997+ weight_loader = getattr (param , "weight_loader" ,
998+ default_weight_loader )
999+ weight_loader (param , loaded_weight )
1000+
1001+ if name .endswith ("v_proj.kv_cache_scale" ):
1002+ remapped_kv_scale_name = name .replace (
1003+ "v_proj.kv_cache_scale" , "attn.value_antiquant_scale" )
1004+ if remapped_kv_scale_name not in params_dict :
1005+ logger .warning_once (
1006+ "Found kv scale in the checkpoint "
1007+ f"(e.g. { name } ), but not found the expected "
1008+ f"name in the model "
1009+ f"(e.g. { remapped_kv_scale_name } ). "
1010+ "kv-scale is not loaded." )
1011+ continue
1012+ else :
1013+ name = remapped_kv_scale_name
1014+ param = params_dict [name ]
1015+ loaded_weight = torch .tensor_split (loaded_weight ,
1016+ tp_size ,
1017+ dim = 0 )[tp_rank ]
1018+ weight_loader = getattr (param , "weight_loader" ,
1019+ default_weight_loader )
1020+ weight_loader (param , loaded_weight )
1021+
9751022 for (param_name , weight_name , shard_id ) in stacked_params_mapping :
9761023 # Skip non-stacked layers and experts (experts handled below).
9771024 if weight_name not in name :
0 commit comments