Skip to content

Commit 780ff6f

Browse files
committed
[CI] Fix
Signed-off-by: zzhx1 <zzh_201018@outlook.com>
1 parent d54c2f1 commit 780ff6f

File tree

6 files changed

+66
-81
lines changed

6 files changed

+66
-81
lines changed

vllm_ascend/ascend_config.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,13 @@ def __init__(self, vllm_config):
5050
self.lmhead_tensor_parallel_size = additional_config.get(
5151
"lmhead_tensor_parallel_size", None)
5252
if self.lmhead_tensor_parallel_size is not None:
53-
logger.info(f"Enable lmhead_tensor_parallel_size={self.lmhead_tensor_parallel_size} in pure DP scenario")
54-
assert(
53+
logger.info(
54+
f"Enable lmhead_tensor_parallel_size={self.lmhead_tensor_parallel_size} in pure DP scenario"
55+
)
56+
assert (
5557
vllm_config.parallel_config.tensor_parallel_size == 1
56-
),"lmhead_tensor_parallel_size is only supported in the pure DP scenario"
57-
assert(
58+
), "lmhead_tensor_parallel_size is only supported in the pure DP scenario"
59+
assert (
5860
self.torchair_graph_config.enabled == True
5961
), "lmhead_tensor_parallel_size is only supported in graph mode"
6062

vllm_ascend/distributed/parallel_state.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,18 @@
99
_MC2: Optional[GroupCoordinator] = None
1010
_LMTP: Optional[GroupCoordinator] = None
1111

12+
1213
def get_mc2_group() -> GroupCoordinator:
1314
assert _MC2 is not None, ("mc2 group is not initialized")
1415
return _MC2
1516

17+
1618
def get_lmheadtp_group() -> GroupCoordinator:
1719
assert _LMTP is not None, (
1820
"lm head tensor parallel group is not initialized")
1921
return _LMTP
2022

23+
2124
def model_parallel_initialized():
2225
return (_MC2 is not None)
2326

@@ -43,22 +46,23 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
4346
get_world_group().local_rank,
4447
backend,
4548
group_name="mc2")
46-
49+
4750
lmhead_tensor_parallel_size = parallel_config.lmhead_tensor_parallel_size
4851
if lmhead_tensor_parallel_size is not None:
4952
group_ranks = []
5053
global _LMTP
5154
num_lmhead_tensor_parallel_groups: int = (world_size //
52-
lmhead_tensor_parallel_size)
55+
lmhead_tensor_parallel_size)
5356
for i in range(num_lmhead_tensor_parallel_groups):
5457
ranks = list(
5558
range(i * lmhead_tensor_parallel_size,
56-
(i + 1) * lmhead_tensor_parallel_size))
59+
(i + 1) * lmhead_tensor_parallel_size))
5760
group_ranks.append(ranks)
5861
_LMTP = init_model_parallel_group(group_ranks,
59-
get_world_group().local_rank,
60-
backend,
61-
group_name="lmheadtp")
62+
get_world_group().local_rank,
63+
backend,
64+
group_name="lmheadtp")
65+
6266

6367
def destroy_ascend_model_parallel():
6468
global _MC2

vllm_ascend/models/deepseek_v2.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,11 @@
4949
ReplicatedLinear,
5050
RowParallelLinear,
5151
UnquantizedLinearMethod)
52-
from vllm.model_executor.layers.logits_processor import LogitsProcessor
5352
from vllm.model_executor.layers.quantization import QuantizationConfig
5453
from vllm.model_executor.layers.rotary_embedding import get_rope
5554
from vllm.model_executor.layers.sampler import get_sampler
56-
from vllm.model_executor.layers.vocab_parallel_embedding import (
57-
ParallelLMHead, VocabParallelEmbedding)
55+
from vllm.model_executor.layers.vocab_parallel_embedding import \
56+
VocabParallelEmbedding
5857
from vllm.model_executor.model_loader.weight_utils import (
5958
default_weight_loader, maybe_remap_kv_scale_name)
6059
from vllm.model_executor.models.deepseek_v2 import \
@@ -68,14 +67,15 @@
6867
PPMissingLayer, is_pp_missing_parameter,
6968
make_empty_intermediate_tensors_factory, make_layers, maybe_prefix)
7069
from vllm.sequence import IntermediateTensors
71-
from vllm.model_executor.sampling_metadata import SamplingMetadata
7270

7371
from vllm_ascend.ascend_config import get_ascend_config
7472
from vllm_ascend.ops.fused_moe import AscendFusedMoE
73+
from vllm_ascend.ops.vocab_parallel_embedding import (CustomLogitsProcessor,
74+
CustomParallelLMHead)
7575
from vllm_ascend.quantization.quant_config import AscendLinearMethod
7676
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
7777
from vllm_ascend.utils import dispose_tensor, npu_prefetch
78-
from vllm_ascend.ops.vocab_parallel_embedding import CustomParallelLMHead, CustomLogitsProcessor
78+
7979

8080
class CustomDeepseekV2SiluAndMul(SiluAndMul):
8181

