@@ -120,6 +120,13 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
120120 self .max_num_tokens = self .scheduler_config .max_num_batched_tokens
121121 self .max_num_reqs = self .scheduler_config .max_num_seqs
122122
123+ additional_config = vllm_config .additional_config
124+ if additional_config and additional_config .get (
125+ "ascend_scheduler_config" , None ) is not None :
126+ self .use_v0_scheduler = True
127+ else :
128+ self .use_v0_scheduler = False
129+
123130 self .graph_block_tables = np .zeros (
124131 (self .vllm_config .scheduler_config .max_num_seqs ,
125132 (self .model_config .max_model_len + self .block_size - 1 ) //
@@ -545,13 +552,14 @@ def _process_reqs(
545552 block_offsets ,
546553 out = self .slot_mapping_np [:total_num_scheduled_tokens ])
547554
548- if self .chunked_prefill_enabled :
549- attn_state = AscendAttentionState .ChunkedPrefill
550- elif np .array_equal (self .seq_lens_np [:num_reqs ], num_scheduled_tokens ):
555+ if np .array_equal (self .seq_lens_np [:num_reqs ], num_scheduled_tokens ):
551556 attn_state = AscendAttentionState .PrefillNoCache
552557 # We assume it is the decode stage, where prefill occurs but only one token is not hit in cache.
553558 elif np .all (num_scheduled_tokens == 1 ):
554559 attn_state = AscendAttentionState .DecodeOnly
560+ # splitfuse
561+ elif not self .use_v0_scheduler or self .chunked_prefill_enabled :
562+ attn_state = AscendAttentionState .ChunkedPrefill
555563 else :
556564 attn_state = AscendAttentionState .PrefillCacheHit
557565
0 commit comments