@@ -362,7 +362,7 @@ def test_cannot_schedule_after_recv():
362362 BLOCK_SIZE = vllm_config .cache_config .block_size
363363 # Prompt will use 2 blocks + 1 block after we schedule.
364364 NUM_TOKENS_LOCAL = int (BLOCK_SIZE * NUM_PROMPT_BLOCKS )
365- NUM_TOKENS_REMOTE = int (BLOCK_SIZE * ( NUM_PROMPT_BLOCKS + 0.5 ) )
365+ NUM_TOKENS_REMOTE = int (BLOCK_SIZE * NUM_PROMPT_BLOCKS )
366366
367367 request_normal = create_request (request_id = 1 , num_tokens = NUM_TOKENS_LOCAL )
368368 request_remote = create_request (request_id = 2 ,
@@ -393,30 +393,124 @@ def test_cannot_schedule_after_recv():
393393 assert len (scheduler .running ) == 1
394394 assert len (scheduler .waiting ) == 1
395395
396- # Step 4: try to schedule, not enough blocks.
396+ # Step 4: try to schedule, remote request is put to running list
397+ # because the transfer is completed.
398+ scheduler_output = scheduler .schedule ()
399+ model_runner_output = create_model_runner_output (
400+ reqs = [request_normal , request_remote ])
401+ scheduler .update_from_output (scheduler_output , model_runner_output )
402+ assert len (scheduler .running ) == 2
403+ assert len (scheduler .waiting ) == 0
404+
405+ # Step 5: Remote request will be put back to waiting list
406+ # because it needs new block to hold generated token.
397407 scheduler_output = scheduler .schedule ()
398408 model_runner_output = create_model_runner_output (reqs = [request_normal ])
399409 scheduler .update_from_output (scheduler_output , model_runner_output )
400410 assert len (scheduler .running ) == 1
401411 assert len (scheduler .waiting ) == 1
402412
403- # Step 5 : finish the request, free it.
413+ # Step 6 : finish the request, free it.
404414 scheduler_output = scheduler .schedule ()
405415 model_runner_output = create_model_runner_output (reqs = [request_normal ],
406416 use_eos = True )
407417 scheduler .update_from_output (scheduler_output , model_runner_output )
408418 assert len (scheduler .running ) == 0
409419 assert len (scheduler .waiting ) == 1
410420
411- # Step 6: now we can schedule (with 2 blocks computed).
421+ # Step 7: now we can schedule (with 2 blocks computed),
422+ # request is retrieved from preempted list.
412423 scheduler_output = scheduler .schedule ()
413424 model_runner_output = create_model_runner_output (reqs = [request_remote ])
414- assert (scheduler_output .scheduled_new_reqs [0 ]. num_computed_tokens ==
425+ assert (scheduler_output .scheduled_cached_reqs . num_computed_tokens [0 ] ==
415426 NUM_PROMPT_BLOCKS * BLOCK_SIZE )
416427 scheduler .update_from_output (scheduler_output , model_runner_output )
417428 assert len (scheduler .running ) == 1
418429 assert len (scheduler .waiting ) == 0
419430
431+ # Step 8: free everything.
432+ scheduler_output = scheduler .schedule ()
433+ model_runner_output = create_model_runner_output (reqs = [request_remote ],
434+ use_eos = True )
435+ scheduler .update_from_output (scheduler_output , model_runner_output )
436+ _ = scheduler .schedule ()
437+ assert_scheduler_empty (scheduler )
438+
439+
440+ def test_cannot_recv ():
441+ """
442+ Test that we can handle no schedule KV block transfer due to not
443+ enough remaining KV blocks.
444+ """
445+
446+ # NOTE: the KVCacheManager will use 1 null block.
447+ # So there are 5 total working blocks.
448+ TOTAL_NUM_BLOCKS = 6
449+ vllm_config = create_vllm_config ()
450+ scheduler = create_scheduler (vllm_config , num_blocks = TOTAL_NUM_BLOCKS )
451+
452+ # Prime the KVCache.
453+ NUM_PROMPT_BLOCKS = 2
454+ BLOCK_SIZE = vllm_config .cache_config .block_size
455+ # Prompt will use 2 blocks + 1 block after we schedule.
456+ NUM_TOKENS_LOCAL = int (BLOCK_SIZE * NUM_PROMPT_BLOCKS )
457+ NUM_TOKENS_REMOTE = int (BLOCK_SIZE * (NUM_PROMPT_BLOCKS + 0.5 ))
458+
459+ request_normal = create_request (request_id = 1 , num_tokens = NUM_TOKENS_LOCAL )
460+ request_remote = create_request (request_id = 2 ,
461+ num_tokens = NUM_TOKENS_REMOTE ,
462+ do_remote_prefill = True )
463+
464+ # STEP 1: 3 blocks are in use (2 for prompt, 1 for decode).
465+ scheduler .add_request (request_normal )
466+ scheduler_output = scheduler .schedule ()
467+ model_runner_output = create_model_runner_output (reqs = [request_normal ])
468+ scheduler .update_from_output (scheduler_output , model_runner_output )
469+ assert len (scheduler .running ) == 1
470+ assert len (scheduler .waiting ) == 0
471+
472+ # Step 2: 3 blocks are in use,
473+ # need 3 new for remote blocks but only 2 are available.
474+ scheduler .add_request (request_remote )
475+ scheduler_output = scheduler .schedule ()
476+ model_runner_output = create_model_runner_output (reqs = [request_normal ])
477+ scheduler .update_from_output (scheduler_output , model_runner_output )
478+ assert len (scheduler .running ) == 1
479+ assert len (scheduler .waiting ) == 1
480+ # Should not have KV transfer in progress.
481+ assert (request_remote .status != RequestStatus .WAITING_FOR_REMOTE_KVS )
482+
483+ # Step 3: finish the request, free it.
484+ scheduler_output = scheduler .schedule ()
485+ model_runner_output = create_model_runner_output (reqs = [request_normal ],
486+ use_eos = True )
487+ scheduler .update_from_output (scheduler_output , model_runner_output )
488+ assert len (scheduler .running ) == 0
489+ assert len (scheduler .waiting ) == 1
490+
491+ # Step 4: now we can initiate KV transfer (with 2 blocks computed).
492+ scheduler_output = scheduler .schedule ()
493+ model_runner_output = create_model_runner_output (reqs = [])
494+ scheduler .update_from_output (scheduler_output , model_runner_output )
495+ assert len (scheduler .running ) == 0
496+ assert len (scheduler .waiting ) == 1
497+ assert (request_remote .status == RequestStatus .WAITING_FOR_REMOTE_KVS )
498+
499+ # Step 5: finish recving (5 blocks in use)
500+ scheduler_output = scheduler .schedule ()
501+ model_runner_output = create_model_runner_output (
502+ reqs = [], finished_recving = [request_remote .request_id ])
503+ scheduler .update_from_output (scheduler_output , model_runner_output )
504+ assert len (scheduler .running ) == 0
505+ assert len (scheduler .waiting ) == 1
506+
507+ # Step 6: schedule remote request
508+ scheduler_output = scheduler .schedule ()
509+ model_runner_output = create_model_runner_output (reqs = [request_remote ])
510+ scheduler .update_from_output (scheduler_output , model_runner_output )
511+ assert len (scheduler .running ) == 1
512+ assert len (scheduler .waiting ) == 0
513+
420514 # Step 7: free everything.
421515 scheduler_output = scheduler .schedule ()
422516 model_runner_output = create_model_runner_output (reqs = [request_remote ],
0 commit comments