Skip to content

Commit 418ed3d

Browse files
committed
[BugFix] Fix chunked prefill bugs
Signed-off-by: rjg-lyh <1318825571@qq.com>
1 parent d5401a0 commit 418ed3d

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

vllm_ascend/worker/model_runner_v1.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,13 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
120120
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
121121
self.max_num_reqs = self.scheduler_config.max_num_seqs
122122

123+
additional_config = vllm_config.additional_config
124+
if additional_config and additional_config.get(
125+
"ascend_scheduler_config", None) is not None:
126+
self.use_v0_scheduler = True
127+
else:
128+
self.use_v0_scheduler = False
129+
123130
self.graph_block_tables = np.zeros(
124131
(self.vllm_config.scheduler_config.max_num_seqs,
125132
(self.model_config.max_model_len + self.block_size - 1) //
@@ -545,13 +552,14 @@ def _process_reqs(
545552
block_offsets,
546553
out=self.slot_mapping_np[:total_num_scheduled_tokens])
547554

548-
if self.chunked_prefill_enabled:
549-
attn_state = AscendAttentionState.ChunkedPrefill
550-
elif np.array_equal(self.seq_lens_np[:num_reqs], num_scheduled_tokens):
555+
if np.array_equal(self.seq_lens_np[:num_reqs], num_scheduled_tokens):
551556
attn_state = AscendAttentionState.PrefillNoCache
552557
# We assume it is the decode stage, where prefill occurs but only one token is not hit in cache.
553558
elif np.all(num_scheduled_tokens == 1):
554559
attn_state = AscendAttentionState.DecodeOnly
560+
# splitfuse
561+
elif not self.use_v0_scheduler or self.chunked_prefill_enabled:
562+
attn_state = AscendAttentionState.ChunkedPrefill
555563
else:
556564
attn_state = AscendAttentionState.PrefillCacheHit
557565

0 commit comments

Comments
 (0)