1919
2020import torch
2121from torch .nn import Module
22- import torch .distributed as dist
23- from torch .nn .parameter import Parameter , UninitializedParameter
24-
25- from vllm .distributed import (
26- divide ,
27- get_tensor_model_parallel_rank ,
28- get_tensor_model_parallel_world_size ,
29- tensor_model_parallel_all_reduce
30- )
31- from vllm .model_executor .layers .vocab_parallel_embedding import (
32- VocabParallelEmbedding ,
33- DEFAULT_VOCAB_PADDING_SIZE ,
34- pad_vocab_size ,
35- UnquantizedEmbeddingMethod ,
36- ParallelLMHead
37- )
22+ from torch .nn .parameter import Parameter
23+ from vllm .distributed import (divide , get_tensor_model_parallel_rank ,
24+ get_tensor_model_parallel_world_size ,
25+ tensor_model_parallel_all_reduce )
3826from vllm .model_executor .layers .logits_processor import (
39- LogitsProcessor ,
40- _apply_logits_processors ,
41- _prune_hidden_states
42- )
43- from vllm .model_executor .parameter import BasevLLMParameter
44- from vllm .model_executor .utils import set_weight_attrs
45- from vllm .model_executor .sampling_metadata import SamplingMetadata
27+ LogitsProcessor , _apply_logits_processors , _prune_hidden_states )
4628from vllm .model_executor .layers .quantization .base_config import (
47- QuantizationConfig ,
48- QuantizeMethodBase ,
49- method_has_implemented_embedding
50- )
29+ QuantizationConfig , QuantizeMethodBase , method_has_implemented_embedding )
30+ from vllm .model_executor .layers .vocab_parallel_embedding import (
31+ DEFAULT_VOCAB_PADDING_SIZE , ParallelLMHead , UnquantizedEmbeddingMethod ,
32+ VocabParallelEmbedding , pad_vocab_size )
33+ from vllm .model_executor .sampling_metadata import SamplingMetadata
34+ from vllm .model_executor .utils import set_weight_attrs
5135
5236from vllm_ascend .distributed .parallel_state import get_lmheadtp_group
53- from vllm_ascend .ascend_config import get_ascend_config
5437from vllm_ascend .utils import _enable_lmhead_tp
5538
39+
5640def get_masked_input_and_mask (
5741 input_ : torch .Tensor , org_vocab_start_index : int ,
5842 org_vocab_end_index : int , num_org_vocab_padding : int ,
@@ -105,7 +89,6 @@ def vocab_parallel_embedding_forward(self, input_):
10589
10690
10791class CustomParallelLMHead (ParallelLMHead ):
108-
10992 """Costom Parallelized LM head, added the feature of lmheadTP in pure dp scenario
11093
11194 Output logits weight matrices used in the Sampler. The weight and bias
@@ -120,6 +103,7 @@ class CustomParallelLMHead(ParallelLMHead):
120103 org_num_embeddings: original vocabulary size (without LoRA).
121104 padding_size: padding size for the vocabulary.
122105 """
106+
123107 def __init__ (self ,
124108 num_embeddings : int ,
125109 embedding_dim : int ,
@@ -128,7 +112,7 @@ def __init__(self,
128112 org_num_embeddings : Optional [int ] = None ,
129113 padding_size : int = DEFAULT_VOCAB_PADDING_SIZE ,
130114 quant_config : Optional [QuantizationConfig ] = None ,
131- prefix : str = "" ):
115+ prefix : str = "" ):
132116 Module .__init__ (self )
133117
134118 if _enable_lmhead_tp ():
@@ -137,7 +121,7 @@ def __init__(self,
137121 else :
138122 tp_rank = get_tensor_model_parallel_rank ()
139123 self .tp_size = get_tensor_model_parallel_world_size ()
140-
124+
141125 self .num_embeddings = num_embeddings
142126 self .padding_size = padding_size
143127 self .org_vocab_size = org_num_embeddings or num_embeddings
@@ -197,7 +181,7 @@ def __init__(self,
197181 self .num_embeddings_padded ,
198182 params_dtype = params_dtype ,
199183 weight_loader = self .weight_loader )
200-
184+
201185 self .quant_config = quant_config
202186 if bias :
203187 self .bias = Parameter (
@@ -209,25 +193,24 @@ def __init__(self,
209193 })
210194 else :
211195 self .register_parameter ("bias" , None )
212-
196+
197+
213198class CustomLogitsProcessor (LogitsProcessor ):
214199 """Custom logits processor extending base LogitsProcessor functionality.
215200 Added the feature of lmheadTP in pure dp scenario
216201 """
217-
202+
218203 def __init__ (self ,
219204 vocab_size : int ,
220205 org_vocab_size : Optional [int ] = None ,
221206 scale : float = 1.0 ,
222207 logits_as_input : bool = False ,
223208 soft_cap : Optional [float ] = None ) -> None :
224- super ().__init__ (
225- vocab_size = vocab_size ,
226- org_vocab_size = org_vocab_size ,
227- scale = scale ,
228- logits_as_input = logits_as_input ,
229- soft_cap = soft_cap
230- )
209+ super ().__init__ (vocab_size = vocab_size ,
210+ org_vocab_size = org_vocab_size ,
211+ scale = scale ,
212+ logits_as_input = logits_as_input ,
213+ soft_cap = soft_cap )
231214
232215 def forward (
233216 self ,
@@ -258,15 +241,15 @@ def forward(
258241 if sampling_metadata is not None and \
259242 sampling_metadata .seq_groups is not None :
260243 logits = _apply_logits_processors (logits , sampling_metadata )
261-
244+
262245 return logits
263246
264247 def _get_logits (
265- self ,
266- hidden_states : torch .Tensor ,
267- lm_head : CustomParallelLMHead ,
268- embedding_bias : Optional [torch .Tensor ],
269- ) -> Optional [torch .Tensor ]:
248+ self ,
249+ hidden_states : torch .Tensor ,
250+ lm_head : CustomParallelLMHead ,
251+ embedding_bias : Optional [torch .Tensor ],
252+ ) -> Optional [torch .Tensor ]:
270253 """
271254 Compute logits for next token prediction using parallel processing.
272255
@@ -281,16 +264,15 @@ def _get_logits(
281264
282265 if _enable_lmhead_tp ():
283266 # Gather hidden states from all devices in tensor parallel group
284- gathered_hidden_states = get_lmheadtp_group ().all_gather (hidden_states , dim = 0 )
267+ gathered_hidden_states = get_lmheadtp_group ().all_gather (
268+ hidden_states , dim = 0 )
285269 else :
286270 gathered_hidden_states = hidden_states
287271
288272 # Compute logits using quantized matrix multiplication
289- local_logits = lm_head .quant_method .apply (
290- lm_head ,
291- gathered_hidden_states ,
292- bias = embedding_bias
293- )
273+ local_logits = lm_head .quant_method .apply (lm_head ,
274+ gathered_hidden_states ,
275+ bias = embedding_bias )
294276
295277 if _enable_lmhead_tp ():
296278 logits = get_lmheadtp_group ().all_to_all (local_logits )
@@ -301,6 +283,5 @@ def _get_logits(
301283 # Remove paddings in vocab (if any)
302284 if logits is not None :
303285 logits = logits [..., :self .org_vocab_size ]
304-
286+
305287 return logits
306-
0 commit comments