@@ -192,7 +192,7 @@ def __init__(
192192
193193 self .logit_bias : List [Optional [Dict [int ,
194194 float ]]] = [None ] * max_num_reqs
195- self .has_allowed_token_ids : List [ bool ] = [ False ] * max_num_reqs
195+ self .has_allowed_token_ids : Set [ str ] = set ()
196196 self .allowed_token_ids_mask : Optional [torch .Tensor ] = None
197197 self .allowed_token_ids_mask_cpu_tensor : Optional [torch .Tensor ] = None
198198
@@ -297,7 +297,7 @@ def add_request(
297297 for tid in sampling_params .allowed_token_ids ):
298298 raise ValueError (
299299 "allowed_token_ids contains out-of-vocab token id" )
300- self .has_allowed_token_ids [ req_index ] = True
300+ self .has_allowed_token_ids . add ( req_id )
301301 if self .allowed_token_ids_mask_cpu_tensor is None :
302302 # Lazy allocation for this tensor, which can be large.
303303 self .allowed_token_ids_mask = torch .zeros (self .max_num_reqs ,
@@ -357,7 +357,7 @@ def remove_request(self, req_id: str) -> Optional[int]:
357357 self .request_lora_mapping [req_index ] = 0
358358
359359 self .logit_bias [req_index ] = None
360- self .has_allowed_token_ids [ req_index ] = False
360+ self .has_allowed_token_ids . discard ( req_id )
361361 if self .allowed_token_ids_mask_cpu_tensor is not None :
362362 self .allowed_token_ids_mask_cpu_tensor [req_index ].fill_ (False )
363363 return req_index
@@ -428,8 +428,6 @@ def condense(self, empty_req_indices: List[int]) -> None:
428428
429429 self .logit_bias [empty_index ] = self .logit_bias [last_req_index ]
430430
431- self .has_allowed_token_ids [
432- empty_index ] = self .has_allowed_token_ids [last_req_index ]
433431 if self .allowed_token_ids_mask_cpu_tensor is not None :
434432 self .allowed_token_ids_mask_cpu_tensor [
435433 empty_index ] = self .allowed_token_ids_mask_cpu_tensor [
@@ -478,8 +476,8 @@ def _make_sampling_metadata(self) -> SamplingMetadata:
478476 prompt_token_ids = None
479477
480478 allowed_token_ids_mask : Optional [torch .Tensor ] = None
481- if not self .no_allowed_token_ids and \
482- self .allowed_token_ids_mask is not None :
479+ if not self .no_allowed_token_ids :
480+ assert self .allowed_token_ids_mask is not None
483481 copy_slice (self .allowed_token_ids_mask_cpu_tensor ,
484482 self .allowed_token_ids_mask , num_reqs )
485483 allowed_token_ids_mask = self .allowed_token_ids_mask [:num_reqs ]
@@ -502,7 +500,6 @@ def _make_sampling_metadata(self) -> SamplingMetadata:
502500 min_tokens = self .min_tokens ,
503501 no_penalties = self .no_penalties ,
504502 logit_bias = self .logit_bias [:num_reqs ],
505- no_allowed_token_ids = self .no_allowed_token_ids ,
506503 allowed_token_ids_mask = allowed_token_ids_mask ,
507504 )
508505
@@ -597,4 +594,4 @@ def no_prompt_logprob(self) -> bool:
597594
598595 @property
599596 def no_allowed_token_ids (self ) -> bool :
600- return not any (self .has_allowed_token_ids )
597+ return len (self .has_allowed_token_ids ) == 0
0 commit comments