Skip to content

Commit 46b8b8f

Browse files
committed
[bugfix] main-sd-bugfix
Signed-off-by: mengwei805 <mengwei25@huawei.com>
1 parent 5442b46 commit 46b8b8f

File tree

8 files changed

+17
-30
lines changed

8 files changed

+17
-30
lines changed

vllm_ascend/attention/attention.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,8 @@ def get_splitfuse_attn_mask(
113113
self.update_attn_cache(max_seq_len, dtype, device)
114114
# FIXME: Currently the mask value of chunked-prefill situation and Prefill-Only situation
115115
# is not the same. Fix this in the future when kernel is ready.
116-
if self.attn_mask_cache[0][1] > 0:
116+
if self.attn_mask_cache.numel(
117+
) > 1 and self.attn_mask_cache[0][1] > 0:
117118
attn_mask = self.get_attn_mask( # type: ignore
118119
max_seq_len, dtype, device)
119120
attn_mask *= -10000

vllm_ascend/attention/mla_v1.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
77
AttentionMetadata,
88
MLAAttentionImpl)
9-
from vllm.logger import init_logger
109
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
1110
LinearBase, RowParallelLinear,
1211
UnquantizedLinearMethod)
@@ -21,8 +20,6 @@
2120
from vllm.v1.core.sched.output import SchedulerOutput
2221
from vllm.v1.worker.gpu_input_batch import InputBatch
2322

24-
logger = init_logger(__name__)
25-
2623

2724
class AscendMLABackend(AttentionBackend):
2825

vllm_ascend/core/scheduler.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,12 @@
1616
#
1717
from collections import deque
1818

19-
from vllm.logger import init_logger
19+
from vllm.logger import logger
2020
from vllm.utils import cdiv
2121
from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput
2222
from vllm.v1.core.sched.scheduler import Scheduler
2323
from vllm.v1.request import Request, RequestStatus
2424

25-
logger = init_logger(__name__)
26-
2725

