2525from  vllm .attention  import  Attention 
2626from  vllm .config  import  VllmConfig 
2727from  vllm .distributed  import  get_tensor_model_parallel_world_size 
28- from  vllm .distributed .utils  import  divide 
2928from  vllm .logger  import  init_logger 
3029from  vllm .model_executor .layers .linear  import  (ColumnParallelLinear ,
3130                                               ReplicatedLinear ,
@@ -128,10 +127,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
128127
129128        config  =  vllm_config .model_config .hf_config 
130129        cache_config  =  vllm_config .cache_config 
130+         model_config  =  vllm_config .model_config 
131+         parallel_config  =  vllm_config .parallel_config 
131132
132133        self .config  =  config 
133-         self .vocab_size  =  config . vocab_size 
134-         self .unpadded_vocab_size  =  config . vocab_size 
134+         self .vocab_size  =  model_config . get_vocab_size () 
135+         self .unpadded_vocab_size  =  model_config . get_vocab_size () 
135136
136137        self .model : PreTrainedModel  =  AutoModel .from_config (
137138            self .config ,
@@ -145,15 +146,17 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
145146        self .apply_base_model_tp_plan (self .model )
146147
147148        # Attention modifications (assumes 1 attention op per hidden layer) 
148-         tp_size  =  get_tensor_model_parallel_world_size ()
149+         num_heads  =  model_config .get_num_attention_heads (parallel_config )
150+         head_size  =  model_config .get_head_size ()
151+         num_kv_heads  =  model_config .get_num_kv_heads (parallel_config )
149152        self .attention_instances  =  [
150153            Attention (
151-                 num_heads = divide ( config . num_attention_heads ,  tp_size ) ,
152-                 head_size = config . head_dim ,
154+                 num_heads = num_heads ,
155+                 head_size = head_size ,
153156                # NOTE: We use Llama scale as default, if it's set by 
154157                # Transformers, it's updated in vllm_flash_attention_forward 
155-                 scale = config . head_dim ** - 0.5 ,
156-                 num_kv_heads = divide ( config . num_key_value_heads ,  tp_size ) ,
158+                 scale = head_size ** - 0.5 ,
159+                 num_kv_heads = num_kv_heads ,
157160                cache_config = cache_config ,
158161                quant_config = self .quant_config ,
159162                prefix = f"{ i }  .attn" ) for  i  in  range (config .num_hidden_layers )
@@ -163,7 +166,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
163166        self .replace_vocab_embed_class (self .model )
164167
165168        # ForCausalLM modifications 
166-         self .lm_head  =  ParallelLMHead (config .vocab_size ,
169+         self .lm_head  =  ParallelLMHead (self .vocab_size ,
167170                                      config .hidden_size ,
168171                                      quant_config = self .quant_config ,
169172                                      prefix = maybe_prefix (prefix , "lm_head" ))
@@ -172,7 +175,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
172175
173176        logit_scale  =  getattr (config , "logit_scale" , 1.0 )
174177        self .logits_processor  =  LogitsProcessor (self .unpadded_vocab_size ,
175-                                                 config .vocab_size , logit_scale )
178+                                                 self .vocab_size , logit_scale )
176179        self .sampler  =  get_sampler ()
177180
178181    def  apply_base_model_tp_plan (self , module : nn .Module , prefix : str  =  "" ):
@@ -203,12 +206,12 @@ def replace_vocab_embed_class(self, module: nn.Module):
203206        new_module  =  VocabParallelEmbedding (
204207            self .vocab_size ,
205208            self .config .hidden_size ,
206-             org_num_embeddings = self .config . vocab_size ,
209+             org_num_embeddings = self .vocab_size ,
207210            quant_config = None ,
208211        )
209212        log_replacement ("input embedding" , self .model .get_input_embeddings (),
210213                        new_module )
211-         self . model .set_input_embeddings (new_module )
214+         module .set_input_embeddings (new_module )
212215
213216    def  forward (
214217        self ,
0 commit comments