@@ -872,7 +872,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
872872
config.hidden_size,
873873
quant_config=quant_config,
874874
prefix=maybe_prefix(
875-
prefix, "lm_head"))
875+
prefix, "lm_head"))
876876
else:
877877
self.lm_head = PPMissingLayer()
878878
self.logits_processor = CustomLogitsProcessor(config.vocab_size)

vllm_ascend/ops/vocab_parallel_embedding.py

Lines changed: 36 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -19,40 +19,24 @@
1919

2020
import torch
2121
from torch.nn import Module
22-
import torch.distributed as dist
23-
from torch.nn.parameter import Parameter, UninitializedParameter
24-
25-
from vllm.distributed import (
26-
divide,
27-
get_tensor_model_parallel_rank,
28-
get_tensor_model_parallel_world_size,
29-
tensor_model_parallel_all_reduce
30-
)
31-
from vllm.model_executor.layers.vocab_parallel_embedding import (
32-
VocabParallelEmbedding,
33-
DEFAULT_VOCAB_PADDING_SIZE,
34-
pad_vocab_size,
35-
UnquantizedEmbeddingMethod,
36-
ParallelLMHead
37-
)
22+
from torch.nn.parameter import Parameter
23+
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
24+
get_tensor_model_parallel_world_size,
25+
tensor_model_parallel_all_reduce)
3826
from vllm.model_executor.layers.logits_processor import (
39-
LogitsProcessor,
40-
_apply_logits_processors,
41-
_prune_hidden_states
42-
)
43-
from vllm.model_executor.parameter import BasevLLMParameter
44-
from vllm.model_executor.utils import set_weight_attrs
45-
from vllm.model_executor.sampling_metadata import SamplingMetadata
27+
LogitsProcessor, _apply_logits_processors, _prune_hidden_states)
4628
from vllm.model_executor.layers.quantization.base_config import (
47-
QuantizationConfig,
48-
QuantizeMethodBase,
49-
method_has_implemented_embedding
50-
)
29+
QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding)
30+
from vllm.model_executor.layers.vocab_parallel_embedding import (
31+
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, UnquantizedEmbeddingMethod,
32+
VocabParallelEmbedding, pad_vocab_size)
33+
from vllm.model_executor.sampling_metadata import SamplingMetadata
34+
from vllm.model_executor.utils import set_weight_attrs
5135

5236
from vllm_ascend.distributed.parallel_state import get_lmheadtp_group
53-
from vllm_ascend.ascend_config import get_ascend_config
5437
from vllm_ascend.utils import _enable_lmhead_tp
5538