2826
class AscendScheduler(Scheduler):
2927
"""This Scheduler extends vllm's original v1 scheduler

vllm_ascend/distributed/parallel_state.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ def init_ascend_model_parallel(
3636
expert_tensor_parallel_size)
3737

3838
global _EP
39-
assert _EP is None, ("expert parallel group is already initialized")
4039
group_ranks = []
4140
for i in range(num_expert_parallel_groups):
4241
ranks = list(range(i, world_size, num_expert_parallel_groups))
@@ -49,8 +48,6 @@ def init_ascend_model_parallel(
4948

5049
group_ranks = []
5150
global _ETP
52-
assert _ETP is None, (
53-
"expert tensor parallel group is already initialized")
5451
for i in range(num_expert_tensor_parallel_groups):
5552
ranks = list(
5653
range(i * expert_tensor_parallel_size,

vllm_ascend/models/deepseek_mtp.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#
22
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3-
# Adapted from vllm/model_executor/models/qwen2_vl.py
3+
# Adapted from vllm/model_executor/models/deepseek_mtp.py
44
# Copyright 2023 The vLLM team.
55
#
66
# This file is a part of the vllm-ascend project.
@@ -17,12 +17,11 @@
1717
# See the License for the specific language governing permissions and
1818
# limitations under the License.
1919

20-
from typing import List, Optional
20+
from typing import Optional
2121

2222
import torch
2323
import torch.nn as nn
2424
from transformers import PretrainedConfig
25-
from vllm.attention.backends.abstract import AttentionMetadata
2625
from vllm.config import CacheConfig, ModelConfig, VllmConfig
2726
from vllm.model_executor.layers.layernorm import RMSNorm
2827
from vllm.model_executor.layers.logits_processor import LogitsProcessor
@@ -70,8 +69,6 @@ def forward(
7069
self,
7170
input_ids: torch.Tensor,
7271
positions: torch.Tensor,
73-
kv_cache: torch.Tensor,
74-
attn_metadata: AttentionMetadata,
7572
previous_hidden_states: torch.Tensor,
7673
inputs_embeds: Optional[torch.Tensor] = None,
7774
spec_step_index: int = 0,
@@ -91,8 +88,6 @@ def forward(
9188

9289
hidden_states, residual = self.mtp_block(positions=positions,
9390
hidden_states=hidden_states,
94-
kv_cache=kv_cache,
95-
attn_metadata=attn_metadata,
9691
residual=None)
9792
hidden_states = residual + hidden_states
9893
return hidden_states
@@ -130,8 +125,6 @@ def forward(
130125
self,
131126
input_ids: torch.Tensor,
132127
positions: torch.Tensor,
133-
kv_caches: List[torch.Tensor],
134-
attn_metadata: AttentionMetadata,
135128
previous_hidden_states: torch.Tensor,
136129
inputs_embeds: Optional[torch.Tensor] = None,
137130
spec_step_idx: int = 0,
@@ -140,8 +133,6 @@ def forward(
140133
return self.layers_list[current_step_idx](
141134
input_ids,
142135
positions,
143-
kv_caches[current_step_idx],
144-
attn_metadata,
145136
previous_hidden_states,
146137
inputs_embeds,
147138
current_step_idx,
@@ -162,6 +153,14 @@ def compute_logits(
162153

163154

164155
class CustomDeepSeekMTP(DeepSeekMTP):
156+
# NOTE 1.The quantized MTP layer of deepseek on the NPU is not quantized;
157+
# NOTE 2.The description file generated by the current msmodelslim tool does not have
158+
# MTP layer info. Please manually add it and set the value to FLOAT.
159+
packed_modules_mapping = {
160+
"gate_up_proj": ["gate_proj", "up_proj"],
161+
"experts":
162+
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"]
163+
}
165164

166165
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
167166
nn.Module.__init__(self)

vllm_ascend/patch/worker/patch_common/patch_spec_decode_worker.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from typing import Any, Dict, Optional
1919

2020
from vllm.config import ParallelConfig
21-
from vllm.logger import init_logger
21+
from vllm.logger import logger
2222
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
2323
from vllm.model_executor.layers.spec_decode_base_sampler import \
2424
SpecDecodeBaseSampler
@@ -34,8 +34,6 @@
3434

3535
from vllm_ascend.worker.draft_model_runner import TP1DraftModelRunner
3636

37-
logger = init_logger(__name__)
38-
3937

4038
def create_worker(
4139
cls,

vllm_ascend/quantization/quant_config.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,6 @@
2323
from vllm.distributed import get_tensor_model_parallel_rank
2424
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
2525
FusedMoeWeightScaleSupported)
26-
from vllm.model_executor.layers.fused_moe.layer import \
27-
UnquantizedFusedMoEMethod
2826
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
2927
RowParallelLinear,
3028
UnquantizedLinearMethod)
@@ -36,6 +34,7 @@
3634
from vllm.model_executor.parameter import PerTensorScaleParameter
3735
from vllm.model_executor.utils import set_weight_attrs
3836

37+
from ..ops.fused_moe import AscendUnquantizedFusedMoEMethod
3938
from .quantizer import AscendQuantizer
4039

4140

@@ -97,7 +96,7 @@ def get_quant_method(self, layer: torch.nn.Module,
9796
elif isinstance(layer, FusedMoE):
9897
if self.is_layer_skipped_ascend(prefix,
9998
self.packed_modules_mapping):
100-
return UnquantizedFusedMoEMethod()
99+
return AscendUnquantizedFusedMoEMethod()
101100
return AscendFusedMoEMethod(self, prefix,
102101
self.packed_modules_mapping)
103102
return None

vllm_ascend/worker/draft_model_runner.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
import torch
2121
from vllm.forward_context import set_forward_context
22-
from vllm.logger import init_logger
22+
from vllm.logger import logger
2323
from vllm.model_executor.layers.sampler import SamplerOutput
2424
from vllm.multimodal import MultiModalKwargs
2525
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
@@ -29,8 +29,6 @@
2929

3030
from vllm_ascend.attention.attention import AscendMetadata
3131

32-
logger = init_logger(__name__)
33-
3432
# A flag to enable debug prints for the updated input tensors
3533
# before each step.
3634
debug_advance_input = False

0 commit comments

Comments
 (0)