Skip to content

Commit 3b313b7

Browse files
committed
[CI] fix
Signed-off-by: zzhx1 <zzh_201018@outlook.com>
1 parent 531544a commit 3b313b7

File tree

6 files changed

+32
-32
lines changed

6 files changed

+32
-32
lines changed

vllm_ascend/ascend_config.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,15 +58,14 @@ def __init__(self, vllm_config):
5858
vllm_config.parallel_config.tensor_parallel_size == 1
5959
), "lmhead_tensor_parallel_size is only supported in the pure DP scenario"
6060
assert (
61-
self.torchair_graph_config.enabled == True
61+
self.torchair_graph_config.enabled
6262
), "lmhead_tensor_parallel_size is only supported in graph mode"
6363

6464
self.enable_shared_expert_dp = additional_config.get(
6565
"enable_shared_expert_dp", True
6666
) and not self.torchair_graph_config.enabled and vllm_config.parallel_config.enable_expert_parallel
6767

6868

69-
7069
class TorchairGraphConfig:
7170
"""
7271
Configuration Object for torchair_graph_config from additional_config

vllm_ascend/models/deepseek_mtp.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,20 +25,22 @@
2525
from vllm.attention.backends.abstract import AttentionMetadata
2626
from vllm.config import CacheConfig, ModelConfig, VllmConfig
2727
from vllm.model_executor.layers.layernorm import RMSNorm
28-
from vllm.model_executor.layers.logits_processor import LogitsProcessor
2928
from vllm.model_executor.layers.quantization import QuantizationConfig
3029
from vllm.model_executor.layers.sampler import get_sampler
31-
from vllm.model_executor.layers.vocab_parallel_embedding import (
32-
ParallelLMHead, VocabParallelEmbedding)
30+
from vllm.model_executor.layers.vocab_parallel_embedding import \
31+
VocabParallelEmbedding
3332
from vllm.model_executor.models.deepseek_mtp import (
3433
DeepSeekMTP, DeepSeekMultiTokenPredictor, DeepSeekMultiTokenPredictorLayer,
3534
SharedHead)
3635
from vllm.model_executor.models.utils import maybe_prefix
3736
from vllm.model_executor.sampling_metadata import SamplingMetadata
3837
from vllm.sequence import IntermediateTensors
3938

39+
from vllm_ascend.ops.vocab_parallel_embedding import (CustomLogitsProcessor,
40+
CustomParallelLMHead)
41+
4042
from .deepseek_v2 import CustomDeepseekV2DecoderLayer
41-
from vllm_ascend.ops.vocab_parallel_embedding import CustomLogitsProcessor, CustomParallelLMHead
43+
4244

4345
class CustomDeepSeekShareHead(SharedHead):
4446

@@ -49,9 +51,9 @@ def __init__(self,
4951
nn.Module.__init__(self)
5052
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
5153
self.head = CustomParallelLMHead(config.vocab_size,
52-
config.hidden_size,
53-
quant_config=quant_config,
54-
prefix=maybe_prefix(prefix, "head"))
54+
config.hidden_size,
55+
quant_config=quant_config,
56+
prefix=maybe_prefix(prefix, "head"))
5557

5658

5759
class CustomDeepSeekMultiTokenPredictorLayer(DeepSeekMultiTokenPredictorLayer):

vllm_ascend/ops/vocab_parallel_embedding.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,6 @@ def _get_logits(
269269
else:
270270
gathered_hidden_states = hidden_states
271271

272-
# Compute logits using quantized matrix multiplication
273272
local_logits = lm_head.quant_method.apply(lm_head,
274273
gathered_hidden_states,
275274
bias=embedding_bias)

vllm_ascend/utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -512,6 +512,4 @@ def get_ascend_soc_version():
512512

513513

514514
def _enable_lmhead_tp() -> bool:
515-
if get_ascend_config().lmhead_tensor_parallel_size is not None:
516-
return True
517-
return False
515+
return get_ascend_config().lmhead_tensor_parallel_size is not None

vllm_ascend/worker/model_runner_v1.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1320,14 +1320,14 @@ def _process_reqs(
13201320
if self.use_aux_hidden_state_outputs:
13211321
hidden_states, aux_hidden_states = hidden_states
13221322

1323-
if _enable_lmhead_tp(): #
1323+
if _enable_lmhead_tp():
13241324
if not with_prefill:
13251325
max_num_reqs_across_dp = padded_num_tokens_across_dp
13261326
else:
13271327
max_num_reqs_across_dp = self.max_num_reqs
1328-
sample_indices = nn.functional.pad(
1329-
sample_indices,
1330-
(0, max_num_reqs_across_dp - sample_indices.shape[0]))
1328+
logits_indices = nn.functional.pad(
1329+
logits_indices,
1330+
(0, max_num_reqs_across_dp - logits_indices.shape[0]))
13311331

13321332
return (attn_metadata, hidden_states, spec_decode_metadata, positions,
13331333
total_num_scheduled_tokens, logits_indices, aux_hidden_states,
@@ -1656,14 +1656,14 @@ def execute_model(
16561656
# Sample the next token and get logprobs if needed.
16571657
sampling_metadata = self.input_batch.sampling_metadata
16581658
if spec_decode_metadata is None:
1659-
if _enable_lmhead_tp():
1659+
if _enable_lmhead_tp() and logits is not None:
16601660
logits = logits[:self.input_batch.num_reqs]
16611661
sampler_output = self.sampler(
16621662
logits=logits,
16631663
sampling_metadata=sampling_metadata,
16641664
)
16651665
else:
1666-
if _enable_lmhead_tp():
1666+
if _enable_lmhead_tp() and logits is not None:
16671667
logits = logits[:len(spec_decode_metadata.logits_indices)]
16681668
# When indexing with a tensor (bonus_logits_indices), PyTorch
16691669
# creates a new tensor with separate storage from the original
@@ -1952,16 +1952,16 @@ def _dummy_run(
19521952
with_prefill, is_torchair_compile, input_ids, positions,
19531953
attn_metadata, num_tokens, intermediate_tensors,
19541954
inputs_embeds)
1955-
1955+
19561956
if _enable_lmhead_tp() and not self.in_profile_run:
1957-
if not with_prefill:
1958-
max_num_reqs_across_dp = num_reqs
1959-
else:
1960-
max_num_reqs_across_dp = max_num_reqs
1961-
dummy_indices = torch.zeros(max_num_reqs_across_dp,
1962-
device=hidden_states.device,
1963-
dtype=torch.int32)
1964-
model.compute_logits(hidden_states[dummy_indices], None)
1957+
if not with_prefill:
1958+
max_num_reqs_across_dp = num_reqs
1959+
else:
1960+
max_num_reqs_across_dp = max_num_reqs
1961+
dummy_indices = torch.zeros(max_num_reqs_across_dp,
1962+
device=hidden_states.device,
1963+
dtype=torch.int32)
1964+
self.model.compute_logits(hidden_states[dummy_indices], None)
19651965

19661966
if self.speculative_config and self.speculative_config.method == "deepseek_mtp":
19671967
assert isinstance(self.drafter, MtpProposer)
@@ -1979,7 +1979,8 @@ def _dummy_run(
19791979
dummy_indices = torch.zeros(max_num_reqs_across_dp,
19801980
device=hidden_states.device,
19811981
dtype=torch.int32)
1982-
model.compute_logits(hidden_states[dummy_indices], None)
1982+
self.model.compute_logits(hidden_states[dummy_indices],
1983+
None)
19831984

19841985
return hidden_states
19851986

vllm_ascend/worker/mtp_proposer_v1.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from vllm_ascend.models.deepseek_mtp import CustomDeepSeekMTP
2020
from vllm_ascend.utils import ProfileExecuteDuration, _enable_lmhead_tp
2121

22+
2223
class MtpProposer:
2324

2425
def __init__(
@@ -224,16 +225,16 @@ def propose(
224225
previous_hidden_states=self.
225226
hidden_states[:num_input_tokens],
226227
kv_caches=self.runner.kv_caches[-1:])
227-
228+
228229
num_indices = last_token_indices.shape[0]
229230
if _enable_lmhead_tp():
230231
if not self.runner.with_prefill:
231232
max_num_reqs_across_dp = num_input_tokens
232233
else:
233234
max_num_reqs_across_dp = self.vllm_config.scheduler_config.max_num_seqs
234235
last_token_indices = nn.functional.pad(
235-
last_token_indices, (0, max_num_reqs_across_dp - num_indices))
236-
236+
last_token_indices, (0, max_num_reqs_across_dp - num_indices))
237+
237238
sample_hidden_states = hidden_states[last_token_indices]
238239
logits = self.model.compute_logits(sample_hidden_states, None)
239240
if _enable_lmhead_tp() and num_indices < logits.shape[0]:

0 commit comments

Comments
 (0)