@@ -65,8 +65,7 @@ def mm_inputs(self) -> list[MultiModalKwargsItems]:
6565 def get_token_id (self , idx : int ) -> int :
6666 if idx < self .num_prompt_tokens :
6767 return self .prompt_token_ids [idx ]
68- else :
69- return self .output_token_ids [idx - self .num_prompt_tokens ]
68+ return self .output_token_ids [idx - self .num_prompt_tokens ]
7069
7170
7271class InputBatch :
@@ -261,30 +260,27 @@ def _register_add_request(self, request: "CachedRequestState") -> int:
261260 Not applicable to pooling models.
262261 """
263262
264- # Detailed added request metadata is only required for non-pooling
265- # models, to support logitsprocs
266- assert request .sampling_params
267-
268263 # Fill the next empty index if there is one.
269264 if (new_req_index := self .batch_update_builder .pop_removed ()) is None :
270265 # Append to end otherwise.
271266 new_req_index = self .num_reqs
272267
273268 assert new_req_index < self .max_num_reqs
274- self .batch_update_builder .added .append (
275- (new_req_index , request .sampling_params , request .prompt_token_ids ,
276- request .output_token_ids ))
269+ self .batch_update_builder .batch_changed = True
270+ if request .sampling_params :
271+ # Detailed added request metadata is only required for non-pooling
272+ # models, to support logitsprocs.
273+ self .batch_update_builder .added .append (
274+ (new_req_index , request .sampling_params ,
275+ request .prompt_token_ids , request .output_token_ids ))
276+
277277 return new_req_index
278278
279279 def add_request (
280280 self ,
281281 request : "CachedRequestState" ,
282282 ) -> int :
283- if not self .is_pooling_model :
284- # New request index bookkeeping for autoregressive models.
285- req_index = self ._register_add_request (request )
286- else :
287- req_index = self .num_reqs
283+ req_index = self ._register_add_request (request )
288284
289285 req_id = request .req_id
290286 if req_index == len (self ._req_ids ):
@@ -389,7 +385,7 @@ def add_request(
389385 self .logits_processing_needs_token_ids [req_index ] = (
390386 pooling_params .requires_token_ids )
391387 else :
392- raise NotImplementedError (request )
388+ raise NotImplementedError ("Unrecognized request type" )
393389
394390 # Add request lora ID
395391 if request .lora_request :
@@ -419,13 +415,25 @@ def remove_request(self, req_id: str) -> Optional[int]:
419415 req_index = self .req_id_to_index .pop (req_id , None )
420416 if req_index is None :
421417 return None
422- if not self .is_pooling_model :
423- # Autoregressive models require bookkeeping of removed requests to
424- # support logitsprocs.
425- self .batch_update_builder .removed_append (req_index )
418+
419+ self .batch_update_builder .removed_append (req_index )
426420 self ._req_ids [req_index ] = None
427421 self .req_output_token_ids [req_index ] = None
428422
423+ # LoRA
424+ lora_id = self .request_lora_mapping [req_index ]
425+ if lora_id != 0 :
426+ lora_req_ids = self .lora_id_to_request_ids [lora_id ]
427+ lora_req_ids .discard (req_id )
428+ if not lora_req_ids :
429+ del self .lora_id_to_request_ids [lora_id ]
430+ del self .lora_id_to_lora_request [lora_id ]
431+ self .request_lora_mapping [req_index ] = 0
432+
433+ if self .is_pooling_model :
434+ self .pooling_params .pop (req_id , None )
435+ return req_index
436+
429437 self .greedy_reqs .discard (req_id )
430438 self .random_reqs .discard (req_id )
431439 self .top_p_reqs .discard (req_id )
@@ -439,29 +447,14 @@ def remove_request(self, req_id: str) -> Optional[int]:
439447 self .num_prompt_logprobs .pop (req_id , None )
440448 self .in_progress_prompt_logprobs_cpu .pop (req_id , None )
441449
442- # LoRA
443- lora_id = self .request_lora_mapping [req_index ]
444- if lora_id != 0 :
445- lora_req_ids = self .lora_id_to_request_ids [lora_id ]
446- lora_req_ids .discard (req_id )
447- if not lora_req_ids :
448- del self .lora_id_to_request_ids [lora_id ]
449- del self .lora_id_to_lora_request [lora_id ]
450- self .request_lora_mapping [req_index ] = 0
451-
452450 self .has_allowed_token_ids .discard (req_id )
453451 if self .allowed_token_ids_mask_cpu_tensor is not None :
454452 # False means we don't fill with -inf.
455453 self .allowed_token_ids_mask_cpu_tensor [req_index ].fill_ (False )
456454 self .bad_words_token_ids .pop (req_index , None )
457- self .pooling_params .pop (req_id , None )
458455 return req_index
459456
460457 def swap_states (self , i1 : int , i2 : int ) -> None :
461- # For autoregressive models, track detailed request reordering info
462- # to support logitsprocs
463- self .batch_update_builder .moved .append (
464- (i1 , i2 , MoveDirectionality .SWAP ))
465458 old_id_i1 = self ._req_ids [i1 ]
466459 old_id_i2 = self ._req_ids [i2 ]
467460 self ._req_ids [i1 ], self ._req_ids [i2 ] = \
@@ -479,18 +472,6 @@ def swap_states(self, i1: int, i2: int) -> None:
479472 self .num_prompt_tokens [i2 ], self .num_prompt_tokens [i1 ]
480473 self .num_computed_tokens_cpu [i1 ], self .num_computed_tokens_cpu [i2 ] = \
481474 self .num_computed_tokens_cpu [i2 ], self .num_computed_tokens_cpu [i1 ]
482- self .temperature_cpu [i1 ], self .temperature_cpu [i2 ] = \
483- self .temperature_cpu [i2 ], self .temperature_cpu [i1 ]
484- self .top_p_cpu [i1 ], self .top_p_cpu [i2 ] = \
485- self .top_p_cpu [i2 ], self .top_p_cpu [i1 ]
486- self .top_k_cpu [i1 ], self .top_k_cpu [i2 ] = \
487- self .top_k_cpu [i2 ], self .top_k_cpu [i1 ]
488- self .frequency_penalties_cpu [i1 ], self .frequency_penalties_cpu [i2 ] = \
489- self .frequency_penalties_cpu [i2 ], self .frequency_penalties_cpu [i1 ]
490- self .presence_penalties_cpu [i1 ], self .presence_penalties_cpu [i2 ] = \
491- self .presence_penalties_cpu [i2 ], self .presence_penalties_cpu [i1 ]
492- self .repetition_penalties_cpu [i1 ], self .repetition_penalties_cpu [i2 ] = \
493- self .repetition_penalties_cpu [i2 ], self .repetition_penalties_cpu [i1 ]
494475
495476 # NOTE: the following is unsafe
496477 # self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\
@@ -501,18 +482,41 @@ def swap_states(self, i1: int, i2: int) -> None:
501482 self .token_ids_cpu [i1 , ...] = self .token_ids_cpu [i2 , ...]
502483 self .token_ids_cpu [i2 , ...] = tmp
503484
504- swap_dict_values (self .generators , i1 , i2 )
505- swap_dict_values (self .bad_words_token_ids , i1 , i2 )
485+ self .block_table .swap_row (i1 , i2 )
506486
507- self .request_lora_mapping [i1 ], self .request_lora_mapping [i2 ] = \
487+ self .request_lora_mapping [i1 ], self .request_lora_mapping [i2 ] = \
508488 self .request_lora_mapping [i2 ], self .request_lora_mapping [i1 ]
509489
490+ if self .is_pooling_model :
491+ # Sampling and logits parameters don't apply to pooling models.
492+ return
493+
494+ # For autoregressive models, track detailed request reordering info
495+ # to support logitsprocs.
496+ self .batch_update_builder .moved .append (
497+ (i1 , i2 , MoveDirectionality .SWAP ))
498+
499+ self .temperature_cpu [i1 ], self .temperature_cpu [i2 ] = \
500+ self .temperature_cpu [i2 ], self .temperature_cpu [i1 ]
501+ self .top_p_cpu [i1 ], self .top_p_cpu [i2 ] = \
502+ self .top_p_cpu [i2 ], self .top_p_cpu [i1 ]
503+ self .top_k_cpu [i1 ], self .top_k_cpu [i2 ] = \
504+ self .top_k_cpu [i2 ], self .top_k_cpu [i1 ]
505+ self .frequency_penalties_cpu [i1 ], self .frequency_penalties_cpu [i2 ] = \
506+ self .frequency_penalties_cpu [i2 ], self .frequency_penalties_cpu [i1 ]
507+ self .presence_penalties_cpu [i1 ], self .presence_penalties_cpu [i2 ] = \
508+ self .presence_penalties_cpu [i2 ], self .presence_penalties_cpu [i1 ]
509+ self .repetition_penalties_cpu [i1 ], self .repetition_penalties_cpu [i2 ] = \
510+ self .repetition_penalties_cpu [i2 ], self .repetition_penalties_cpu [i1 ]
511+
512+ swap_dict_values (self .generators , i1 , i2 )
513+ swap_dict_values (self .bad_words_token_ids , i1 , i2 )
514+
510515 if self .allowed_token_ids_mask_cpu_tensor is not None :
511516 self .allowed_token_ids_mask_cpu_tensor [i1 ], \
512517 self .allowed_token_ids_mask_cpu_tensor [i2 ] = \
513518 self .allowed_token_ids_mask_cpu_tensor [i2 ], \
514519 self .allowed_token_ids_mask_cpu_tensor [i1 ]
515- self .block_table .swap_row (i1 , i2 )
516520
517521 def condense (self ) -> None :
518522 """Slide non-empty requests down into lower, empty indices.
@@ -529,12 +533,6 @@ def condense(self) -> None:
529533 """
530534 num_reqs = self .num_reqs
531535
532- if self .is_pooling_model :
533- # Will be contiguous in pooling case, just trim the lists.
534- del self ._req_ids [num_reqs :]
535- del self .req_output_token_ids [num_reqs :]
536- return
537-
538536 if not (empty_req_indices := self .batch_update_builder .removed ):
539537 # All removed requests were replaced by added requests, or else no
540538 # requests were removed at all. No condense() needed
@@ -562,11 +560,6 @@ def condense(self) -> None:
562560 # Move active request down into empty request
563561 # index.
564562 self .batch_update_builder .pop_removed ()
565- # Autoregressive models require detailed tracking of condense
566- # operations to support logitsprocs
567- self .batch_update_builder .moved .append (
568- (last_req_index , empty_index ,
569- MoveDirectionality .UNIDIRECTIONAL ))
570563 req_id = self ._req_ids [last_req_index ]
571564 output_token_ids = self .req_output_token_ids [last_req_index ]
572565 assert req_id is not None
@@ -587,6 +580,21 @@ def condense(self) -> None:
587580 self .num_computed_tokens_cpu [
588581 empty_index ] = self .num_computed_tokens_cpu [last_req_index ]
589582 self .block_table .move_row (last_req_index , empty_index )
583+
584+ self .request_lora_mapping [empty_index ] = self .request_lora_mapping [
585+ last_req_index ]
586+
587+ if self .is_pooling_model :
588+ last_req_index -= 1
589+ # Samping state not used by pooling models.
590+ continue
591+
592+ # Autoregressive models require detailed tracking of condense
593+ # operations to support logitsprocs
594+ self .batch_update_builder .moved .append (
595+ (last_req_index , empty_index ,
596+ MoveDirectionality .UNIDIRECTIONAL ))
597+
590598 self .temperature_cpu [empty_index ] = self .temperature_cpu [
591599 last_req_index ]
592600 self .top_p_cpu [empty_index ] = self .top_p_cpu [last_req_index ]
@@ -601,9 +609,6 @@ def condense(self) -> None:
601609 if generator is not None :
602610 self .generators [empty_index ] = generator
603611
604- self .request_lora_mapping [empty_index ] = self .request_lora_mapping [
605- last_req_index ]
606-
607612 # TODO convert these to LogitsProcessors
608613 if self .allowed_token_ids_mask_cpu_tensor is not None :
609614 self .allowed_token_ids_mask_cpu_tensor [
@@ -626,8 +631,9 @@ def refresh_metadata(self):
626631 """Apply any batch updates to sampling metadata."""
627632
628633 if self .is_pooling_model :
629- # Batch changes every step for pooling models.
630- self .sampling_metadata = self ._make_sampling_metadata ()
634+ batch_changed = self .batch_update_builder .reset ()
635+ if batch_changed :
636+ self .sampling_metadata = self ._make_sampling_metadata ()
631637 return
632638
633639 # For non-pooling models - generate and apply logitsprocs update;
@@ -720,19 +726,19 @@ def pooling_metadata(self) -> PoolingMetadata:
720726 )
721727
722728 def _make_prompt_token_ids_tensor (self ) -> torch .Tensor :
723- max_prompt_len = self .num_prompt_tokens [:self .num_reqs ].max ()
729+ num_reqs = self .num_reqs
730+ max_prompt_len = self .num_prompt_tokens [:num_reqs ].max ()
724731 prompt_token_ids_cpu_tensor = torch .empty (
725732 (self .num_reqs , max_prompt_len ),
726733 device = "cpu" ,
727734 dtype = torch .int64 ,
728735 pin_memory = self .pin_memory ,
729736 )
730737 prompt_token_ids = prompt_token_ids_cpu_tensor .numpy ()
731- prompt_token_ids [:] = self .token_ids_cpu [:self .
732- num_reqs , :max_prompt_len ]
738+ prompt_token_ids [:] = self .token_ids_cpu [:num_reqs , :max_prompt_len ]
733739 # Use the value of vocab_size as a pad since we don't have a
734740 # token_id of this value.
735- for i in range (self . num_reqs ):
741+ for i in range (num_reqs ):
736742 prompt_token_ids [i , self .num_prompt_tokens [i ]:] = self .vocab_size
737743 return prompt_token_ids_cpu_tensor .to (device = self .device ,
738744 non_blocking = True )
0 commit comments