1919
2020import torch
2121from torch .nn import Module
22- import torch .distributed as dist
23- from torch .nn .parameter import Parameter , UninitializedParameter
24-
25- from vllm .distributed import (
26- divide ,
27- get_tensor_model_parallel_rank ,
28- get_tensor_model_parallel_world_size ,
29- tensor_model_parallel_all_reduce
30- )
31- from vllm .model_executor .layers .vocab_parallel_embedding import (
32- VocabParallelEmbedding ,
33- DEFAULT_VOCAB_PADDING_SIZE ,
34- pad_vocab_size ,
35- UnquantizedEmbeddingMethod ,
36- ParallelLMHead
37- )
38- from vllm .model_executor .layers .logits_processor import (
39- LogitsProcessor ,
40- _apply_logits_processors ,
41- _prune_hidden_states
42- )
43- from vllm .model_executor .parameter import BasevLLMParameter
44- from vllm .model_executor .utils import set_weight_attrs , _enable_lmhead_tp
45- from vllm .model_executor .sampling_metadata import SamplingMetadata
22+ from torch .nn .parameter import Parameter
23+ from vllm .distributed import (divide , get_tensor_model_parallel_rank ,
24+ get_tensor_model_parallel_world_size ,
25+ tensor_model_parallel_all_reduce )
26+ from vllm .model_executor .layers .logits_processor import LogitsProcessor
4627from vllm .model_executor .layers .quantization .base_config import (
47- QuantizationConfig ,
48- QuantizeMethodBase ,
49- method_has_implemented_embedding
50- )
28+ QuantizationConfig , QuantizeMethodBase , method_has_implemented_embedding )
29+ from vllm .model_executor .layers .vocab_parallel_embedding import (
30+ DEFAULT_VOCAB_PADDING_SIZE , ParallelLMHead , UnquantizedEmbeddingMethod ,
31+ VocabParallelEmbedding , pad_vocab_size )
32+ from vllm .model_executor .utils import set_weight_attrs
5133
5234from vllm_ascend .distributed .parallel_state import get_lmheadtp_group
53- from vllm_ascend .ascend_config import get_ascend_config
35+ from vllm_ascend .utils import lmhead_tp_enable
5436
5537
5638def get_masked_input_and_mask (
@@ -105,8 +87,7 @@ def vocab_parallel_embedding_forward(self, input_):
10587
10688
10789class CustomParallelLMHead (ParallelLMHead ):
108-
109- """Costom Parallelized LM head, added the feature of lmheadTP in pure dp scenario
90+ """Custom Parallelized LM head, added the feature of lmheadTP in pure dp scenario
11091
11192 Output logits weight matrices used in the Sampler. The weight and bias
11293 tensors are padded to make sure they are divisible by the number of
@@ -120,6 +101,7 @@ class CustomParallelLMHead(ParallelLMHead):
120101 org_num_embeddings: original vocabulary size (without LoRA).
121102 padding_size: padding size for the vocabulary.
122103 """
104+
123105 def __init__ (self ,
124106 num_embeddings : int ,
125107 embedding_dim : int ,
@@ -128,16 +110,16 @@ def __init__(self,
128110 org_num_embeddings : Optional [int ] = None ,
129111 padding_size : int = DEFAULT_VOCAB_PADDING_SIZE ,
130112 quant_config : Optional [QuantizationConfig ] = None ,
131- prefix : str = "" ):
113+ prefix : str = "" ):
132114 Module .__init__ (self )
133115
134- if _enable_lmhead_tp ():
116+ if lmhead_tp_enable ():
135117 tp_rank = get_lmheadtp_group ().rank_in_group
136118 self .tp_size = get_lmheadtp_group ().world_size
137119 else :
138120 tp_rank = get_tensor_model_parallel_rank ()
139121 self .tp_size = get_tensor_model_parallel_world_size ()
140-
122+
141123 self .num_embeddings = num_embeddings
142124 self .padding_size = padding_size
143125 self .org_vocab_size = org_num_embeddings or num_embeddings
@@ -197,7 +179,7 @@ def __init__(self,
197179 self .num_embeddings_padded ,
198180 params_dtype = params_dtype ,
199181 weight_loader = self .weight_loader )
200-
182+
201183 self .quant_config = quant_config
202184 if bias :
203185 self .bias = Parameter (
@@ -209,90 +191,32 @@ def __init__(self,
209191 })
210192 else :
211193 self .register_parameter ("bias" , None )
212-
194+
195+
213196class CustomLogitsProcessor (LogitsProcessor ):
214197 """Custom logits processor extending base LogitsProcessor functionality.
215198 Added the feature of lmheadTP in pure dp scenario
216199 """
217-
218- def __init__ (self ,
219- vocab_size : int ,
220- org_vocab_size : Optional [int ] = None ,
221- scale : float = 1.0 ,
222- logits_as_input : bool = False ,
223- soft_cap : Optional [float ] = None ) -> None :
224- super ().__init__ (
225- vocab_size = vocab_size ,
226- org_vocab_size = org_vocab_size ,
227- scale = scale ,
228- logits_as_input = logits_as_input ,
229- soft_cap = soft_cap
230- )
231200
232- def forward (
201+ def _get_logits (
233202 self ,
234- lm_head : CustomParallelLMHead ,
235203 hidden_states : torch .Tensor ,
236- sampling_metadata : Optional [ SamplingMetadata ] = None ,
237- embedding_bias : Optional [torch .Tensor ] = None ,
204+ lm_head : CustomParallelLMHead ,
205+ embedding_bias : Optional [torch .Tensor ],
238206 ) -> Optional [torch .Tensor ]:
239- if self .logits_as_input :
240- logits = hidden_states
241- else :
242- if sampling_metadata is not None :
243- hidden_states = _prune_hidden_states (hidden_states ,
244- sampling_metadata )
245-
246- # Get the logits for the next tokens.
247- logits = self ._get_logits (hidden_states , lm_head , embedding_bias )
248- if logits is not None :
249- if self .soft_cap is not None :
250- logits = logits / self .soft_cap
251- logits = torch .tanh (logits )
252- logits = logits * self .soft_cap
253-
254- if self .scale != 1.0 :
255- logits *= self .scale
256207
257- # Apply logits processors (if any).
258- if sampling_metadata is not None and \
259- sampling_metadata .seq_groups is not None :
260- logits = _apply_logits_processors (logits , sampling_metadata )
261-
262- return logits
263-
264- def _get_logits (
265- self ,
266- hidden_states : torch .Tensor ,
267- lm_head : CustomParallelLMHead ,
268- embedding_bias : Optional [torch .Tensor ],
269- ) -> Optional [torch .Tensor ]:
270- """
271- Compute logits for next token prediction using parallel processing.
272-
273- Args:
274- hidden_states: Current hidden states from the model with shape [batch_size, hidden_size]
275- lm_head: Parallel embedding layer for vocabulary predictions
276- embedding_bias: Optional bias tensor to add to logits with shape [vocab_size]
277-
278- Returns:
279- Logits tensor for next token prediction with shape [batch_size, vocab_size] or None
280- """
281-
282- if _enable_lmhead_tp ():
208+ if lmhead_tp_enable ():
283209 # Gather hidden states from all devices in tensor parallel group
284- gathered_hidden_states = get_lmheadtp_group ().all_gather (hidden_states , dim = 0 )
210+ gathered_hidden_states = get_lmheadtp_group ().all_gather (
211+ hidden_states , dim = 0 )
285212 else :
286213 gathered_hidden_states = hidden_states
287214
288- # Compute logits using quantized matrix multiplication
289- local_logits = lm_head .quant_method .apply (
290- lm_head ,
291- gathered_hidden_states ,
292- bias = embedding_bias
293- )
215+ local_logits = lm_head .quant_method .apply (lm_head ,
216+ gathered_hidden_states ,
217+ bias = embedding_bias )
294218
295- if _enable_lmhead_tp ():
219+ if lmhead_tp_enable ():
296220 logits = get_lmheadtp_group ().all_to_all (local_logits )
297221 else :
298222 # Gather logits for tensor parallel
@@ -301,6 +225,5 @@ def _get_logits(
301225 # Remove paddings in vocab (if any)
302226 if logits is not None :
303227 logits = logits [..., :self .org_vocab_size ]
304-
228+
305229 return logits
306-
0 commit comments