39+
5640
def get_masked_input_and_mask(
5741
input_: torch.Tensor, org_vocab_start_index: int,
5842
org_vocab_end_index: int, num_org_vocab_padding: int,
@@ -105,7 +89,6 @@ def vocab_parallel_embedding_forward(self, input_):
10589

10690

10791
class CustomParallelLMHead(ParallelLMHead):
108-
10992
"""Costom Parallelized LM head, added the feature of lmheadTP in pure dp scenario
11093
11194
Output logits weight matrices used in the Sampler. The weight and bias
@@ -120,6 +103,7 @@ class CustomParallelLMHead(ParallelLMHead):
120103
org_num_embeddings: original vocabulary size (without LoRA).
121104
padding_size: padding size for the vocabulary.
122105
"""
106+
123107
def __init__(self,
124108
num_embeddings: int,
125109
embedding_dim: int,
@@ -128,7 +112,7 @@ def __init__(self,
128112
org_num_embeddings: Optional[int] = None,
129113
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
130114
quant_config: Optional[QuantizationConfig] = None,
131-
prefix: str = ""):
115+
prefix: str = ""):
132116
Module.__init__(self)
133117

134118
if _enable_lmhead_tp():
@@ -137,7 +121,7 @@ def __init__(self,
137121
else:
138122
tp_rank = get_tensor_model_parallel_rank()
139123
self.tp_size = get_tensor_model_parallel_world_size()
140-
124+
141125
self.num_embeddings = num_embeddings
142126
self.padding_size = padding_size
143127
self.org_vocab_size = org_num_embeddings or num_embeddings
@@ -197,7 +181,7 @@ def __init__(self,
197181
self.num_embeddings_padded,
198182
params_dtype=params_dtype,
199183
weight_loader=self.weight_loader)
200-
184+
201185
self.quant_config = quant_config
202186
if bias:
203187
self.bias = Parameter(
@@ -209,25 +193,24 @@ def __init__(self,
209193
})
210194
else:
211195
self.register_parameter("bias", None)
212-
196+
197+
213198
class CustomLogitsProcessor(LogitsProcessor):
214199
"""Custom logits processor extending base LogitsProcessor functionality.
215200
Added the feature of lmheadTP in pure dp scenario
216201
"""
217-
202+
218203
def __init__(self,
219204
vocab_size: int,
220205
org_vocab_size: Optional[int] = None,
221206
scale: float = 1.0,
222207
logits_as_input: bool = False,
223208
soft_cap: Optional[float] = None) -> None:
224-
super().__init__(
225-
vocab_size=vocab_size,
226-
org_vocab_size=org_vocab_size,
227-
scale=scale,
228-
logits_as_input=logits_as_input,
229-
soft_cap=soft_cap
230-
)
209+
super().__init__(vocab_size=vocab_size,
210+
org_vocab_size=org_vocab_size,
211+
scale=scale,
212+
logits_as_input=logits_as_input,
213+
soft_cap=soft_cap)
231214

232215
def forward(
233216
self,
@@ -258,15 +241,15 @@ def forward(
258241
if sampling_metadata is not None and \
259242
sampling_metadata.seq_groups is not None:
260243
logits = _apply_logits_processors(logits, sampling_metadata)
261-
244+
262245
return logits
263246

264247
def _get_logits(
265-
self,
266-
hidden_states: torch.Tensor,
267-
lm_head: CustomParallelLMHead,
268-
embedding_bias: Optional[torch.Tensor],
269-
) -> Optional[torch.Tensor]:
248+
self,
249+
hidden_states: torch.Tensor,
250+
lm_head: CustomParallelLMHead,
251+
embedding_bias: Optional[torch.Tensor],
252+
) -> Optional[torch.Tensor]:
270253
"""
271254
Compute logits for next token prediction using parallel processing.
272255
@@ -281,16 +264,15 @@ def _get_logits(
281264

282265
if _enable_lmhead_tp():
283266
# Gather hidden states from all devices in tensor parallel group
284-
gathered_hidden_states = get_lmheadtp_group().all_gather(hidden_states, dim=0)
267+
gathered_hidden_states = get_lmheadtp_group().all_gather(
268+
hidden_states, dim=0)
285269
else:
286270
gathered_hidden_states = hidden_states
287271

288272
# Compute logits using quantized matrix multiplication
289-
local_logits = lm_head.quant_method.apply(
290-
lm_head,
291-
gathered_hidden_states,
292-
bias=embedding_bias
293-
)
273+
local_logits = lm_head.quant_method.apply(lm_head,
274+
gathered_hidden_states,
275+
bias=embedding_bias)
294276

295277
if _enable_lmhead_tp():
296278
logits = get_lmheadtp_group().all_to_all(local_logits)
@@ -301,6 +283,5 @@ def _get_logits(
301283
# Remove paddings in vocab (if any)
302284
if logits is not None:
303285
logits = logits[..., :self.org_vocab_size]
304-
286+
305287
return logits
306-

vllm_ascend/platform.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
134134
if parallel_config:
135135
# assign lmhead tensor parallel size
136136
parallel_config.lmhead_tensor_parallel_size = (
137-
ascend_config.lmhead_tensor_parallel_size
138-
)
137+
ascend_config.lmhead_tensor_parallel_size)
139138

140139
if model_config is None:
141140
logger.warning("Model config is missing. This may indicate "

vllm_ascend/worker/model_runner_v1.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -85,14 +85,13 @@
8585
from vllm_ascend.torchair.utils import (check_torchair_cache_exist,
8686
write_kv_cache_bytes_to_file)
8787
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
88-
ProfileExecuteDuration, is_310p,
89-
maybe_converting_weight_acl_format,
90-
vllm_version_is, _enable_lmhead_tp)
88+
ProfileExecuteDuration, _enable_lmhead_tp,
89+
is_310p, maybe_converting_weight_acl_format,
90+
vllm_version_is)
9191
from vllm_ascend.worker.eagle_proposer_v1 import EagleProposer
9292
from vllm_ascend.worker.mtp_proposer_v1 import MtpProposer
9393
from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch
9494

95-
9695
if not vllm_version_is("0.10.0"):
9796
from vllm.tasks import GenerationTask, SupportedTask
9897
from vllm.v1.worker.kv_connector_model_runner_mixin import \
@@ -1334,8 +1333,8 @@ def _process_reqs(
13341333
aux_hidden_states = None
13351334
if self.use_aux_hidden_state_outputs:
13361335
hidden_states, aux_hidden_states = hidden_states
1337-
1338-
if _enable_lmhead_tp(): #
1336+
1337+
if _enable_lmhead_tp(): #
13391338
if not with_prefill:
13401339
max_num_reqs_across_dp = padded_num_tokens_across_dp
13411340
else:
@@ -1998,7 +1997,7 @@ def _dummy_run(
19981997
if self.use_spec_decode and isinstance(
19991998
self.drafter, EagleProposer):
20001999
self.drafter.dummy_run(num_tokens)
2001-
2000+
20022001
if _enable_lmhead_tp() and not self.in_profile_run:
20032002
if not with_prefill:
20042003
max_num_reqs_across_dp = num_reqs
@@ -2008,7 +2007,7 @@ def _dummy_run(
20082007
device=hidden_states.device,
20092008
dtype=torch.int32)
20102009
model.compute_logits(hidden_states[dummy_indices], None)
2011-
2010+
20122011
if self.speculative_config and self.speculative_config.method == "deepseek_mtp":
20132012
assert isinstance(self.drafter, MtpProposer)
20142013
self.drafter.dummy_run(

0 commit comments

Comments
 (0)