@@ -600,3 +600,98 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool],
600600 prompt_logprobs_dict = {},
601601 )
602602 scheduler .update_from_output (scheduler_output1 , model_runner_output )
603+
604+
605+ # Note - these test cases mirror some of those in test_rejection_sampler.py
606+ @pytest .mark .parametrize (
607+ "spec_tokens,output_tokens,expected" ,
608+ [
609+ ([[1 , 2 , 3 ]], [[1 , 2 , 3 , 4 ]], (3 , 3 )), # perfect match
610+ ([[1 , 2 , 3 ]], [[1 , 5 ]], (3 , 1 )), # early mismatch
611+ ([[1 , 2 ], [3 ]], [[1 , 2 , 5 ], [3 , 4 ]], (3 , 3 )), # multiple sequences
612+ ([[1 ]], [[1 , 2 ]], (1 , 1 )), # single token sequence
613+ ([[]], [[5 ]], (0 , 0 )), # empty sequence
614+ ([[1 , 2 , 3 ], [4 , 5 , 6 ]], [[1 , 2 , 7 ], [4 , 8 ]],
615+ (6 , 3 )), # multiple mismatches
616+ ])
617+ def test_schedule_spec_decoding_stats (spec_tokens , output_tokens , expected ):
618+ """Test scheduling behavior with speculative decoding.
619+
620+ This test verifies that:
621+ 1. Speculated tokens get scheduled correctly
622+ 2. Spec decoding stats properly count number of draft and accepted tokens
623+ """
624+ scheduler = create_scheduler ()
625+ requests = create_requests (num_requests = len (spec_tokens ), num_tokens = 1 )
626+ req_ids = []
627+ req_to_index = {}
628+ for i , request in enumerate (requests ):
629+ scheduler .add_request (request )
630+ req_ids .append (request .request_id )
631+ req_to_index [request .request_id ] = i
632+
633+ # Schedule a decode, which will also draft speculative tokens
634+ output = scheduler .schedule ()
635+ assert len (output .scheduled_new_reqs ) == len (requests )
636+ assert output .total_num_scheduled_tokens == len (requests )
637+ for i in range (len (requests )):
638+ req_id = requests [i ].request_id
639+ assert output .num_scheduled_tokens [req_id ] == 1
640+ assert req_id not in output .scheduled_spec_decode_tokens
641+
642+ model_runner_output = ModelRunnerOutput (
643+ req_ids = req_ids ,
644+ req_id_to_index = req_to_index ,
645+ sampled_token_ids = [[0 ] for _ in range (len (requests ))],
646+ spec_token_ids = spec_tokens ,
647+ logprobs = None ,
648+ prompt_logprobs_dict = {},
649+ )
650+ engine_core_outputs = scheduler .update_from_output (output ,
651+ model_runner_output )
652+
653+ for i in range (len (requests )):
654+ running_req = scheduler .running [i ]
655+ # The prompt token
656+ assert running_req .num_computed_tokens == 1
657+ # The prompt token and the sampled token
658+ assert running_req .num_tokens == 2
659+ # The prompt token, the sampled token, and the speculated tokens
660+ assert running_req .num_tokens_with_spec == 2 + len (spec_tokens [i ])
661+
662+ # No draft or accepted tokens counted yet
663+ assert engine_core_outputs .scheduler_stats .spec_decoding_stats is not None
664+ stats = engine_core_outputs .scheduler_stats .spec_decoding_stats
665+ assert stats .num_draft_tokens == 0
666+ assert stats .num_accepted_tokens == 0
667+
668+ # Schedule the speculated tokens for validation
669+ output = scheduler .schedule ()
670+ assert len (output .scheduled_new_reqs ) == 0
671+ # The sampled token and speculated tokens
672+ assert output .total_num_scheduled_tokens == \
673+ len (requests ) + sum (len (ids ) for ids in spec_tokens )
674+ for i in range (len (requests )):
675+ req_id = requests [i ].request_id
676+ assert output .num_scheduled_tokens [req_id ] == 1 + len (spec_tokens [i ])
677+ if spec_tokens [i ]:
678+ assert len (output .scheduled_spec_decode_tokens [req_id ]) == \
679+ len (spec_tokens [i ])
680+ else :
681+ assert req_id not in output .scheduled_spec_decode_tokens
682+
683+ model_runner_output = ModelRunnerOutput (
684+ req_ids = req_ids ,
685+ req_id_to_index = req_to_index ,
686+ sampled_token_ids = output_tokens ,
687+ spec_token_ids = None ,
688+ logprobs = None ,
689+ prompt_logprobs_dict = {},
690+ )
691+ engine_core_outputs = scheduler .update_from_output (output ,
692+ model_runner_output )
693+
694+ assert engine_core_outputs .scheduler_stats .spec_decoding_stats is not None
695+ stats = engine_core_outputs .scheduler_stats .spec_decoding_stats
696+ assert stats .num_draft_tokens == expected [0 ]
697+ assert stats .num_accepted_tokens == expected [1 ]
0 commit comments