Skip to content

Commit 60b4393

Browse files
committed
move the check in the Processor
Signed-off-by: Lu Fang <lufang@fb.com>
1 parent 368369c commit 60b4393

File tree

3 files changed

+16
-9
lines changed

3 files changed

+16
-9
lines changed

vllm/v1/engine/processor.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,19 @@ def _validate_lora(self, lora_request: Optional[LoRARequest]) -> None:
8383
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
8484
"not enabled!")
8585

86+
def _validate_allowed_token_ids(
87+
self,
88+
params: Union[SamplingParams, PoolingParams],
89+
) -> None:
90+
if not isinstance(params, SamplingParams):
91+
return
92+
if params.allowed_token_ids is None:
93+
return
94+
if not all(0 <= tid < self.model_config.vocab_size
95+
for tid in params.allowed_token_ids):
96+
raise ValueError(
97+
"allowed_token_ids contains out-of-vocab token id")
98+
8699
def process_inputs(
87100
self,
88101
request_id: str,
@@ -100,6 +113,7 @@ def process_inputs(
100113

101114
self._validate_logprobs(params)
102115
self._validate_lora(lora_request)
116+
self._validate_allowed_token_ids(params)
103117

104118
if arrival_time is None:
105119
arrival_time = time.time()

vllm/v1/sample/sampler.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,8 @@ def forward(
5858

5959
# Gather the logprobs of the topk and sampled token (if requested).
6060
# Get logprobs and rank tensors (if requested)
61-
logprobs_tensors = (None if num_logprobs is None else \
62-
self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=sampled))
61+
logprobs_tensors = None if num_logprobs is None else \
62+
self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=sampled)
6363

6464
# Use int32 to reduce the tensor size.
6565
sampled = sampled.to(torch.int32)
@@ -183,7 +183,6 @@ def apply_penalties(
183183
apply_min_token_penalties(logits,
184184
sampling_metadata.output_token_ids,
185185
sampling_metadata.min_tokens)
186-
187186
if not sampling_metadata.no_penalties:
188187
assert sampling_metadata.prompt_token_ids is not None
189188
logits = apply_all_penalties(

vllm/v1/worker/gpu_input_batch.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -291,12 +291,6 @@ def add_request(
291291
self.logit_bias[req_index] = sampling_params.logit_bias
292292

293293
if sampling_params.allowed_token_ids:
294-
# NOTE(houseroad): put the check here since no vocab_size info
295-
# available in vllm/sampling_params.py
296-
if not all(0 <= tid < self.vocab_size
297-
for tid in sampling_params.allowed_token_ids):
298-
raise ValueError(
299-
"allowed_token_ids contains out-of-vocab token id")
300294
self.has_allowed_token_ids.add(req_id)
301295
if self.allowed_token_ids_mask_cpu_tensor is None:
302296
# Lazy allocation for this tensor, which can be large.

0 commit comments

Comments
 (0)