11# SPDX-License-Identifier: Apache-2.0
22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33from typing import Any , Optional
4+ from unittest .mock import Mock
45
56import pytest
67import torch
1112from vllm .v1 .sample .metadata import SamplingMetadata
1213from vllm .v1 .sample .rejection_sampler import (PLACEHOLDER_TOKEN_ID ,
1314 RejectionSampler )
15+ from vllm .v1 .sample .sampler import Sampler , SamplerOutput
1416from vllm .v1 .spec_decode .metadata import SpecDecodeMetadata
1517
1618DEVICE = current_platform .device_type
1719
1820
1921@pytest .fixture
2022def rejection_sampler ():
21- return RejectionSampler ()
23+ mock_sampler = Mock (spec = Sampler )
24+ return RejectionSampler (mock_sampler )
25+
26+
27+ def mock_sampler_output (rejection_sampler : RejectionSampler ,
28+ bonus_token_ids : torch .Tensor ):
29+ rejection_sampler .sampler .return_value = SamplerOutput (
30+ sampled_token_ids = bonus_token_ids , logprobs_tensors = None )
31+
32+
33+ def create_spec_decode_metadata (spec_tokens : list [list [int ]],
34+ logits : torch .Tensor ) -> SpecDecodeMetadata :
35+ metadata = SpecDecodeMetadata .make_dummy (spec_tokens , device = logits .device )
36+ metadata .target_logits_indices = torch .arange (logits .shape [0 ])
37+ # Output bonus token ids are mocked, so the bonus logit indices should
38+ # be empty.
39+ metadata .bonus_logits_indices = torch .empty (0 , dtype = torch .int32 )
40+ return metadata
2241
2342
2443def create_logits_tensor (output_token_ids : list [list [int ]],
@@ -83,20 +102,19 @@ def test_perfect_match(rejection_sampler):
83102 logits = create_logits_tensor (output_tokens )
84103 bonus_token_tensor = torch .tensor ([output_tokens [0 ][- 1 ]],
85104 device = logits .device )
86- spec_decode_metadata = SpecDecodeMetadata .make_dummy (spec_tokens ,
87- device = logits .device )
105+ spec_decode_metadata = create_spec_decode_metadata (spec_tokens , logits )
88106
107+ mock_sampler_output (rejection_sampler , bonus_token_tensor )
89108 output = rejection_sampler (
90109 spec_decode_metadata ,
91110 draft_probs = None ,
92- target_logits = logits ,
93- bonus_token_ids = bonus_token_tensor ,
111+ logits = logits ,
94112 sampling_metadata = metadata ,
95113 )
96114 expected = torch .tensor ([[1 , 2 , 3 , 4 ]],
97115 dtype = torch .int ,
98116 device = logits .device )
99- assert torch .equal (output , expected )
117+ assert torch .equal (output . sampled_token_ids , expected )
100118
101119
102120def test_early_mismatch (rejection_sampler ):
@@ -108,22 +126,21 @@ def test_early_mismatch(rejection_sampler):
108126 logits = create_logits_tensor (output_tokens )
109127 bonus_token_tensor = torch .tensor ([output_tokens [0 ][- 1 ]],
110128 device = logits .device )
111- spec_decode_metadata = SpecDecodeMetadata .make_dummy (spec_tokens ,
112- device = logits .device )
129+ spec_decode_metadata = create_spec_decode_metadata (spec_tokens , logits )
113130
131+ mock_sampler_output (rejection_sampler , bonus_token_tensor )
114132 output = rejection_sampler (
115133 spec_decode_metadata ,
116134 draft_probs = None ,
117- target_logits = logits ,
118- bonus_token_ids = bonus_token_tensor ,
135+ logits = logits ,
119136 sampling_metadata = metadata ,
120137 )
121138 expected = torch .tensor (
122139 [[1 , 5 , PLACEHOLDER_TOKEN_ID , PLACEHOLDER_TOKEN_ID ]],
123140 dtype = torch .int ,
124141 device = logits .device ,
125142 )
126- assert torch .equal (output , expected )
143+ assert torch .equal (output . sampled_token_ids , expected )
127144
128145
129146def test_multiple_sequences (rejection_sampler ):
@@ -136,20 +153,19 @@ def test_multiple_sequences(rejection_sampler):
136153 logits = create_logits_tensor (output_tokens )
137154 bonus_token_tensor = torch .tensor (
138155 [output_tokens [0 ][- 1 ], output_tokens [1 ][- 1 ]], device = logits .device )
139- spec_decode_metadata = SpecDecodeMetadata .make_dummy (spec_tokens ,
140- device = logits .device )
156+ spec_decode_metadata = create_spec_decode_metadata (spec_tokens , logits )
141157
158+ mock_sampler_output (rejection_sampler , bonus_token_tensor )
142159 output = rejection_sampler (
143160 spec_decode_metadata ,
144161 draft_probs = None ,
145- target_logits = logits ,
146- bonus_token_ids = bonus_token_tensor ,
162+ logits = logits ,
147163 sampling_metadata = metadata ,
148164 )
149165 expected = torch .tensor ([[1 , 2 , 5 ], [3 , 4 , PLACEHOLDER_TOKEN_ID ]],
150166 dtype = torch .int ,
151167 device = logits .device )
152- assert torch .equal (output , expected )
168+ assert torch .equal (output . sampled_token_ids , expected )
153169
154170
155171def test_single_token_sequence (rejection_sampler ):
@@ -161,18 +177,17 @@ def test_single_token_sequence(rejection_sampler):
161177 logits = create_logits_tensor (output_tokens )
162178 bonus_token_tensor = torch .tensor ([output_tokens [0 ][- 1 ]],
163179 device = logits .device )
164- spec_decode_metadata = SpecDecodeMetadata .make_dummy (spec_tokens ,
165- device = logits .device )
180+ spec_decode_metadata = create_spec_decode_metadata (spec_tokens , logits )
166181
182+ mock_sampler_output (rejection_sampler , bonus_token_tensor )
167183 output = rejection_sampler (
168184 spec_decode_metadata ,
169185 draft_probs = None ,
170- target_logits = logits ,
171- bonus_token_ids = bonus_token_tensor ,
186+ logits = logits ,
172187 sampling_metadata = metadata ,
173188 )
174189 expected = torch .tensor ([[1 , 2 ]], dtype = torch .int , device = logits .device )
175- assert torch .equal (output , expected )
190+ assert torch .equal (output . sampled_token_ids , expected )
176191
177192
178193def test_empty_sequence (rejection_sampler ):
@@ -184,18 +199,17 @@ def test_empty_sequence(rejection_sampler):
184199 logits = create_logits_tensor (output_tokens )
185200 bonus_token_tensor = torch .tensor ([output_tokens [0 ][- 1 ]],
186201 device = logits .device )
187- spec_decode_metadata = SpecDecodeMetadata .make_dummy (spec_tokens ,
188- device = logits .device )
202+ spec_decode_metadata = create_spec_decode_metadata (spec_tokens , logits )
189203
204+ mock_sampler_output (rejection_sampler , bonus_token_tensor )
190205 output = rejection_sampler (
191206 spec_decode_metadata ,
192207 draft_probs = None ,
193- target_logits = logits ,
194- bonus_token_ids = bonus_token_tensor ,
208+ logits = logits ,
195209 sampling_metadata = metadata ,
196210 )
197211 expected = torch .tensor ([[5 ]], dtype = torch .int , device = logits .device )
198- assert torch .equal (output , expected )
212+ assert torch .equal (output . sampled_token_ids , expected )
199213
200214
201215def test_multiple_mismatches (rejection_sampler ):
@@ -208,14 +222,13 @@ def test_multiple_mismatches(rejection_sampler):
208222 logits = create_logits_tensor (output_tokens )
209223 bonus_token_tensor = torch .tensor (
210224 [output_tokens [0 ][- 1 ], output_tokens [1 ][- 1 ]], device = logits .device )
211- spec_decode_metadata = SpecDecodeMetadata .make_dummy (spec_tokens ,
212- device = logits .device )
225+ spec_decode_metadata = create_spec_decode_metadata (spec_tokens , logits )
213226
227+ mock_sampler_output (rejection_sampler , bonus_token_tensor )
214228 output = rejection_sampler (
215229 spec_decode_metadata ,
216230 draft_probs = None ,
217- target_logits = logits ,
218- bonus_token_ids = bonus_token_tensor ,
231+ logits = logits ,
219232 sampling_metadata = metadata ,
220233 )
221234 expected = torch .tensor (
@@ -224,7 +237,7 @@ def test_multiple_mismatches(rejection_sampler):
224237 dtype = torch .int ,
225238 device = logits .device ,
226239 )
227- assert torch .equal (output , expected )
240+ assert torch .equal (output . sampled_token_ids , expected )
228241
229242
230243@pytest .mark .parametrize (
@@ -242,20 +255,19 @@ def test_parametrized_cases(rejection_sampler, spec_tokens, output_tokens,
242255 logits = create_logits_tensor (output_tokens )
243256 bonus_token_tensor = torch .tensor ([tokens [- 1 ] for tokens in output_tokens ],
244257 device = logits .device )
245- spec_decode_metadata = SpecDecodeMetadata .make_dummy (spec_tokens ,
246- device = logits .device )
258+ spec_decode_metadata = create_spec_decode_metadata (spec_tokens , logits )
247259
260+ mock_sampler_output (rejection_sampler , bonus_token_tensor )
248261 output = rejection_sampler (
249262 spec_decode_metadata ,
250263 draft_probs = None ,
251- target_logits = logits ,
252- bonus_token_ids = bonus_token_tensor ,
264+ logits = logits ,
253265 sampling_metadata = metadata ,
254266 )
255267 expected_tensor = torch .tensor (expected ,
256268 dtype = torch .int ,
257269 device = logits .device )
258- assert torch .equal (output , expected_tensor )
270+ assert torch .equal (output . sampled_token_ids , expected_tensor )
259271
260272
261273########################### Tests for Random Sampling ###################
@@ -305,17 +317,18 @@ def test_deterministic_when_seeded(
305317 sampling_metadata = create_sampling_metadata (all_greedy = False ,
306318 temperature = temperature ,
307319 generators = seeded_seqs )
308- spec_decode_metadata = SpecDecodeMetadata .make_dummy (
309- draft_token_ids .tolist (), device = DEVICE )
320+ spec_decode_metadata = create_spec_decode_metadata (
321+ draft_token_ids .tolist (), target_logits )
322+
323+ mock_sampler_output (rejection_sampler , bonus_token_ids )
310324 rep_result = rejection_sampler (
311325 spec_decode_metadata ,
312- draft_probs = draft_probs ,
313- target_logits = target_logits ,
314- bonus_token_ids = bonus_token_ids ,
326+ draft_probs = None ,
327+ logits = target_logits ,
315328 sampling_metadata = sampling_metadata ,
316329 )
317330
318- results .append (rep_result )
331+ results .append (rep_result . sampled_token_ids )
319332
320333 for i in range (batch_size ):
321334 if seeded_mask [i ]:
@@ -424,7 +437,9 @@ def estimate_rejection_sampling_pdf(
424437 Returns:
425438 Estimated probability distribution of the output tokens.
426439 """
427- rejection_sampler = RejectionSampler ()
440+ # Mock the sampler that TreeRejectionSampler uses
441+ mock_sampler = Mock (spec = Sampler )
442+ rejection_sampler = RejectionSampler (mock_sampler )
428443 num_tokens = num_samples * k
429444 # Repeat draft probs num_samples * k times.
430445 draft_probs = draft_probs .reshape (1 , 1 ,
@@ -447,16 +462,17 @@ def estimate_rejection_sampling_pdf(
447462 temperature = torch .ones (num_samples , dtype = torch .float32 , device = DEVICE )
448463 sampling_metadata = create_sampling_metadata (all_greedy = False ,
449464 temperature = temperature )
450- spec_decode_metadata = SpecDecodeMetadata .make_dummy (
451- draft_token_ids .tolist (), device = bonus_token_ids .device )
452- output_token_ids = rejection_sampler (
465+ spec_decode_metadata = create_spec_decode_metadata (
466+ draft_token_ids .tolist (), target_logits )
467+
468+ mock_sampler_output (rejection_sampler , bonus_token_ids )
469+ sampler_output = rejection_sampler (
453470 spec_decode_metadata ,
454471 draft_probs = draft_probs ,
455- target_logits = target_logits ,
456- bonus_token_ids = bonus_token_ids ,
472+ logits = target_logits ,
457473 sampling_metadata = sampling_metadata ,
458474 )
459- output_token_ids = output_token_ids [:, :- 1 ].flatten ()
475+ output_token_ids = sampler_output . sampled_token_ids [:, :- 1 ].flatten ()
460476
461477 hist = torch .histogram (output_token_ids .to (dtype = torch .float ,
462478 device = "cpu" ),
@@ -496,22 +512,20 @@ def _test_masked_logits(
496512 device = DEVICE )
497513
498514 # Create spec decode metadata
499- spec_decode_metadata = SpecDecodeMetadata .make_dummy (
500- draft_token_ids ,
501- device = DEVICE ,
502- )
515+ spec_decode_metadata = create_spec_decode_metadata (draft_token_ids ,
516+ target_logits )
503517
504518 # Run rejection sampling
505- output_token_ids = rejection_sampler (
519+ mock_sampler_output (rejection_sampler , bonus_token_ids )
520+ output = rejection_sampler (
506521 spec_decode_metadata ,
507522 draft_probs = draft_probs ,
508- target_logits = target_logits ,
509- bonus_token_ids = bonus_token_ids ,
523+ logits = target_logits ,
510524 sampling_metadata = sampling_metadata ,
511525 )
512526
513527 # Remove bonus tokens and reshape
514- output_token_ids = output_token_ids [:, :- 1 ].flatten ().tolist ()
528+ output_token_ids = output . sampled_token_ids [:, :- 1 ].flatten ().tolist ()
515529
516530 # Check that all sampled tokens are within the unmasked indices.
517531 for i in range (num_tokens ):
0 commit comments