@@ -888,7 +888,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
888888 return loaded_params
889889
890890
891- class HunYuanV1Base (nn .Module , SupportsLoRA , SupportsPP , MixtureOfExperts ):
891+ class HunyuanV1ModelBase (nn .Module , SupportsLoRA , SupportsPP ):
892892 packed_modules_mapping = {
893893 "qkv_proj" : [
894894 "q_proj" ,
@@ -930,6 +930,56 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
930930 else :
931931 self .lm_head = PPMissingLayer ()
932932
933+ def forward (
934+ self ,
935+ input_ids : torch .Tensor ,
936+ positions : torch .Tensor ,
937+ intermediate_tensors : Optional [IntermediateTensors ] = None ,
938+ inputs_embeds : Optional [torch .Tensor ] = None ,
939+ ) -> Union [torch .Tensor , IntermediateTensors ]:
940+ model_output = self .model (input_ids , positions , intermediate_tensors ,
941+ inputs_embeds )
942+ return model_output
943+
944+ def compute_logits (
945+ self ,
946+ hidden_states : torch .Tensor ,
947+ ) -> Optional [torch .Tensor ]:
948+ logits = self .logits_processor (self .lm_head , hidden_states )
949+ return logits
950+
951+ def make_empty_intermediate_tensors (
952+ self , batch_size : int , dtype : torch .dtype ,
953+ device : torch .device ) -> IntermediateTensors :
954+ return IntermediateTensors ({
955+ "hidden_states" :
956+ torch .zeros ((batch_size , self .config .hidden_size ),
957+ dtype = dtype ,
958+ device = device ),
959+ "residual" :
960+ torch .zeros ((batch_size , self .config .hidden_size ),
961+ dtype = dtype ,
962+ device = device ),
963+ })
964+
965+ def load_weights (self , weights : Iterable [tuple [str ,
966+ torch .Tensor ]]) -> set [str ]:
967+ loader = AutoWeightsLoader (
968+ self ,
969+ skip_prefixes = (["lm_head." ]
970+ if self .config .tie_word_embeddings else None ),
971+ )
972+ return loader .load_weights (weights )
973+
974+ def get_input_embeddings (self , input_ids : torch .Tensor ) -> torch .Tensor :
975+ return self .model .get_input_embeddings (input_ids )
976+
977+
978+ class HunYuanMoEV1Base (HunyuanV1ModelBase , MixtureOfExperts ):
979+
980+ def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
981+ super ().__init__ (vllm_config = vllm_config , prefix = prefix )
982+
933983 # Set MoE hyperparameters
934984 self .expert_weights = []
935985 self .num_expert_groups = 1
@@ -988,57 +1038,19 @@ def update_physical_experts_metadata(
9881038 moe .n_redundant_experts = self .num_redundant_experts
9891039 moe .experts .update_expert_map ()
9901040
991- def get_input_embeddings (self , input_ids : torch . Tensor ) -> torch . Tensor :
992- return self .model .get_input_embeddings ( input_ids )
1041+ def get_expert_mapping (self ) -> list [ tuple [ str , str , int , str ]] :
1042+ return self .model .get_expert_mapping ( )
9931043
994- def forward (
995- self ,
996- input_ids : torch .Tensor ,
997- positions : torch .Tensor ,
998- intermediate_tensors : Optional [IntermediateTensors ] = None ,
999- inputs_embeds : Optional [torch .Tensor ] = None ,
1000- ) -> Union [torch .Tensor , IntermediateTensors ]:
1001- model_output = self .model (input_ids , positions , intermediate_tensors ,
1002- inputs_embeds )
1003- return model_output
10041044
1005- def compute_logits (
1006- self ,
1007- hidden_states : torch .Tensor ,
1008- ) -> Optional [torch .Tensor ]:
1009- logits = self .logits_processor (self .lm_head , hidden_states )
1010- return logits
1045+ class HunYuanDenseV1Base (HunyuanV1ModelBase ):
10111046
1012- def make_empty_intermediate_tensors (
1013- self , batch_size : int , dtype : torch .dtype ,
1014- device : torch .device ) -> IntermediateTensors :
1015- return IntermediateTensors ({
1016- "hidden_states" :
1017- torch .zeros ((batch_size , self .config .hidden_size ),
1018- dtype = dtype ,
1019- device = device ),
1020- "residual" :
1021- torch .zeros ((batch_size , self .config .hidden_size ),
1022- dtype = dtype ,
1023- device = device ),
1024- })
1025-
1026- def load_weights (self , weights : Iterable [tuple [str ,
1027- torch .Tensor ]]) -> set [str ]:
1028- loader = AutoWeightsLoader (
1029- self ,
1030- skip_prefixes = (["lm_head." ]
1031- if self .config .tie_word_embeddings else None ),
1032- )
1033- return loader .load_weights (weights )
1034-
1035- def get_expert_mapping (self ) -> list [tuple [str , str , int , str ]]:
1036- return self .model .get_expert_mapping ()
1047+ def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
1048+ super ().__init__ (vllm_config = vllm_config , prefix = prefix )
10371049
10381050
1039- class HunYuanDenseV1ForCausalLM (HunYuanV1Base ):
1051+ class HunYuanDenseV1ForCausalLM (HunYuanDenseV1Base ):
10401052 pass
10411053
10421054
1043- class HunYuanMoEV1ForCausalLM (HunYuanV1Base ):
1044- pass
1055+ class HunYuanMoEV1ForCausalLM (HunYuanMoEV1Base ):
1056+ pass
0 commit comments