11# SPDX-License-Identifier: Apache-2.0
22
3- from typing import Dict , List , Set , Tuple
3+ from typing import Dict , List , Optional , Set , Tuple
44
55import numpy as np
66import pytest
@@ -41,7 +41,7 @@ def _remove_requests(
4141 for index in req_indices_to_remove :
4242 input_batch .remove_request (reqs [index ].req_id )
4343 req_ids_to_remove .add (reqs [index ].req_id )
44- return ( req_ids_to_remove , req_indices_to_remove_list )
44+ return req_ids_to_remove , req_indices_to_remove_list
4545
4646
4747def _construct_expected_sampling_metadata (
@@ -64,8 +64,7 @@ def _construct_expected_sampling_metadata(
6464 top_p = [0.0 for _ in range (num_reqs )]
6565 min_p = [0.0 for _ in range (num_reqs )]
6666 temperature = [0.0 for _ in range (num_reqs )]
67- stop_token_ids : List [Set [int ]] = [set () for _ in range (num_reqs )]
68- min_tokens = [0 for _ in range (num_reqs )]
67+ min_tokens = {}
6968 logit_bias = [None ] * num_reqs
7069 for req in reqs :
7170 if req .req_id not in req_ids_retained :
@@ -83,22 +82,21 @@ def _construct_expected_sampling_metadata(
8382 top_p [index_in_input_batch ] = req .sampling_params .top_p
8483 min_p [index_in_input_batch ] = req .sampling_params .min_p
8584 temperature [index_in_input_batch ] = req .sampling_params .temperature
86- stop_token_ids [
87- index_in_input_batch ] = req .sampling_params .all_stop_token_ids
88- min_tokens [ index_in_input_batch ] = req .sampling_params .min_tokens
85+ min_tokens [ index_in_input_batch ] = (
86+ req .sampling_params .min_tokens ,
87+ req .sampling_params .all_stop_token_ids )
8988 logit_bias [index_in_input_batch ] = req .sampling_params .logit_bias
9089 return SamplingMetadata (
9190 temperature = torch .tensor (temperature , dtype = torch .float ,
9291 device = device ),
9392 all_greedy = False ,
9493 all_random = True ,
95- rejection_sampling = False ,
96- top_p = torch .tensor (top_p , dtype = torch .float , device = device ),
97- top_k = torch .tensor (top_k , dtype = torch .int , device = device ),
98- no_top_p = all (x == 1.0 for x in top_p ),
99- no_top_k = all (x == 0 for x in top_k ),
100- min_p = torch .tensor (min_p , dtype = torch .float , device = device ),
101- no_min_p = all (x == 0.0 for x in min_p ),
94+ top_p = None if all (x == 1.0 for x in top_p ) else torch .tensor (
95+ top_p , dtype = torch .float , device = device ),
96+ top_k = None if all (x == 0 for x in top_k ) else torch .tensor (
97+ top_k , dtype = torch .int , device = device ),
98+ min_p = None if all (x == 0.0 for x in min_p ) else torch .tensor (
99+ min_p , dtype = torch .float , device = device ),
102100 generators = {},
103101 max_num_logprobs = 0 ,
104102 prompt_token_ids = make_tensor_with_pad (
@@ -117,9 +115,8 @@ def _construct_expected_sampling_metadata(
117115 dtype = torch .float ,
118116 device = device ),
119117 output_token_ids = output_token_ids ,
120- spec_token_ids = [] ,
118+ spec_token_ids = None ,
121119 min_tokens = min_tokens ,
122- stop_token_ids = stop_token_ids ,
123120 no_penalties = (all (x == 0 for x in presence_penalties )
124121 and all (x == 0 for x in frequency_penalties )
125122 and all (x == 1 for x in repetition_penalties )),
@@ -206,8 +203,7 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
206203 input_batch .condense (req_indices_to_remove )
207204
208205 # Generate the sampling metadata
209- sampling_metadata = input_batch .make_sampling_metadata (
210- req_id_output_token_ids , req_id_to_spec_token_ids = {}, skip_copy = False )
206+ sampling_metadata = input_batch ._make_sampling_metadata ()
211207
212208 # Create expected output.
213209 expected_sampling_metadata = _construct_expected_sampling_metadata (
@@ -216,13 +212,16 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
216212 input_batch .req_id_to_index ,
217213 device = torch .device (device ))
218214
215+ def same (t1 : Optional [torch .Tensor ], t2 : Optional [torch .Tensor ]) -> bool :
216+ return (t1 is None
217+ and t2 is None ) or (t1 is not None and t2 is not None
218+ and torch .allclose (t1 , t2 ))
219+
219220 # Assert the actual and expected output.
220221 assert torch .allclose (expected_sampling_metadata .temperature ,
221222 sampling_metadata .temperature )
222- assert torch .allclose (expected_sampling_metadata .top_p ,
223- sampling_metadata .top_p )
224- assert torch .allclose (expected_sampling_metadata .top_k ,
225- sampling_metadata .top_k )
223+ assert same (expected_sampling_metadata .top_p , sampling_metadata .top_p )
224+ assert same (expected_sampling_metadata .top_k , sampling_metadata .top_k )
226225 assert torch .allclose (
227226 expected_sampling_metadata .frequency_penalties ,
228227 sampling_metadata .frequency_penalties ,
@@ -240,10 +239,6 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
240239 assert (expected_sampling_metadata .output_token_ids ==
241240 sampling_metadata .output_token_ids )
242241 assert expected_sampling_metadata .min_tokens == sampling_metadata .min_tokens
243- assert expected_sampling_metadata .stop_token_ids == \
244- sampling_metadata .stop_token_ids
245242 assert expected_sampling_metadata .no_penalties == \
246243 sampling_metadata .no_penalties
247- assert expected_sampling_metadata .no_top_p == sampling_metadata .no_top_p
248- assert expected_sampling_metadata .no_top_k == sampling_metadata .no_top_k
249244 assert expected_sampling_metadata .logit_bias == sampling_metadata .logit_bias
0 commit comments