@@ -45,9 +45,11 @@ def _remove_requests(
4545
4646
4747def _construct_expected_sampling_metadata (
48- reqs : List [CachedRequestState ], req_ids_retained : Set [int ],
49- req_id_index_in_input_batch : Dict [str , int ],
50- device : torch .device ) -> SamplingMetadata :
48+ reqs : List [CachedRequestState ],
49+ req_ids_retained : Set [int ],
50+ req_id_index_in_input_batch : Dict [str , int ],
51+ device : torch .device ,
52+ ) -> SamplingMetadata :
5153 """
5254 Constructs and returns the expected SamplingMetadata for this
5355 batch.
@@ -63,6 +65,7 @@ def _construct_expected_sampling_metadata(
6365 temperature = [0.0 for _ in range (num_reqs )]
6466 stop_token_ids : List [Set [int ]] = [set () for _ in range (num_reqs )]
6567 min_tokens = [0 for _ in range (num_reqs )]
68+ logit_bias = [None ] * num_reqs
6669 for req in reqs :
6770 if req .req_id not in req_ids_retained :
6871 continue
@@ -71,20 +74,21 @@ def _construct_expected_sampling_metadata(
7174 prompt_token_ids [index_in_input_batch ] = req .prompt_token_ids
7275 presence_penalties [
7376 index_in_input_batch ] = req .sampling_params .presence_penalty
74- frequency_penalties [
75- index_in_input_batch ] = req .sampling_params .frequency_penalty
76- repetition_penalties [
77- index_in_input_batch ] = req .sampling_params .repetition_penalty
77+ frequency_penalties [index_in_input_batch ] = (
78+ req .sampling_params .frequency_penalty )
79+ repetition_penalties [index_in_input_batch ] = (
80+ req .sampling_params .repetition_penalty )
7881 top_k [index_in_input_batch ] = req .sampling_params .top_k
7982 top_p [index_in_input_batch ] = req .sampling_params .top_p
8083 temperature [index_in_input_batch ] = req .sampling_params .temperature
8184 stop_token_ids [
8285 index_in_input_batch ] = req .sampling_params .all_stop_token_ids
8386 min_tokens [index_in_input_batch ] = req .sampling_params .min_tokens
84-
87+ logit_bias [ index_in_input_batch ] = req . sampling_params . logit_bias
8588
8689 return SamplingMetadata (
87- temperature = torch .tensor (temperature , dtype = torch .float , device = device ),
90+ temperature = torch .tensor (temperature , dtype = torch .float ,
91+ device = device ),
8892 all_greedy = False ,
8993 all_random = True ,
9094 top_p = torch .tensor (top_p , dtype = torch .float , device = device ),
@@ -93,41 +97,45 @@ def _construct_expected_sampling_metadata(
9397 no_top_k = all (x == 0 for x in top_k ),
9498 generators = {},
9599 max_num_logprobs = 0 ,
96- prompt_token_ids = make_tensor_with_pad (
100+ prompt_token_ids = make_tensor_with_pad (
97101 prompt_token_ids ,
98102 pad = VOCAB_SIZE ,
99103 device = torch .device (device ),
100104 dtype = torch .int64 ,
101105 ),
102- frequency_penalties = torch .tensor (
103- frequency_penalties , dtype = torch .float ,
104- device = device ),
105- presence_penalties = torch .tensor (
106- presence_penalties , dtype = torch .float ,
107- device = device ),
108- repetition_penalties = torch .tensor (
109- repetition_penalties , dtype = torch .float ,
110- device = device ),
106+ frequency_penalties = torch .tensor (frequency_penalties ,
107+ dtype = torch .float ,
108+ device = device ),
109+ presence_penalties = torch .tensor (presence_penalties ,
110+ dtype = torch .float ,
111+ device = device ),
112+ repetition_penalties = torch .tensor (repetition_penalties ,
113+ dtype = torch .float ,
114+ device = device ),
111115 output_token_ids = output_token_ids ,
112116 min_tokens = min_tokens ,
113117 stop_token_ids = stop_token_ids ,
114- no_penalties = (all (x == 0 for x in presence_penalties ) and \
115- all (x == 0 for x in frequency_penalties ) and \
116- all (x == 1 for x in repetition_penalties ))
118+ no_penalties = (all (x == 0 for x in presence_penalties )
119+ and all (x == 0 for x in frequency_penalties )
120+ and all (x == 1 for x in repetition_penalties )),
121+ logit_bias = logit_bias ,
117122 )
118123
119124
120125def _create_sampling_params ():
121- return SamplingParams (top_k = np .random .randint (1 , 10 ),
122- top_p = np .random .uniform (0.0 , 1.0 ),
123- presence_penalty = np .random .uniform (- 2.0 , 2.0 ),
124- repetition_penalty = np .random .uniform (0.0 , 2.0 ),
125- frequency_penalty = np .random .uniform (- 2.0 , 2.0 ),
126- min_tokens = np .random .randint (1 , 10 ),
127- stop_token_ids = [
128- np .random .randint (0 , VOCAB_SIZE )
129- for _ in range (np .random .randint (10 ))
130- ])
126+ return SamplingParams (
127+ top_k = np .random .randint (1 , 10 ),
128+ top_p = np .random .uniform (0.0 , 1.0 ),
129+ presence_penalty = np .random .uniform (- 2.0 , 2.0 ),
130+ repetition_penalty = np .random .uniform (0.0 , 2.0 ),
131+ frequency_penalty = np .random .uniform (- 2.0 , 2.0 ),
132+ min_tokens = np .random .randint (1 , 10 ),
133+ stop_token_ids = [
134+ np .random .randint (0 , VOCAB_SIZE )
135+ for _ in range (np .random .randint (10 ))
136+ ],
137+ logit_bias = {0 : np .random .uniform (- 3.0 , 3.0 )},
138+ )
131139
132140
133141def _construct_cached_request_state (req_id_suffix : int ):
@@ -139,16 +147,18 @@ def _construct_cached_request_state(req_id_suffix: int):
139147 np .random .randint (0 , VOCAB_SIZE )
140148 for _ in range (np .random .randint (0 , NUM_OUTPUT_TOKENS ))
141149 ]
142- return CachedRequestState (req_id = f"req_id_{ req_id_suffix } " ,
143- prompt_token_ids = prompt_token_ids ,
144- prompt = None ,
145- sampling_params = _create_sampling_params (),
146- mm_inputs = [],
147- mm_positions = [],
148- block_ids = [],
149- generator = None ,
150- num_computed_tokens = len (output_token_ids ),
151- output_token_ids = output_token_ids )
150+ return CachedRequestState (
151+ req_id = f"req_id_{ req_id_suffix } " ,
152+ prompt_token_ids = prompt_token_ids ,
153+ prompt = None ,
154+ sampling_params = _create_sampling_params (),
155+ mm_inputs = [],
156+ mm_positions = [],
157+ block_ids = [],
158+ generator = None ,
159+ num_computed_tokens = len (output_token_ids ),
160+ output_token_ids = output_token_ids ,
161+ )
152162
153163
154164@pytest .mark .parametrize ("device" , CUDA_DEVICES )
@@ -163,12 +173,14 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
163173 output of `make_sampling_metadata` is then compared against the expected
164174 results to ensure correctness.
165175 """
166- input_batch : InputBatch = InputBatch (max_num_reqs = batch_size ,
167- max_model_len = 1024 ,
168- max_num_blocks_per_req = 10 ,
169- device = torch .device (device ),
170- pin_memory = is_pin_memory_available (),
171- vocab_size = 1024 )
176+ input_batch : InputBatch = InputBatch (
177+ max_num_reqs = batch_size ,
178+ max_model_len = 1024 ,
179+ max_num_blocks_per_req = 10 ,
180+ device = torch .device (device ),
181+ pin_memory = is_pin_memory_available (),
182+ vocab_size = 1024 ,
183+ )
172184 reqs : List [CachedRequestState ] = []
173185 req_id_reqs = {}
174186 req_id_output_token_ids = {}
@@ -206,21 +218,27 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
206218 sampling_metadata .top_p )
207219 assert torch .allclose (expected_sampling_metadata .top_k ,
208220 sampling_metadata .top_k )
209- assert torch .allclose (expected_sampling_metadata .frequency_penalties ,
210- sampling_metadata .frequency_penalties )
211- assert torch .allclose (expected_sampling_metadata .presence_penalties ,
212- sampling_metadata .presence_penalties )
213- assert torch .allclose (expected_sampling_metadata .repetition_penalties ,
214- sampling_metadata .repetition_penalties )
221+ assert torch .allclose (
222+ expected_sampling_metadata .frequency_penalties ,
223+ sampling_metadata .frequency_penalties ,
224+ )
225+ assert torch .allclose (
226+ expected_sampling_metadata .presence_penalties ,
227+ sampling_metadata .presence_penalties ,
228+ )
229+ assert torch .allclose (
230+ expected_sampling_metadata .repetition_penalties ,
231+ sampling_metadata .repetition_penalties ,
232+ )
215233 assert torch .allclose (expected_sampling_metadata .prompt_token_ids ,
216234 sampling_metadata .prompt_token_ids )
217235 assert (expected_sampling_metadata .output_token_ids ==
218236 sampling_metadata .output_token_ids )
219- assert (
220- expected_sampling_metadata .min_tokens == sampling_metadata . min_tokens )
221- assert ( expected_sampling_metadata .stop_token_ids ==
222- sampling_metadata . stop_token_ids )
223- assert ( expected_sampling_metadata .no_penalties ==
224- sampling_metadata .no_penalties )
225- assert ( expected_sampling_metadata .no_top_p == sampling_metadata .no_top_p )
226- assert ( expected_sampling_metadata .no_top_k == sampling_metadata .no_top_k )
237+ assert expected_sampling_metadata . min_tokens == sampling_metadata . min_tokens
238+ assert expected_sampling_metadata .stop_token_ids == \
239+ sampling_metadata .stop_token_ids
240+ assert expected_sampling_metadata . no_penalties == \
241+ sampling_metadata .no_penalties
242+ assert expected_sampling_metadata . no_top_p == sampling_metadata .no_top_p
243+ assert expected_sampling_metadata .no_top_k == sampling_metadata .no_top_k
244+ assert expected_sampling_metadata .logit_bias == sampling_metadata .logit_bias
0 commit comments