@@ -888,7 +888,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
888888        return  loaded_params 
889889
890890
891- class  HunYuanV1Base (nn .Module , SupportsLoRA , SupportsPP ,  MixtureOfExperts ):
891+ class  HunyuanV1ModelBase (nn .Module , SupportsLoRA , SupportsPP ):
892892    packed_modules_mapping  =  {
893893        "qkv_proj" : [
894894            "q_proj" ,
@@ -930,6 +930,56 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
930930        else :
931931            self .lm_head  =  PPMissingLayer ()
932932
933+     def  forward (
934+         self ,
935+         input_ids : torch .Tensor ,
936+         positions : torch .Tensor ,
937+         intermediate_tensors : Optional [IntermediateTensors ] =  None ,
938+         inputs_embeds : Optional [torch .Tensor ] =  None ,
939+     ) ->  Union [torch .Tensor , IntermediateTensors ]:
940+         model_output  =  self .model (input_ids , positions , intermediate_tensors ,
941+                                   inputs_embeds )
942+         return  model_output 
943+ 
944+     def  compute_logits (
945+         self ,
946+         hidden_states : torch .Tensor ,
947+     ) ->  Optional [torch .Tensor ]:
948+         logits  =  self .logits_processor (self .lm_head , hidden_states )
949+         return  logits 
950+ 
951+     def  make_empty_intermediate_tensors (
952+             self , batch_size : int , dtype : torch .dtype ,
953+             device : torch .device ) ->  IntermediateTensors :
954+         return  IntermediateTensors ({
955+             "hidden_states" :
956+             torch .zeros ((batch_size , self .config .hidden_size ),
957+                         dtype = dtype ,
958+                         device = device ),
959+             "residual" :
960+             torch .zeros ((batch_size , self .config .hidden_size ),
961+                         dtype = dtype ,
962+                         device = device ),
963+         })
964+ 
965+     def  load_weights (self , weights : Iterable [tuple [str ,
966+                                                    torch .Tensor ]]) ->  set [str ]:
967+         loader  =  AutoWeightsLoader (
968+             self ,
969+             skip_prefixes = (["lm_head." ]
970+                            if  self .config .tie_word_embeddings  else  None ),
971+         )
972+         return  loader .load_weights (weights )
973+ 
974+     def  get_input_embeddings (self , input_ids : torch .Tensor ) ->  torch .Tensor :
975+         return  self .model .get_input_embeddings (input_ids )
976+ 
977+ 
978+ class  HunYuanMoEV1Base (HunyuanV1ModelBase , MixtureOfExperts ):
979+ 
980+     def  __init__ (self , * , vllm_config : VllmConfig , prefix : str  =  "" ):
981+         super ().__init__ (vllm_config = vllm_config , prefix = prefix )
982+ 
933983        # Set MoE hyperparameters 
934984        self .expert_weights  =  []
935985        self .num_expert_groups  =  1 
@@ -988,57 +1038,19 @@ def update_physical_experts_metadata(
9881038                moe .n_redundant_experts  =  self .num_redundant_experts 
9891039                moe .experts .update_expert_map ()
9901040
991-     def  get_input_embeddings (self ,  input_ids :  torch . Tensor ) ->  torch . Tensor :
992-         return  self .model .get_input_embeddings ( input_ids )
1041+     def  get_expert_mapping (self ) ->  list [ tuple [ str ,  str ,  int ,  str ]] :
1042+         return  self .model .get_expert_mapping ( )
9931043
994-     def  forward (
995-         self ,
996-         input_ids : torch .Tensor ,
997-         positions : torch .Tensor ,
998-         intermediate_tensors : Optional [IntermediateTensors ] =  None ,
999-         inputs_embeds : Optional [torch .Tensor ] =  None ,
1000-     ) ->  Union [torch .Tensor , IntermediateTensors ]:
1001-         model_output  =  self .model (input_ids , positions , intermediate_tensors ,
1002-                                   inputs_embeds )
1003-         return  model_output 
10041044
1005-     def  compute_logits (
1006-         self ,
1007-         hidden_states : torch .Tensor ,
1008-     ) ->  Optional [torch .Tensor ]:
1009-         logits  =  self .logits_processor (self .lm_head , hidden_states )
1010-         return  logits 
1045+ class  HunYuanDenseV1Base (HunyuanV1ModelBase ):
10111046
1012-     def  make_empty_intermediate_tensors (
1013-             self , batch_size : int , dtype : torch .dtype ,
1014-             device : torch .device ) ->  IntermediateTensors :
1015-         return  IntermediateTensors ({
1016-             "hidden_states" :
1017-             torch .zeros ((batch_size , self .config .hidden_size ),
1018-                         dtype = dtype ,
1019-                         device = device ),
1020-             "residual" :
1021-             torch .zeros ((batch_size , self .config .hidden_size ),
1022-                         dtype = dtype ,
1023-                         device = device ),
1024-         })
1025- 
1026-     def  load_weights (self , weights : Iterable [tuple [str ,
1027-                                                    torch .Tensor ]]) ->  set [str ]:
1028-         loader  =  AutoWeightsLoader (
1029-             self ,
1030-             skip_prefixes = (["lm_head." ]
1031-                            if  self .config .tie_word_embeddings  else  None ),
1032-         )
1033-         return  loader .load_weights (weights )
1034- 
1035-     def  get_expert_mapping (self ) ->  list [tuple [str , str , int , str ]]:
1036-         return  self .model .get_expert_mapping ()
1047+     def  __init__ (self , * , vllm_config : VllmConfig , prefix : str  =  "" ):
1048+         super ().__init__ (vllm_config = vllm_config , prefix = prefix )
10371049
10381050
1039- class  HunYuanDenseV1ForCausalLM (HunYuanV1Base ):
1051+ class  HunYuanDenseV1ForCausalLM (HunYuanDenseV1Base ):
10401052    pass 
10411053
10421054
1043- class  HunYuanMoEV1ForCausalLM (HunYuanV1Base ):
1044-     pass 
1055+ class  HunYuanMoEV1ForCausalLM (HunYuanMoEV1Base ):
1056+     pass 
0 commit comments