Skip to content

Commit e5e9548

Browse files
committed
allow deepseek models to enable chunked prefill on NPUs
Signed-off-by: rjg-lyh <1318825571@qq.com>
1 parent 6725d90 commit e5e9548

File tree

3 files changed

+153
-0
lines changed

3 files changed

+153
-0
lines changed

vllm_ascend/patch/platform/patch_main/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,5 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616
#
17+
18+
import vllm_ascend.patch.platform.patch_main.patch_vllm_config # noqa
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# Copyright 2023 The vLLM team.
4+
#
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
# Adapted from vllm/config.py
18+
# This file is a part of the vllm-ascend project.
19+
20+
import torch
21+
import vllm.envs as envs
22+
from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
23+
from vllm.logger import init_logger
24+
from vllm.utils import random_uuid
25+
26+
logger = init_logger(__name__)
27+
28+
29+
def __post_init__(self):
30+
"""Verify configs are valid & consistent with each other.
31+
"""
32+
if self.model_config is not None:
33+
self.model_config.verify_async_output_proc(self.parallel_config,
34+
self.speculative_config,
35+
self.device_config)
36+
self.model_config.verify_with_parallel_config(self.parallel_config)
37+
self.model_config.verify_dual_chunk_attention_config(self.load_config)
38+
39+
if self.cache_config is not None:
40+
self.cache_config.verify_with_parallel_config(self.parallel_config)
41+
42+
if self.lora_config:
43+
self.lora_config.verify_with_cache_config(self.cache_config)
44+
self.lora_config.verify_with_model_config(self.model_config)
45+
self.lora_config.verify_lora_support()
46+
if self.prompt_adapter_config:
47+
self.prompt_adapter_config.verify_with_model_config(self.model_config)
48+
49+
if self.quant_config is None and \
50+
self.model_config is not None and self.load_config is not None:
51+
self.quant_config = VllmConfig._get_quantization_config(
52+
self.model_config, self.load_config)
53+
54+
from vllm.platforms import current_platform
55+
if self.scheduler_config is not None and \
56+
self.model_config is not None and \
57+
self.scheduler_config.chunked_prefill_enabled and \
58+
self.model_config.dtype == torch.float32 and \
59+
current_platform.get_device_capability() == (7, 5):
60+
logger.warning_once(
61+
"Turing devices tensor cores do not support float32 matmul. "
62+
"To workaround this limitation, vLLM will set 'ieee' input "
63+
"precision for chunked prefill triton kernels.")
64+
65+
if self.compilation_config is None:
66+
self.compilation_config = CompilationConfig()
67+
if self.compilation_config.pass_config.enable_sequence_parallelism:
68+
self.compilation_config.custom_ops.append("+rms_norm")
69+
if envs.VLLM_USE_V1 and self.model_config is not None and \
70+
not self.model_config.enforce_eager:
71+
# NOTE(woosuk): Currently, we use inductor because the piecewise
72+
# CUDA graphs do not work properly with the custom CUDA kernels.
73+
# FIXME(woosuk): Disable inductor to reduce the compilation time
74+
# and avoid any potential issues with the inductor.
75+
# FIXME(rob): Add function to set all of these.
76+
if not self.compilation_config.custom_ops:
77+
self.compilation_config.custom_ops = ["none"]
78+
self.compilation_config.use_cudagraph = True
79+
self.compilation_config.use_inductor = True
80+
self.compilation_config.cudagraph_num_of_warmups = 1
81+
self.compilation_config.pass_config.enable_fusion = False
82+
self.compilation_config.pass_config.enable_noop = False
83+
self.compilation_config.level = CompilationLevel.PIECEWISE
84+
self.compilation_config.set_splitting_ops_for_v1()
85+
86+
if self.parallel_config is not None and \
87+
self.parallel_config.tensor_parallel_size > 1 and \
88+
self.parallel_config.pipeline_parallel_size > 1 and \
89+
self.compilation_config is not None and \
90+
self.compilation_config.pass_config is not None and \
91+
self.compilation_config.pass_config.enable_sequence_parallelism:
92+
logger.warning_once(
93+
"Sequence parallelism is not supported with pipeline "
94+
"parallelism. Disabling sequence parallelism.")
95+
self.compilation_config.pass_config.\
96+
enable_sequence_parallelism = False
97+
98+
self._set_cudagraph_sizes()
99+
100+
if self.cache_config is not None and \
101+
self.cache_config.cpu_offload_gb > 0 and \
102+
self.compilation_config.level != CompilationLevel.NO_COMPILATION \
103+
and not envs.VLLM_USE_V1:
104+
logger.warning(
105+
"CPU offload is not supported with `torch.compile` in v0 yet."
106+
" Disabling `torch.compile`.")
107+
self.compilation_config.level = CompilationLevel.NO_COMPILATION
108+
109+
if ((not envs.VLLM_USE_V1) and self.lora_config is not None and
110+
self.compilation_config.level != CompilationLevel.NO_COMPILATION):
111+
logger.warning(
112+
"LoRA for V0 is not supported with `torch.compile` yet. "
113+
"Disabling `torch.compile`.")
114+
self.compilation_config.level = CompilationLevel.NO_COMPILATION
115+
116+
if self.compilation_config.full_cuda_graph and \
117+
not self.model_config.disable_cascade_attn:
118+
logger.warning_once("full_cuda_graph is not supported with "
119+
"cascade attention. Disabling cascade attention.")
120+
self.model_config.disable_cascade_attn = True
121+
122+
if self.model_config and self.model_config.use_mla and \
123+
not (current_platform.is_cuda() or current_platform.is_rocm()):
124+
logger.info(
125+
"MLA is enabled on a non-GPU and NPU platform; just forcing "
126+
"prefix caching to be disabled.")
127+
128+
if self.cache_config is not None:
129+
self.cache_config.enable_prefix_caching = False
130+
131+
if (self.kv_events_config and self.kv_events_config.enable_kv_cache_events
132+
and not self.cache_config.enable_prefix_caching):
133+
logger.warning(
134+
"KV cache events are on, but prefix caching is not enabled."
135+
"Use --enable-prefix-caching to enable.")
136+
if (self.kv_events_config and self.kv_events_config.publisher != "null"
137+
and not self.kv_events_config.enable_kv_cache_events):
138+
logger.warning("KV cache events are disabled,"
139+
"but the scheduler is configured to publish them."
140+
"Modify KVEventsConfig.enable_kv_cache_events"
141+
"to True to enable.")
142+
current_platform.check_and_update_config(self)
143+
144+
if not self.instance_id:
145+
self.instance_id = random_uuid()[:5]
146+
147+
148+
VllmConfig.__post_init__ = __post_init__

vllm_ascend/platform.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,9 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
195195
"ascend_scheduler_config", None) is not None:
196196
additional_scheduler_config = additional_config.get(
197197
"ascend_scheduler_config")
198+
if vllm_config.scheduler_config.enable_chunked_prefill:
199+
additional_scheduler_config[
200+
"enable_chunked_prefill"] = True
198201
from vllm_ascend.core.schedule_config import \
199202
AscendSchedulerConfig
200203
ascend_scheduler_config = AscendSchedulerConfig.initialize_from_config(

0 commit comments

Comments
 (0)