diff --git a/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py b/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py index fc4928f9ebd1..6fcff0d62045 100644 --- a/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py +++ b/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py @@ -340,3 +340,84 @@ def test_full_block_prompt(): output = outputs[0] assert output.finish_reason == FinishReason.STOP assert_scheduler_empty(scheduler) + + +def test_cannot_schedule_after_recv(): + """ + Test that we can handle no schedule after recv due to not + enough remaining KV blocks. + """ + + # NOTE: the KVCacheManager will use 1 null block. + # So there are 5 total working blocks. + TOTAL_NUM_BLOCKS = 6 + vllm_config = create_vllm_config() + scheduler = create_scheduler(vllm_config, num_blocks=TOTAL_NUM_BLOCKS) + + # Prime the KVCache. + NUM_PROMPT_BLOCKS = 2 + BLOCK_SIZE = vllm_config.cache_config.block_size + # Prompt will use 2 blocks + 1 block after we schedule. + NUM_TOKENS_LOCAL = int(BLOCK_SIZE * NUM_PROMPT_BLOCKS) + NUM_TOKENS_REMOTE = int(BLOCK_SIZE * (NUM_PROMPT_BLOCKS + 0.5)) + + request_normal = create_request(request_id=1, num_tokens=NUM_TOKENS_LOCAL) + request_remote = create_request(request_id=2, + num_tokens=NUM_TOKENS_REMOTE, + do_remote_prefill=True) + + # STEP 1: 3 blocks are in use (2 for prompt, 1 for decode). + scheduler.add_request(request_normal) + scheduler_output = scheduler.schedule() + model_runner_output = create_model_runner_output(reqs=[request_normal]) + scheduler.update_from_output(scheduler_output, model_runner_output) + assert len(scheduler.running) == 1 + assert len(scheduler.waiting) == 0 + + # Step 2: 5 blocks are in use (2 new for remote blocks). + scheduler.add_request(request_remote) + scheduler_output = scheduler.schedule() + model_runner_output = create_model_runner_output(reqs=[request_normal]) + scheduler.update_from_output(scheduler_output, model_runner_output) + assert len(scheduler.running) == 1 + assert len(scheduler.waiting) == 1 + + # Step 3: finish recving (5 blocks in use) + scheduler_output = scheduler.schedule() + model_runner_output = create_model_runner_output( + reqs=[request_normal], finished_recving=[request_remote.request_id]) + scheduler.update_from_output(scheduler_output, model_runner_output) + assert len(scheduler.running) == 1 + assert len(scheduler.waiting) == 1 + + # Step 4: try to schedule, not enough blocks. + scheduler_output = scheduler.schedule() + model_runner_output = create_model_runner_output(reqs=[request_normal]) + scheduler.update_from_output(scheduler_output, model_runner_output) + assert len(scheduler.running) == 1 + assert len(scheduler.waiting) == 1 + + # Step 5: finish the request, free it. + scheduler_output = scheduler.schedule() + model_runner_output = create_model_runner_output(reqs=[request_normal], + use_eos=True) + scheduler.update_from_output(scheduler_output, model_runner_output) + assert len(scheduler.running) == 0 + assert len(scheduler.waiting) == 1 + + # Step 6: now we can schedule (with 2 blocks computed). + scheduler_output = scheduler.schedule() + model_runner_output = create_model_runner_output(reqs=[request_remote]) + assert (scheduler_output.scheduled_new_reqs[0].num_computed_tokens == + NUM_PROMPT_BLOCKS * BLOCK_SIZE) + scheduler.update_from_output(scheduler_output, model_runner_output) + assert len(scheduler.running) == 1 + assert len(scheduler.waiting) == 0 + + # Step 7: free everything. + scheduler_output = scheduler.schedule() + model_runner_output = create_model_runner_output(reqs=[request_remote], + use_eos=True) + scheduler.update_from_output(scheduler_output, model_runner_output) + _ = scheduler.schedule() + assert_scheduler_empty(scheduler) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index c873ced343bf..fd3b7d540766 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -310,15 +310,16 @@ def schedule(self) -> SchedulerOutput: break request = self.waiting[0] - num_prealloc_computed_tokens = 0 - # P/D: skip request if still waiting for remote kvs. + + # KVTransfer: skip request if still waiting for remote kvs. if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: is_ready = self._update_waiting_for_remote_kv(request) if is_ready: request.status = RequestStatus.WAITING - num_prealloc_computed_tokens = ( - request.num_computed_tokens) else: + logger.debug( + "%s is still in WAITING_FOR_REMOTE_KVS state.", + request.request_id) self.waiting.popleft() skipped_waiting_requests.appendleft(request) continue @@ -349,8 +350,9 @@ def schedule(self) -> SchedulerOutput: load_kv_async = False # Get already-cached tokens. - if num_prealloc_computed_tokens == 0: - new_computed_blocks, num_native_computed_tokens = \ + if request.num_computed_tokens == 0: + # Get locally-cached tokens. + new_computed_blocks, num_new_local_computed_tokens = \ self.kv_cache_manager.get_computed_blocks( request) @@ -358,23 +360,22 @@ def schedule(self) -> SchedulerOutput: if self.connector is not None: num_external_computed_tokens, load_kv_async = ( self.connector.get_num_new_matched_tokens( - request, num_native_computed_tokens)) + request, num_new_local_computed_tokens)) # Total computed tokens (local + external). - num_computed_tokens = (num_native_computed_tokens + + num_computed_tokens = (num_new_local_computed_tokens + num_external_computed_tokens) + # KVTransfer: WAITING reqs have num_computed_tokens > 0 + # after async KV recvs are completed. else: - # P/D: skip checking prefix cache if loaded from remote kvs. new_computed_blocks = KVCacheBlocks.create_empty() - num_native_computed_tokens = 0 - - # Total computed tokens (allocated in prior step). - num_computed_tokens = num_prealloc_computed_tokens + num_new_local_computed_tokens = 0 + num_computed_tokens = request.num_computed_tokens encoder_inputs_to_schedule = None new_encoder_budget = encoder_budget - # P/D: loading remote KV, do not allocate for new work. + # KVTransfer: loading remote KV, do not allocate for new work. if load_kv_async: assert num_external_computed_tokens > 0 num_new_tokens = 0 @@ -405,7 +406,7 @@ def schedule(self) -> SchedulerOutput: new_blocks = self.kv_cache_manager.allocate_slots( request, num_new_tokens + num_external_computed_tokens, - num_native_computed_tokens, + num_new_local_computed_tokens, new_computed_blocks, num_lookahead_tokens=self.num_lookahead_tokens, delay_cache_blocks=load_kv_async,