Skip to content

Commit fc1d9a3

Browse files
committed
[main][bugfix] disable the chunked prefill feature in Non-MLA LLMs
Signed-off-by: rjg-lyh <1318825571@qq.com>
1 parent fc2bcbe commit fc1d9a3

File tree

2 files changed

+30
-5
lines changed

2 files changed

+30
-5
lines changed

vllm_ascend/core/schedule_config.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
class AscendSchedulerConfig(SchedulerConfig):
2626
enable_chunked_prefill: bool = False
2727
policy: str = "fcfs"
28-
num_scheduler_steps: int = 1
2928
scheduler_cls: Union[str, Type[object]] = (
3029
"vllm_ascend.core.scheduler.AscendScheduler")
3130
enable_pd_transfer: bool = False
@@ -44,7 +43,6 @@ def initialize_from_config(
4443
# Override default values into original SchedulerConfig
4544
scheduler_config["enable_chunked_prefill"] = False
4645
scheduler_config["policy"] = "fcfs"
47-
scheduler_config["num_scheduler_steps"] = 1
4846
scheduler_config["scheduler_cls"] = (
4947
"vllm_ascend.core.scheduler.AscendScheduler")
5048
scheduler_config["enable_pd_transfer"] = False
@@ -76,9 +74,6 @@ def __post_init__(self) -> None:
7674
if self.is_multimodal_model:
7775
raise NotImplementedError(
7876
"currently AscendScheduler only supports LLM models.")
79-
if self.num_scheduler_steps > 1:
80-
raise NotImplementedError(
81-
"currently AscendScheduler doesn't support multi-step.")
8277
if self.send_delta_data:
8378
raise NotImplementedError(
8479
"currently AscendScheduler doesn't support send_delta_data.")

vllm_ascend/platform.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,36 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
128128
model_config = vllm_config.model_config
129129
parallel_config = vllm_config.parallel_config
130130
cache_config = vllm_config.cache_config
131+
decoding_config = vllm_config.decoding_config
132+
scheduler_config = vllm_config.scheduler_config
133+
ascend_scheduler_config = ascend_config.ascend_scheduler_config
134+
135+
if model_config is not None and not model_config.use_mla:
136+
logger.info(
137+
"Non-MLA LLMs forcibly disable the chunked prefill feature,"
138+
"as the performance of operators supporting this feature "
139+
"functionality is currently suboptimal.")
140+
if not model_config.is_multimodal_model and \
141+
decoding_config.backend == "auto" and \
142+
not scheduler_config.delay_factor > 0 and \
143+
not scheduler_config.send_delta_data and \
144+
scheduler_config.policy == "fcfs":
145+
ascend_scheduler_config.enabled = True
146+
chunked_prefill_enabled_in_ascend_scheduler = False
147+
if hasattr(ascend_scheduler_config, "enable_chunked_prefill") and \
148+
ascend_scheduler_config.enable_chunked_prefill is True:
149+
chunked_prefill_enabled_in_ascend_scheduler = True
150+
logger.warning(
151+
"Chunked prefill feature is enabled in ascend_scheduler,"
152+
"but note that the operator supporting this feature "
153+
"would lead to performance degradation.")
154+
# In this situation, max_num_batched_tokens would have been rewritten.
155+
# So we must make sure max_num_batched_tokens is not smaller than max_model_len.
156+
if (scheduler_config.max_num_batched_tokens
157+
< scheduler_config.max_model_len
158+
and not chunked_prefill_enabled_in_ascend_scheduler):
159+
scheduler_config.max_num_batched_tokens = scheduler_config.max_model_len
160+
131161
kv_cache_dtype = vllm_config.additional_config.get(
132162
"kv_cache_dtype", None)
133163
if kv_cache_dtype is not None:

0 commit comments

Comments
 (0)