@@ -118,6 +118,13 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
118118 self .max_num_tokens = self .scheduler_config .max_num_batched_tokens
119119 self .max_num_reqs = self .scheduler_config .max_num_seqs
120120
121+ additional_config = vllm_config .additional_config
122+ if additional_config and additional_config .get (
123+ "ascend_scheduler_config" , None ) is not None :
124+ self .use_v0_scheduler = True
125+ else :
126+ self .use_v0_scheduler = False
127+
121128 self .graph_block_tables = np .zeros (
122129 (self .vllm_config .scheduler_config .max_num_seqs ,
123130 (self .model_config .max_model_len + self .block_size - 1 ) //
@@ -569,13 +576,14 @@ def _process_reqs(
569576 block_offsets ,
570577 out = self .slot_mapping_np [:total_num_scheduled_tokens ])
571578
572- if self .chunked_prefill_enabled :
573- attn_state = AscendAttentionState .ChunkedPrefill
574- elif np .array_equal (self .seq_lens_np [:num_reqs ], num_scheduled_tokens ):
579+ if np .array_equal (self .seq_lens_np [:num_reqs ], num_scheduled_tokens ):
575580 attn_state = AscendAttentionState .PrefillNoCache
576581 # We assume it is the decode stage, where prefill occurs but only one token is not hit in cache.
577582 elif np .all (num_scheduled_tokens == 1 ):
578583 attn_state = AscendAttentionState .DecodeOnly
584+ # splitfuse
585+ elif not self .use_v0_scheduler or self .chunked_prefill_enabled :
586+ attn_state = AscendAttentionState .ChunkedPrefill
579587 else :
580588 attn_state = AscendAttentionState .PrefillCacheHit
581589
0 commit comments