Skip to content

Commit 840f4ce

Browse files
committed
[V1][Spec Decoding] Add scheduler test cases
Signed-off-by: Mark McLoughlin <markmc@redhat.com>
1 parent 85ce056 commit 840f4ce

File tree

1 file changed

+95
-0
lines changed

1 file changed

+95
-0
lines changed

tests/v1/core/test_scheduler.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -600,3 +600,98 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool],
600600
prompt_logprobs_dict={},
601601
)
602602
scheduler.update_from_output(scheduler_output1, model_runner_output)
603+
604+
605+
# Note - these test cases mirror some of those in test_rejection_sampler.py
606+
@pytest.mark.parametrize(
607+
"spec_tokens,output_tokens,expected",
608+
[
609+
([[1, 2, 3]], [[1, 2, 3, 4]], (3, 3)), # perfect match
610+
([[1, 2, 3]], [[1, 5]], (3, 1)), # early mismatch
611+
([[1, 2], [3]], [[1, 2, 5], [3, 4]], (3, 3)), # multiple sequences
612+
([[1]], [[1, 2]], (1, 1)), # single token sequence
613+
([[]], [[5]], (0, 0)), # empty sequence
614+
([[1, 2, 3], [4, 5, 6]], [[1, 2, 7], [4, 8]],
615+
(6, 3)), # multiple mismatches
616+
])
617+
def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
618+
"""Test scheduling behavior with speculative decoding.
619+
620+
This test verifies that:
621+
1. Speculated tokens get scheduled correctly
622+
2. Spec decoding stats properly count number of draft and accepted tokens
623+
"""
624+
scheduler = create_scheduler()
625+
requests = create_requests(num_requests=len(spec_tokens), num_tokens=1)
626+
req_ids = []
627+
req_to_index = {}
628+
for i, request in enumerate(requests):
629+
scheduler.add_request(request)
630+
req_ids.append(request.request_id)
631+
req_to_index[request.request_id] = i
632+
633+
# Schedule a decode, which will also draft speculative tokens
634+
output = scheduler.schedule()
635+
assert len(output.scheduled_new_reqs) == len(requests)
636+
assert output.total_num_scheduled_tokens == len(requests)
637+
for i in range(len(requests)):
638+
req_id = requests[i].request_id
639+
assert output.num_scheduled_tokens[req_id] == 1
640+
assert req_id not in output.scheduled_spec_decode_tokens
641+
642+
model_runner_output = ModelRunnerOutput(
643+
req_ids=req_ids,
644+
req_id_to_index=req_to_index,
645+
sampled_token_ids=[[0] for _ in range(len(requests))],
646+
spec_token_ids=spec_tokens,
647+
logprobs=None,
648+
prompt_logprobs_dict={},
649+
)
650+
engine_core_outputs = scheduler.update_from_output(output,
651+
model_runner_output)
652+
653+
for i in range(len(requests)):
654+
running_req = scheduler.running[i]
655+
# The prompt token
656+
assert running_req.num_computed_tokens == 1
657+
# The prompt token and the sampled token
658+
assert running_req.num_tokens == 2
659+
# The prompt token, the sampled token, and the speculated tokens
660+
assert running_req.num_tokens_with_spec == 2 + len(spec_tokens[i])
661+
662+
# No draft or accepted tokens counted yet
663+
assert engine_core_outputs.scheduler_stats.spec_decoding_stats is not None
664+
stats = engine_core_outputs.scheduler_stats.spec_decoding_stats
665+
assert stats.num_draft_tokens == 0
666+
assert stats.num_accepted_tokens == 0
667+
668+
# Schedule the speculated tokens for validation
669+
output = scheduler.schedule()
670+
assert len(output.scheduled_new_reqs) == 0
671+
# The sampled token and speculated tokens
672+
assert output.total_num_scheduled_tokens == \
673+
len(requests) + sum(len(ids) for ids in spec_tokens)
674+
for i in range(len(requests)):
675+
req_id = requests[i].request_id
676+
assert output.num_scheduled_tokens[req_id] == 1 + len(spec_tokens[i])
677+
if spec_tokens[i]:
678+
assert len(output.scheduled_spec_decode_tokens[req_id]) == \
679+
len(spec_tokens[i])
680+
else:
681+
assert req_id not in output.scheduled_spec_decode_tokens
682+
683+
model_runner_output = ModelRunnerOutput(
684+
req_ids=req_ids,
685+
req_id_to_index=req_to_index,
686+
sampled_token_ids=output_tokens,
687+
spec_token_ids=None,
688+
logprobs=None,
689+
prompt_logprobs_dict={},
690+
)
691+
engine_core_outputs = scheduler.update_from_output(output,
692+
model_runner_output)
693+
694+
assert engine_core_outputs.scheduler_stats.spec_decoding_stats is not None
695+
stats = engine_core_outputs.scheduler_stats.spec_decoding_stats
696+
assert stats.num_draft_tokens == expected[0]
697+
assert stats.num_accepted_tokens == expected[1]

0 commit comments

Comments
 (0)