From 418ed3d51234251e28fceaadd50081ca7cab3e08 Mon Sep 17 00:00:00 2001 From: rjg-lyh <1318825571@qq.com> Date: Wed, 14 May 2025 10:35:10 +0800 Subject: [PATCH 1/2] [BugFix] Fix chunked prefill bugs Signed-off-by: rjg-lyh <1318825571@qq.com> --- vllm_ascend/worker/model_runner_v1.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 18037439c4..89c2348c83 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -120,6 +120,13 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): self.max_num_tokens = self.scheduler_config.max_num_batched_tokens self.max_num_reqs = self.scheduler_config.max_num_seqs + additional_config = vllm_config.additional_config + if additional_config and additional_config.get( + "ascend_scheduler_config", None) is not None: + self.use_v0_scheduler = True + else: + self.use_v0_scheduler = False + self.graph_block_tables = np.zeros( (self.vllm_config.scheduler_config.max_num_seqs, (self.model_config.max_model_len + self.block_size - 1) // @@ -545,13 +552,14 @@ def _process_reqs( block_offsets, out=self.slot_mapping_np[:total_num_scheduled_tokens]) - if self.chunked_prefill_enabled: - attn_state = AscendAttentionState.ChunkedPrefill - elif np.array_equal(self.seq_lens_np[:num_reqs], num_scheduled_tokens): + if np.array_equal(self.seq_lens_np[:num_reqs], num_scheduled_tokens): attn_state = AscendAttentionState.PrefillNoCache # We assume it is the decode stage, where prefill occurs but only one token is not hit in cache. elif np.all(num_scheduled_tokens == 1): attn_state = AscendAttentionState.DecodeOnly + # splitfuse + elif not self.use_v0_scheduler or self.chunked_prefill_enabled: + attn_state = AscendAttentionState.ChunkedPrefill else: attn_state = AscendAttentionState.PrefillCacheHit From f0dac908bd49250da946c302f14d98b6a772f04f Mon Sep 17 00:00:00 2001 From: rjg-lyh <1318825571@qq.com> Date: Wed, 14 May 2025 16:05:19 +0800 Subject: [PATCH 2/2] allow deepseek models to enable chunked prefill on NPUs Signed-off-by: rjg-lyh <1318825571@qq.com> --- vllm_ascend/platform.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index 2d8834b1b7..28cda892a1 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -204,6 +204,9 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: "ascend_scheduler_config", None) is not None: additional_scheduler_config = additional_config.get( "ascend_scheduler_config") + if vllm_config.scheduler_config.enable_chunked_prefill: + additional_scheduler_config[ + "enable_chunked_prefill"] = True from vllm_ascend.core.schedule_config import \ AscendSchedulerConfig ascend_scheduler_config = AscendSchedulerConfig.initialize_from_config(