1919
2020import torch
2121from torch .nn import Module
22- import torch .distributed as dist
23- from torch .nn .parameter import Parameter , UninitializedParameter
24-
25- from vllm .distributed import (
26- divide ,
27- get_tensor_model_parallel_rank ,
28- get_tensor_model_parallel_world_size ,
29- tensor_model_parallel_all_reduce
30- )
31- from vllm .model_executor .layers .vocab_parallel_embedding import (
32- VocabParallelEmbedding ,
33- DEFAULT_VOCAB_PADDING_SIZE ,
34- pad_vocab_size ,
35- UnquantizedEmbeddingMethod ,
36- ParallelLMHead
37- )
22+ from torch .nn .parameter import Parameter
23+ from vllm .distributed import (divide , get_tensor_model_parallel_rank ,
24+ get_tensor_model_parallel_world_size ,
25+ tensor_model_parallel_all_reduce )
3826from vllm .model_executor .layers .logits_processor import (
39- LogitsProcessor ,
40- _apply_logits_processors ,
41- _prune_hidden_states
42- )
43- from vllm .model_executor .parameter import BasevLLMParameter
44- from vllm .model_executor .utils import set_weight_attrs , _enable_lmhead_tp
45- from vllm .model_executor .sampling_metadata import SamplingMetadata
27+ LogitsProcessor , _apply_logits_processors , _prune_hidden_states )
4628from vllm .model_executor .layers .quantization .base_config import (
47- QuantizationConfig ,
48- QuantizeMethodBase ,
49- method_has_implemented_embedding
50- )
29+ QuantizationConfig , QuantizeMethodBase , method_has_implemented_embedding )
30+ from vllm .model_executor .layers .vocab_parallel_embedding import (
31+ DEFAULT_VOCAB_PADDING_SIZE , ParallelLMHead , UnquantizedEmbeddingMethod ,
32+ VocabParallelEmbedding , pad_vocab_size )
33+ from vllm .model_executor .sampling_metadata import SamplingMetadata
34+ from vllm .model_executor .utils import set_weight_attrs
5135
5236from vllm_ascend .distributed .parallel_state import get_lmheadtp_group
53- from vllm_ascend .ascend_config import get_ascend_config
37+ from vllm_ascend .utils import lmhead_tp_enable
5438
5539
5640def get_masked_input_and_mask (
@@ -105,8 +89,7 @@ def vocab_parallel_embedding_forward(self, input_):
10589
10690
10791class CustomParallelLMHead (ParallelLMHead ):
108-
109- """Costom Parallelized LM head, added the feature of lmheadTP in pure dp scenario
92+ """Custom Parallelized LM head, added the feature of lmheadTP in pure dp scenario
11093
11194 Output logits weight matrices used in the Sampler. The weight and bias
11295 tensors are padded to make sure they are divisible by the number of
@@ -120,6 +103,7 @@ class CustomParallelLMHead(ParallelLMHead):
120103 org_num_embeddings: original vocabulary size (without LoRA).
121104 padding_size: padding size for the vocabulary.
122105 """
106+
123107 def __init__ (self ,
124108 num_embeddings : int ,
125109 embedding_dim : int ,
@@ -128,16 +112,16 @@ def __init__(self,
128112 org_num_embeddings : Optional [int ] = None ,
129113 padding_size : int = DEFAULT_VOCAB_PADDING_SIZE ,
130114 quant_config : Optional [QuantizationConfig ] = None ,
131- prefix : str = "" ):
115+ prefix : str = "" ):
132116 Module .__init__ (self )
133117
134- if _enable_lmhead_tp ():
118+ if lmhead_tp_enable ():
135119 tp_rank = get_lmheadtp_group ().rank_in_group
136120 self .tp_size = get_lmheadtp_group ().world_size
137121 else :
138122 tp_rank = get_tensor_model_parallel_rank ()
139123 self .tp_size = get_tensor_model_parallel_world_size ()
140-
124+
141125 self .num_embeddings = num_embeddings
142126 self .padding_size = padding_size
143127 self .org_vocab_size = org_num_embeddings or num_embeddings
@@ -197,7 +181,7 @@ def __init__(self,
197181 self .num_embeddings_padded ,
198182 params_dtype = params_dtype ,
199183 weight_loader = self .weight_loader )
200-
184+
201185 self .quant_config = quant_config
202186 if bias :
203187 self .bias = Parameter (
@@ -209,90 +193,31 @@ def __init__(self,
209193 })
210194 else :
211195 self .register_parameter ("bias" , None )
212-
196+
197+
213198class CustomLogitsProcessor (LogitsProcessor ):
214199 """Custom logits processor extending base LogitsProcessor functionality.
215200 Added the feature of lmheadTP in pure dp scenario
216201 """
217-
218- def __init__ (self ,
219- vocab_size : int ,
220- org_vocab_size : Optional [int ] = None ,
221- scale : float = 1.0 ,
222- logits_as_input : bool = False ,
223- soft_cap : Optional [float ] = None ) -> None :
224- super ().__init__ (
225- vocab_size = vocab_size ,
226- org_vocab_size = org_vocab_size ,
227- scale = scale ,
228- logits_as_input = logits_as_input ,
229- soft_cap = soft_cap
230- )
231-
232- def forward (
202+ def _get_logits (
233203 self ,
234- lm_head : CustomParallelLMHead ,
235204 hidden_states : torch .Tensor ,
236- sampling_metadata : Optional [ SamplingMetadata ] = None ,
237- embedding_bias : Optional [torch .Tensor ] = None ,
205+ lm_head : CustomParallelLMHead ,
206+ embedding_bias : Optional [torch .Tensor ],
238207 ) -> Optional [torch .Tensor ]:
239- if self .logits_as_input :
240- logits = hidden_states
241- else :
242- if sampling_metadata is not None :
243- hidden_states = _prune_hidden_states (hidden_states ,
244- sampling_metadata )
245-
246- # Get the logits for the next tokens.
247- logits = self ._get_logits (hidden_states , lm_head , embedding_bias )
248- if logits is not None :
249- if self .soft_cap is not None :
250- logits = logits / self .soft_cap
251- logits = torch .tanh (logits )
252- logits = logits * self .soft_cap
253-
254- if self .scale != 1.0 :
255- logits *= self .scale
256208
257- # Apply logits processors (if any).
258- if sampling_metadata is not None and \
259- sampling_metadata .seq_groups is not None :
260- logits = _apply_logits_processors (logits , sampling_metadata )
261-
262- return logits
263-
264- def _get_logits (
265- self ,
266- hidden_states : torch .Tensor ,
267- lm_head : CustomParallelLMHead ,
268- embedding_bias : Optional [torch .Tensor ],
269- ) -> Optional [torch .Tensor ]:
270- """
271- Compute logits for next token prediction using parallel processing.
272-
273- Args:
274- hidden_states: Current hidden states from the model with shape [batch_size, hidden_size]
275- lm_head: Parallel embedding layer for vocabulary predictions
276- embedding_bias: Optional bias tensor to add to logits with shape [vocab_size]
277-
278- Returns:
279- Logits tensor for next token prediction with shape [batch_size, vocab_size] or None
280- """
281-
282- if _enable_lmhead_tp ():
209+ if lmhead_tp_enable ():
283210 # Gather hidden states from all devices in tensor parallel group
284- gathered_hidden_states = get_lmheadtp_group ().all_gather (hidden_states , dim = 0 )
211+ gathered_hidden_states = get_lmheadtp_group ().all_gather (
212+ hidden_states , dim = 0 )
285213 else :
286214 gathered_hidden_states = hidden_states
287215
288- # Compute logits using quantized matrix multiplication
289- local_logits = lm_head .quant_method .apply (
290- lm_head ,
291- gathered_hidden_states ,
292- bias = embedding_bias
293- )
216+ local_logits = lm_head .quant_method .apply (lm_head ,
217+ gathered_hidden_states ,
218+ bias = embedding_bias )
294219
295- if _enable_lmhead_tp ():
220+ if lmhead_tp_enable ():
296221 logits = get_lmheadtp_group ().all_to_all (local_logits )
297222 else :
298223 # Gather logits for tensor parallel
@@ -301,6 +226,5 @@ def _get_logits(
301226 # Remove paddings in vocab (if any)
302227 if logits is not None :
303228 logits = logits [..., :self .org_vocab_size ]
304-
229+
305230 return logits
306-
0 commit comments