2222_PAD_SLOT_ID = 0 # FIXME(woosuk)
2323# FIXME(woosuk): Temporarily disabled top-p sampling since it's too slow.
2424_ENABLE_TOP_P = False
25+ # FIXME(woosuk): A temporary hack to support `n > 1`.
26+ # This can significantly affect the performance if too large.
27+ _MAX_NUM_SAMPLES = 128
2528
2629
2730class TPUModelRunner :
@@ -143,8 +146,9 @@ def _dummy_run(
143146 p = torch .ones ((batch_size , ), dtype = torch .float32 , device = self .device )
144147
145148 # Dummy run.
149+ num_samples = _MAX_NUM_SAMPLES if is_prompt else 1
146150 self .model (token_ids , position_ids , kv_caches , attn_metadata ,
147- input_lens , t , p )
151+ input_lens , t , p , num_samples )
148152
149153 def warmup_model (
150154 self ,
@@ -268,14 +272,11 @@ def _prepare_decode(
268272 input_positions : List [List [int ]] = []
269273 slot_mapping : List [List [int ]] = []
270274 context_lens : List [int ] = []
271- num_seq_groups = len (seq_group_metadata_list )
272- batch_size = _get_padded_batch_size (num_seq_groups )
273275
274- for i , seq_group_metadata in enumerate (seq_group_metadata_list ):
276+ batch_idx = 0
277+ for seq_group_metadata in seq_group_metadata_list :
275278 assert not seq_group_metadata .is_prompt
276-
277279 seq_ids = list (seq_group_metadata .seq_data .keys ())
278-
279280 for seq_id in seq_ids :
280281 seq_data = seq_group_metadata .seq_data [seq_id ]
281282 generation_token = seq_data .get_last_token_id ()
@@ -288,14 +289,16 @@ def _prepare_decode(
288289
289290 assert seq_group_metadata .block_tables is not None
290291 block_table = seq_group_metadata .block_tables [seq_id ]
291- self .block_tables [i , :len (block_table )] = block_table
292+ self .block_tables [batch_idx , :len (block_table )] = block_table
293+ batch_idx += 1
292294
293295 block_number = block_table [position // self .block_size ]
294296 block_offset = position % self .block_size
295297 slot = block_number * self .block_size + block_offset
296298 slot_mapping .append ([slot ])
297299
298- num_paddings = batch_size - num_seq_groups
300+ batch_size = _get_padded_batch_size (batch_idx )
301+ num_paddings = batch_size - batch_idx
299302 input_tokens = input_tokens + [[0 ]] * num_paddings
300303 input_positions = input_positions + [[0 ]] * num_paddings
301304 slot_mapping = slot_mapping + [[_PAD_SLOT_ID ]] * num_paddings
@@ -333,14 +336,13 @@ def _prepare_sample(
333336 self ,
334337 seq_group_metadata_list : List [SequenceGroupMetadata ],
335338 padded_batch_size : int ,
336- ) -> Tuple [torch .Tensor , torch .Tensor ]:
339+ ) -> Tuple [torch .Tensor , torch .Tensor , List [ int ] ]:
337340 assert len (seq_group_metadata_list ) > 0
338341 t = []
339342 p = []
343+ best_of = []
340344 for seq_group_metadata in seq_group_metadata_list :
341- assert seq_group_metadata .sampling_params is not None
342345 sampling_params = seq_group_metadata .sampling_params
343-
344346 # NOTE(woosuk): Here we mimic argmax sampling by applying a very
345347 # low temperature. This is not accurate.
346348 t .append (sampling_params .temperature
@@ -354,10 +356,11 @@ def _prepare_sample(
354356 raise NotImplementedError (
355357 "Top-k sampling is currently disabled for the TPU backend "
356358 "due to performance issues." )
357- if sampling_params .best_of > 1 :
359+ if sampling_params .best_of > _MAX_NUM_SAMPLES :
358360 raise NotImplementedError (
359- "best_of > 1 is not currently supported by the TPU "
361+ f"Best of > { _MAX_NUM_SAMPLES } is not supported by the TPU "
360362 "backend." )
363+ best_of .append (sampling_params .best_of )
361364 if sampling_params .use_beam_search :
362365 raise NotImplementedError (
363366 "Beam search is not supported by the TPU backend." )
@@ -369,13 +372,19 @@ def _prepare_sample(
369372 "prompt_logprobs is not currently supported by the TPU "
370373 "backend." )
371374
372- num_paddings = padded_batch_size - len (seq_group_metadata_list )
375+ # Repeat the sampling params if the seq group has multiple seqs.
376+ num_seqs = len (seq_group_metadata .seq_data )
377+ t += [t [- 1 ]] * (num_seqs - 1 )
378+ p += [p [- 1 ]] * (num_seqs - 1 )
379+ best_of += [best_of [- 1 ]] * (num_seqs - 1 )
380+
381+ num_paddings = padded_batch_size - len (t )
373382 t += [1.0 ] * num_paddings
374383 p += [1.0 ] * num_paddings
375384
376385 t = torch .tensor (t , dtype = torch .float32 , device = self .device )
377386 p = torch .tensor (p , dtype = torch .float32 , device = self .device )
378- return t , p
387+ return t , p , best_of
379388
380389 def _execute_model (
381390 self ,
@@ -392,28 +401,41 @@ def _execute_model(
392401 else :
393402 inputs = self ._prepare_decode (seq_group_metadata_list )
394403 padded_batch_size = inputs [0 ].shape [0 ]
395- t , p = self ._prepare_sample (seq_group_metadata_list , padded_batch_size )
404+ t , p , best_of = self ._prepare_sample (seq_group_metadata_list ,
405+ padded_batch_size )
406+ num_samples = _MAX_NUM_SAMPLES if is_prompt else 1
396407
397408 # Execute the model.
398409 next_token_ids = self .model (inputs [0 ], inputs [1 ], kv_caches ,
399- * inputs [2 :], t , p )
410+ * inputs [2 :], t , p , num_samples )
400411 # Retrieve the outputs to CPU.
401412 next_token_ids = next_token_ids .cpu ().tolist ()
402413
403414 # NOTE(woosuk): Minimal code to construct the sampler outputs.
404415 # The TPU backend does not reuse the sampler, since the TPU backend
405416 # does not support the advanced sampling parameters such as logprobs.
406- i = 0
417+ zero_logprob = Logprob (0.0 )
418+ batch_idx = 0
407419 sampler_outputs = []
408420 for seq_group_metadata in seq_group_metadata_list :
409421 seq_outputs = []
410422 seq_ids = list (seq_group_metadata .seq_data .keys ())
411- for seq_id in seq_ids :
412- next_token_id = next_token_ids [i ]
413- seq_outputs .append (
414- SequenceOutput (seq_id , next_token_id ,
415- {next_token_id : Logprob (0.0 )}))
416- i += 1
423+ if is_prompt :
424+ assert len (seq_ids ) == 1
425+ seq_id = seq_ids [0 ]
426+ for i in range (best_of [batch_idx ]):
427+ next_token_id = next_token_ids [batch_idx ][i ]
428+ seq_outputs .append (
429+ SequenceOutput (seq_id , next_token_id ,
430+ {next_token_id : zero_logprob }))
431+ batch_idx += 1
432+ else :
433+ for seq_id in seq_ids :
434+ next_token_id = next_token_ids [batch_idx ][0 ]
435+ seq_outputs .append (
436+ SequenceOutput (seq_id , next_token_id ,
437+ {next_token_id : zero_logprob }))
438+ batch_idx += 1
417439 sampler_outputs .append (
418440 CompletionSequenceGroupOutput (seq_outputs , None ))
419441 return sampler_outputs
@@ -458,6 +480,7 @@ def forward(
458480 input_lens : torch .Tensor ,
459481 t : torch .Tensor ,
460482 p : torch .Tensor ,
483+ num_samples : int ,
461484 ) -> torch .Tensor :
462485 """Executes the forward pass of the model and samples the next token.
463486
@@ -520,8 +543,9 @@ def forward(
520543 if _ENABLE_TOP_P :
521544 logits = _apply_top_p (logits , p .unsqueeze (dim = 1 ))
522545 probs = torch .softmax (logits , dim = - 1 , dtype = torch .float32 )
523- # FIXME(woosuk): best_of > 1 is not supported.
524- next_token_ids = torch .multinomial (probs , num_samples = 1 ).squeeze (dim = 1 )
546+ next_token_ids = torch .multinomial (probs ,
547+ num_samples ,
548+ replacement = True )
525549 return next_token_ids
526550
527551
0 commit comments