@@ -95,6 +95,10 @@ def test_perfect_match(rejection_sampler):
9595 device = logits .device )
9696 assert torch .equal (output , expected )
9797
98+ assert rejection_sampler .stats .num_draft_tokens == 3
99+ assert rejection_sampler .stats .num_accepted_tokens == 3
100+ assert rejection_sampler .stats .num_emitted_tokens == 4
101+
98102
99103def test_early_mismatch (rejection_sampler ):
100104 """Test when there's an early mismatch in tokens"""
@@ -122,6 +126,10 @@ def test_early_mismatch(rejection_sampler):
122126 )
123127 assert torch .equal (output , expected )
124128
129+ assert rejection_sampler .stats .num_draft_tokens == 3
130+ assert rejection_sampler .stats .num_accepted_tokens == 1
131+ assert rejection_sampler .stats .num_emitted_tokens == 2
132+
125133
126134def test_multiple_sequences (rejection_sampler ):
127135 """Test handling multiple sequences of speculated tokens"""
@@ -148,6 +156,10 @@ def test_multiple_sequences(rejection_sampler):
148156 device = logits .device )
149157 assert torch .equal (output , expected )
150158
159+ assert rejection_sampler .stats .num_draft_tokens == 3
160+ assert rejection_sampler .stats .num_accepted_tokens == 3
161+ assert rejection_sampler .stats .num_emitted_tokens == 5
162+
151163
152164def test_single_token_sequence (rejection_sampler ):
153165 """Test handling sequences with single token"""
@@ -171,6 +183,10 @@ def test_single_token_sequence(rejection_sampler):
171183 expected = torch .tensor ([[1 , 2 ]], dtype = torch .int , device = logits .device )
172184 assert torch .equal (output , expected )
173185
186+ assert rejection_sampler .stats .num_draft_tokens == 1
187+ assert rejection_sampler .stats .num_accepted_tokens == 1
188+ assert rejection_sampler .stats .num_emitted_tokens == 2
189+
174190
175191def test_empty_sequence (rejection_sampler ):
176192 """Test handling empty sequence of speculated tokens"""
@@ -194,6 +210,10 @@ def test_empty_sequence(rejection_sampler):
194210 expected = torch .tensor ([[5 ]], dtype = torch .int , device = logits .device )
195211 assert torch .equal (output , expected )
196212
213+ assert rejection_sampler .stats .num_draft_tokens == 0
214+ assert rejection_sampler .stats .num_accepted_tokens == 0
215+ assert rejection_sampler .stats .num_emitted_tokens == 1
216+
197217
198218def test_multiple_mismatches (rejection_sampler ):
199219 """Test handling multiple sequences with mismatches"""
@@ -223,17 +243,24 @@ def test_multiple_mismatches(rejection_sampler):
223243 )
224244 assert torch .equal (output , expected )
225245
246+ assert rejection_sampler .stats .num_draft_tokens == 6
247+ assert rejection_sampler .stats .num_accepted_tokens == 3
248+ assert rejection_sampler .stats .num_emitted_tokens == 5
249+
226250
227251@pytest .mark .parametrize (
228- "spec_tokens,output_tokens,expected" ,
252+ "spec_tokens,output_tokens,expected,expected_stats " ,
229253 [
230- ([[1 , 2 ]], [[1 , 2 , 3 ]], [[1 , 2 , 3 ]]), # Perfect match with bonus
231- ([[1 ]], [[2 , 3 ]], [[2 , PLACEHOLDER_TOKEN_ID ]]), # First mismatch
232- ([[1 , 2 ], [3 , 4 ]], [[1 , 5 , 6 ], [3 , 4 , 7 ]],
233- [[1 , 5 , PLACEHOLDER_TOKEN_ID ], [3 , 4 , 7 ]]), # Mixed matches
254+ ([[1 , 2 ]], [[1 , 2 , 3 ]], [[1 , 2 , 3 ]],
255+ (2 , 2 , 3 )), # Perfect match with bonus
256+ ([[1 ]], [[2 , 3 ]], [[2 , PLACEHOLDER_TOKEN_ID ]],
257+ (1 , 0 , 1 )), # First mismatch
258+ ([[1 , 2 ], [3 , 4 ]], [[1 , 5 , 6 ], [3 , 4 , 7 ]
259+ ], [[1 , 5 , PLACEHOLDER_TOKEN_ID ], [3 , 4 , 7 ]],
260+ (4 , 3 , 5 )), # Mixed matches
234261 ])
235262def test_parametrized_cases (rejection_sampler , spec_tokens , output_tokens ,
236- expected ):
263+ expected , expected_stats ):
237264 """Parametrized test for various matching scenarios"""
238265 metadata = create_sampling_metadata (all_greedy = True )
239266 logits = create_logits_tensor (output_tokens )
@@ -254,6 +281,10 @@ def test_parametrized_cases(rejection_sampler, spec_tokens, output_tokens,
254281 device = logits .device )
255282 assert torch .equal (output , expected_tensor )
256283
284+ assert rejection_sampler .stats .num_draft_tokens == expected_stats [0 ]
285+ assert rejection_sampler .stats .num_accepted_tokens == expected_stats [1 ]
286+ assert rejection_sampler .stats .num_emitted_tokens == expected_stats [2 ]
287+
257288
258289########################### Tests for Random Sampling ###################
259290@pytest .mark .parametrize ("k" , [1 , 3 , 5 ])
@@ -314,6 +345,12 @@ def test_deterministic_when_seeded(
314345
315346 results .append (rep_result )
316347
348+ stats = rejection_sampler .stats .take ()
349+ assert stats .num_draft_tokens == num_tokens
350+ assert stats .num_emitted_tokens >= batch_size
351+ assert (stats .num_emitted_tokens -
352+ batch_size ) == stats .num_accepted_tokens
353+
317354 for i in range (batch_size ):
318355 if seeded_mask [i ]:
319356 for j in range (1 , n_rep ):
0 commit comments