Skip to content

Commit 70dc428

Browse files
committed
allow deepseek models to enable chunked prefill on NPUs
Signed-off-by: rjg-lyh <1318825571@qq.com>
1 parent 6725d90 commit 70dc428

File tree

3 files changed

+153
-0
lines changed

3 files changed

+153
-0
lines changed

vllm_ascend/patch/platform/patch_common/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,4 @@
1616
#
1717

1818
import vllm_ascend.patch.platform.patch_common.patch_distributed # noqa
19+
import vllm_ascend.patch.platform.patch_common.patch_vllm_config # noqa
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# Copyright 2023 The vLLM team.
4+
#
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
# Adapted from vllm/config.py
18+
# This file is a part of the vllm-ascend project.
19+
20+
import torch
21+
22+
from vllm.logger import init_logger
23+
from vllm.config import (VllmConfig, CompilationConfig, CompilationLevel)
24+
from vllm.utils import random_uuid
25+
import vllm.envs as envs
26+
27+
logger = init_logger(__name__)
28+
29+
30+
def __post_init__(self):
31+
"""Verify configs are valid & consistent with each other.
32+
"""
33+
if self.model_config is not None:
34+
self.model_config.verify_async_output_proc(self.parallel_config,
35+
self.speculative_config,
36+
self.device_config)
37+
self.model_config.verify_with_parallel_config(self.parallel_config)
38+
self.model_config.verify_dual_chunk_attention_config(self.load_config)
39+
40+
if self.cache_config is not None:
41+
self.cache_config.verify_with_parallel_config(self.parallel_config)
42+
43+
if self.lora_config:
44+
self.lora_config.verify_with_cache_config(self.cache_config)
45+
self.lora_config.verify_with_model_config(self.model_config)
46+
self.lora_config.verify_lora_support()
47+
if self.prompt_adapter_config:
48+
self.prompt_adapter_config.verify_with_model_config(self.model_config)
49+
50+
if self.quant_config is None and \
51+
self.model_config is not None and self.load_config is not None:
52+
self.quant_config = VllmConfig._get_quantization_config(
53+
self.model_config, self.load_config)
54+
55+
from vllm.platforms import current_platform
56+
if self.scheduler_config is not None and \
57+
self.model_config is not None and \
58+
self.scheduler_config.chunked_prefill_enabled and \
59+
self.model_config.dtype == torch.float32 and \
60+
current_platform.get_device_capability() == (7, 5):
61+
logger.warning_once(
62+
"Turing devices tensor cores do not support float32 matmul. "
63+
"To workaround this limitation, vLLM will set 'ieee' input "
64+
"precision for chunked prefill triton kernels.")
65+
66+
if self.compilation_config is None:
67+
self.compilation_config = CompilationConfig()
68+
if self.compilation_config.pass_config.enable_sequence_parallelism:
69+
self.compilation_config.custom_ops.append("+rms_norm")
70+
if envs.VLLM_USE_V1 and self.model_config is not None and \
71+
not self.model_config.enforce_eager:
72+
# NOTE(woosuk): Currently, we use inductor because the piecewise
73+
# CUDA graphs do not work properly with the custom CUDA kernels.
74+
# FIXME(woosuk): Disable inductor to reduce the compilation time
75+
# and avoid any potential issues with the inductor.
76+
# FIXME(rob): Add function to set all of these.
77+
if not self.compilation_config.custom_ops:
78+
self.compilation_config.custom_ops = ["none"]
79+
self.compilation_config.use_cudagraph = True
80+
self.compilation_config.use_inductor = True
81+
self.compilation_config.cudagraph_num_of_warmups = 1
82+
self.compilation_config.pass_config.enable_fusion = False
83+
self.compilation_config.pass_config.enable_noop = False
84+
self.compilation_config.level = CompilationLevel.PIECEWISE
85+
self.compilation_config.set_splitting_ops_for_v1()
86+
87+
if self.parallel_config is not None and \
88+
self.parallel_config.tensor_parallel_size > 1 and \
89+
self.parallel_config.pipeline_parallel_size > 1 and \
90+
self.compilation_config is not None and \
91+
self.compilation_config.pass_config is not None and \
92+
self.compilation_config.pass_config.enable_sequence_parallelism:
93+
logger.warning_once(
94+
"Sequence parallelism is not supported with pipeline "
95+
"parallelism. Disabling sequence parallelism.")
96+
self.compilation_config.pass_config.\
97+
enable_sequence_parallelism = False
98+
99+
self._set_cudagraph_sizes()
100+
101+
if self.cache_config is not None and \
102+
self.cache_config.cpu_offload_gb > 0 and \
103+
self.compilation_config.level != CompilationLevel.NO_COMPILATION \
104+
and not envs.VLLM_USE_V1:
105+
logger.warning(
106+
"CPU offload is not supported with `torch.compile` in v0 yet."
107+
" Disabling `torch.compile`.")
108+
self.compilation_config.level = CompilationLevel.NO_COMPILATION
109+
110+
if ((not envs.VLLM_USE_V1) and self.lora_config is not None and
111+
self.compilation_config.level != CompilationLevel.NO_COMPILATION):
112+
logger.warning(
113+
"LoRA for V0 is not supported with `torch.compile` yet. "
114+
"Disabling `torch.compile`.")
115+
self.compilation_config.level = CompilationLevel.NO_COMPILATION
116+
117+
if self.compilation_config.full_cuda_graph and \
118+
not self.model_config.disable_cascade_attn:
119+
logger.warning_once("full_cuda_graph is not supported with "
120+
"cascade attention. Disabling cascade attention.")
121+
self.model_config.disable_cascade_attn = True
122+
123+
if self.model_config and self.model_config.use_mla and \
124+
not (current_platform.is_cuda() or current_platform.is_rocm()):
125+
logger.info(
126+
"MLA is enabled on a non-GPU and NPU platform; just forcing "
127+
"prefix caching to be disabled.")
128+
129+
if self.cache_config is not None:
130+
self.cache_config.enable_prefix_caching = False
131+
132+
if (self.kv_events_config and self.kv_events_config.enable_kv_cache_events
133+
and not self.cache_config.enable_prefix_caching):
134+
logger.warning(
135+
"KV cache events are on, but prefix caching is not enabled."
136+
"Use --enable-prefix-caching to enable.")
137+
if (self.kv_events_config and self.kv_events_config.publisher != "null"
138+
and not self.kv_events_config.enable_kv_cache_events):
139+
logger.warning("KV cache events are disabled,"
140+
"but the scheduler is configured to publish them."
141+
"Modify KVEventsConfig.enable_kv_cache_events"
142+
"to True to enable.")
143+
current_platform.check_and_update_config(self)
144+
145+
if not self.instance_id:
146+
self.instance_id = random_uuid()[:5]
147+
148+
149+
VllmConfig.__post_init__ = __post_init__

vllm_ascend/platform.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,9 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
195195
"ascend_scheduler_config", None) is not None:
196196
additional_scheduler_config = additional_config.get(
197197
"ascend_scheduler_config")
198+
if vllm_config.scheduler_config.enable_chunked_prefill:
199+
additional_scheduler_config[
200+
"enable_chunked_prefill"] = True
198201
from vllm_ascend.core.schedule_config import \
199202
AscendSchedulerConfig
200203
ascend_scheduler_config = AscendSchedulerConfig.initialize_from_config(

0 commit comments

Comments
 (0)