Skip to content

Commit bb78fb3

Browse files
authored
[v1] Support allowed_token_ids in v1 Sampler (#13210)
Signed-off-by: Lu Fang <lufang@fb.com>
1 parent 8aca27f commit bb78fb3

File tree

7 files changed

+168
-19
lines changed

7 files changed

+168
-19
lines changed

tests/v1/sample/test_rejection_sampler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def create_sampling_metadata(spec_tokens: List[List[int]]) -> SamplingMetadata:
4343
output_token_ids=[],
4444
min_tokens={},
4545
logit_bias=[None] * batch_size,
46+
allowed_token_ids_mask=None,
4647
)
4748

4849

tests/v1/sample/test_sampler.py

Lines changed: 79 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,26 @@ def _create_logit_bias(
5757
return res
5858

5959

60+
def _create_allowed_token_ids(
61+
batch_size: int,
62+
vocab_size: int,
63+
num_allowed_token_ids: int,
64+
device: torch.device,
65+
) -> Optional[torch.Tensor]:
66+
mask: Optional[torch.Tensor] = None
67+
for i in range(batch_size):
68+
if i % 2 == 1:
69+
continue
70+
if mask is None:
71+
mask = torch.zeros((batch_size, vocab_size),
72+
dtype=torch.bool,
73+
device=device)
74+
start = min(i, vocab_size - 1)
75+
end = min(i + num_allowed_token_ids, vocab_size - 1)
76+
mask[i, start:end] = True
77+
return mask
78+
79+
6080
def _create_default_sampling_metadata(
6181
num_output_tokens: int,
6282
batch_size: int,
@@ -92,6 +112,7 @@ def _create_default_sampling_metadata(
92112
no_penalties=True,
93113
min_tokens={},
94114
logit_bias=[None] * batch_size,
115+
allowed_token_ids_mask=None,
95116
)
96117
return fake_sampling_metadata
97118

@@ -253,7 +274,10 @@ def test_sampler_frequency_penalty(device: str, batch_size: int,
253274
sampling_metadata.frequency_penalties = _create_penalty_tensor(
254275
batch_size, frequency_penalty, torch.device(device))
255276
output_token_ids, sorted_token_ids_in_output = \
256-
_create_weighted_output_token_list(batch_size, VOCAB_SIZE)
277+
_create_weighted_output_token_list(
278+
batch_size,
279+
VOCAB_SIZE,
280+
)
257281
sampling_metadata.output_token_ids = output_token_ids
258282
sampling_metadata.no_penalties = False
259283
sampler = Sampler()
@@ -262,8 +286,8 @@ def test_sampler_frequency_penalty(device: str, batch_size: int,
262286
for batch_idx in range(batch_size):
263287
non_penalized_token_id = logits[batch_idx].argmax().item()
264288
penalized_token_id = logits[batch_idx].argmin().item()
265-
distinct_sorted_token_ids_in_output = \
266-
sorted_token_ids_in_output[batch_idx]
289+
distinct_sorted_token_ids_in_output = sorted_token_ids_in_output[
290+
batch_idx]
267291
most_frequent_token_id = distinct_sorted_token_ids_in_output[
268292
len(distinct_sorted_token_ids_in_output) - 1]
269293
if frequency_penalty > 0:
@@ -272,8 +296,8 @@ def test_sampler_frequency_penalty(device: str, batch_size: int,
272296
# non-penalized token ID is not present in the output, while the
273297
# most penalized token is the one that occurs most frequently in
274298
# the output.
275-
assert non_penalized_token_id \
276-
not in distinct_sorted_token_ids_in_output
299+
assert (non_penalized_token_id
300+
not in distinct_sorted_token_ids_in_output)
277301
assert penalized_token_id == most_frequent_token_id
278302
elif frequency_penalty < 0:
279303
# If `frequency_penalty` is set to < 0, it indicates
@@ -282,8 +306,7 @@ def test_sampler_frequency_penalty(device: str, batch_size: int,
282306
# in the output, while the penalized token ID is one that has not
283307
# yet appeared.
284308
assert non_penalized_token_id == most_frequent_token_id
285-
assert penalized_token_id \
286-
not in distinct_sorted_token_ids_in_output
309+
assert penalized_token_id not in distinct_sorted_token_ids_in_output
287310

288311

289312
@pytest.mark.parametrize("device", CUDA_DEVICES)
@@ -318,18 +341,18 @@ def test_sampler_repetition_penalty(device: str, batch_size: int,
318341
# If `repetition_penalty` > 1.0, verify that the non-penalized
319342
# token ID has not been seen before, while the penalized token ID
320343
# exists either in the prompt or the output.
321-
assert (non_penalized_token_id not in prompt_tokens and \
322-
non_penalized_token_id not in output_tokens)
323-
assert (penalized_token_id in prompt_tokens or \
324-
penalized_token_id in output_tokens)
344+
assert (non_penalized_token_id not in prompt_tokens
345+
and non_penalized_token_id not in output_tokens)
346+
assert (penalized_token_id in prompt_tokens
347+
or penalized_token_id in output_tokens)
325348
elif repetition_penalty < 1.0:
326349
# If `repetition_penalty` < 1.0, verify that the penalized
327350
# token ID has not been seen before, while the non-penalized
328351
# token ID exists either in the prompt or the output.
329-
assert (penalized_token_id not in prompt_tokens and \
330-
penalized_token_id not in output_tokens)
331-
assert (non_penalized_token_id in prompt_tokens or \
332-
non_penalized_token_id in output_tokens)
352+
assert (penalized_token_id not in prompt_tokens
353+
and penalized_token_id not in output_tokens)
354+
assert (non_penalized_token_id in prompt_tokens
355+
or non_penalized_token_id in output_tokens)
333356

334357

335358
@pytest.mark.parametrize("device", CUDA_DEVICES)
@@ -404,3 +427,44 @@ def test_sampler_logit_bias(device: str, batch_size: int, bias_value: float):
404427
1e-2)
405428
else:
406429
assert logits_for_req[token_id] == pytest.approx(1e-2)
430+
431+
432+
@pytest.mark.parametrize("device", CUDA_DEVICES)
433+
@pytest.mark.parametrize("batch_size", [1, 2, 32])
434+
@pytest.mark.parametrize("num_allowed_token_ids", [0, 1, 2])
435+
def test_sampler_allowed_token_ids(device: str, batch_size: int,
436+
num_allowed_token_ids: int):
437+
"""
438+
Test to verify that when the repetition penalty is enabled, tokens
439+
are penalized based on their presence in the prompt or the existing
440+
output.
441+
"""
442+
torch.set_default_device(device)
443+
# Create fake logits where each token is assigned the same
444+
# logit value.
445+
fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE)
446+
sampling_metadata = _create_default_sampling_metadata(
447+
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
448+
mask = _create_allowed_token_ids(
449+
batch_size=batch_size,
450+
vocab_size=VOCAB_SIZE,
451+
num_allowed_token_ids=num_allowed_token_ids,
452+
device=device,
453+
)
454+
sampling_metadata.allowed_token_ids_mask = mask
455+
sampler = Sampler()
456+
logits = sampler.apply_allowed_token_ids(fake_logits, sampling_metadata)
457+
logits = logits.cpu()
458+
for batch_idx in range(batch_size):
459+
logits_for_req = logits[batch_idx]
460+
if batch_idx % 2 == 1:
461+
assert torch.all(logits_for_req != -float("inf"))
462+
continue
463+
for token_id in range(VOCAB_SIZE):
464+
start = min(batch_idx, VOCAB_SIZE - 1)
465+
end = min(batch_idx + num_allowed_token_ids, VOCAB_SIZE - 1)
466+
if token_id >= start and token_id < end:
467+
assert logits_for_req[token_id] == -float(
468+
"inf"), f"{batch_idx}, {token_id}"
469+
else:
470+
assert logits_for_req[token_id] != -float("inf")

tests/v1/worker/test_gpu_input_batch.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,10 @@ def _construct_expected_sampling_metadata(
6666
temperature = [0.0 for _ in range(num_reqs)]
6767
min_tokens = {}
6868
logit_bias = [None] * num_reqs
69+
allowed_token_ids_mask = torch.zeros(num_reqs,
70+
VOCAB_SIZE,
71+
dtype=torch.bool,
72+
device=device)
6973
for req in reqs:
7074
if req.req_id not in req_ids_retained:
7175
continue
@@ -86,6 +90,10 @@ def _construct_expected_sampling_metadata(
8690
req.sampling_params.min_tokens,
8791
req.sampling_params.all_stop_token_ids)
8892
logit_bias[index_in_input_batch] = req.sampling_params.logit_bias
93+
if req.sampling_params.allowed_token_ids:
94+
allowed_token_ids_mask[index_in_input_batch][
95+
req.sampling_params.allowed_token_ids] = True
96+
8997
return SamplingMetadata(
9098
temperature=torch.tensor(temperature, dtype=torch.float,
9199
device=device),
@@ -121,6 +129,7 @@ def _construct_expected_sampling_metadata(
121129
and all(x == 0 for x in frequency_penalties)
122130
and all(x == 1 for x in repetition_penalties)),
123131
logit_bias=logit_bias,
132+
allowed_token_ids_mask=allowed_token_ids_mask,
124133
)
125134

126135

@@ -242,3 +251,7 @@ def same(t1: Optional[torch.Tensor], t2: Optional[torch.Tensor]) -> bool:
242251
assert expected_sampling_metadata.no_penalties == \
243252
sampling_metadata.no_penalties
244253
assert expected_sampling_metadata.logit_bias == sampling_metadata.logit_bias
254+
if sampling_metadata.allowed_token_ids_mask:
255+
assert torch.allclose(
256+
expected_sampling_metadata.allowed_token_ids_mask,
257+
sampling_metadata.allowed_token_ids_mask)

vllm/v1/engine/processor.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,19 @@ def _validate_lora(self, lora_request: Optional[LoRARequest]) -> None:
8383
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
8484
"not enabled!")
8585

86+
def _validate_allowed_token_ids(
87+
self,
88+
params: Union[SamplingParams, PoolingParams],
89+
) -> None:
90+
if not isinstance(params, SamplingParams):
91+
return
92+
if params.allowed_token_ids is None:
93+
return
94+
if not all(0 <= tid < self.model_config.vocab_size
95+
for tid in params.allowed_token_ids):
96+
raise ValueError(
97+
"allowed_token_ids contains out-of-vocab token id")
98+
8699
def process_inputs(
87100
self,
88101
request_id: str,
@@ -100,6 +113,7 @@ def process_inputs(
100113

101114
self._validate_logprobs(params)
102115
self._validate_lora(lora_request)
116+
self._validate_allowed_token_ids(params)
103117

104118
if arrival_time is None:
105119
arrival_time = time.time()

vllm/v1/sample/metadata.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,7 @@ class SamplingMetadata:
3737
min_tokens: Dict[int, Tuple[int, Set[int]]]
3838

3939
logit_bias: List[Optional[Dict[int, float]]]
40+
41+
# `allowed_token_ids_mask` is a 2D bool tensor of shape (max batch size,
42+
# vocab size).
43+
allowed_token_ids_mask: Optional[torch.Tensor]

vllm/v1/sample/sampler.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ def forward(
4747

4848
# Use float32 for the logits.
4949
logits = logits.to(torch.float32)
50+
# Apply allowed token ids.
51+
logits = self.apply_allowed_token_ids(logits, sampling_metadata)
5052
# Apply logits bias.
5153
logits = self.apply_logits_bias(logits, sampling_metadata)
5254
# Apply penalties (e.g., min_tokens, freq_penalties).
@@ -184,11 +186,13 @@ def apply_penalties(
184186
if not sampling_metadata.no_penalties:
185187
assert sampling_metadata.prompt_token_ids is not None
186188
logits = apply_all_penalties(
187-
logits, sampling_metadata.prompt_token_ids,
189+
logits,
190+
sampling_metadata.prompt_token_ids,
188191
sampling_metadata.presence_penalties,
189192
sampling_metadata.frequency_penalties,
190193
sampling_metadata.repetition_penalties,
191-
sampling_metadata.output_token_ids)
194+
sampling_metadata.output_token_ids,
195+
)
192196
return logits
193197

194198
def apply_min_p(
@@ -226,3 +230,13 @@ def apply_logits_bias(
226230
for token_id, bias in logit_bias.items():
227231
logits[i, token_id] += bias
228232
return logits
233+
234+
def apply_allowed_token_ids(
235+
self,
236+
logits: torch.Tensor,
237+
sampling_metadata: SamplingMetadata,
238+
) -> torch.Tensor:
239+
if sampling_metadata.allowed_token_ids_mask is not None:
240+
logits.masked_fill_(sampling_metadata.allowed_token_ids_mask,
241+
float("-inf"))
242+
return logits

vllm/v1/worker/gpu_input_batch.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def __init__(
143143
device="cpu",
144144
pin_memory=pin_memory)
145145
self.frequency_penalties_cpu = \
146-
self.frequency_penalties_cpu_tensor.numpy()
146+
self.frequency_penalties_cpu_tensor.numpy()
147147
self.frequency_penalties_reqs: Set[str] = set()
148148

149149
# Presence penalty related data structures
@@ -168,7 +168,7 @@ def __init__(
168168
device="cpu",
169169
pin_memory=pin_memory)
170170
self.repetition_penalties_cpu = \
171-
self.repetition_penalties_cpu_tensor.numpy()
171+
self.repetition_penalties_cpu_tensor.numpy()
172172
self.repetition_penalties_reqs: Set[str] = set()
173173

174174
# req_index -> (min_tokens, stop_token_ids)
@@ -192,6 +192,9 @@ def __init__(
192192

193193
self.logit_bias: List[Optional[Dict[int,
194194
float]]] = [None] * max_num_reqs
195+
self.has_allowed_token_ids: Set[str] = set()
196+
self.allowed_token_ids_mask: Optional[torch.Tensor] = None
197+
self.allowed_token_ids_mask_cpu_tensor: Optional[torch.Tensor] = None
195198

196199
self.req_output_token_ids: List[Optional[List[int]]] = []
197200

@@ -287,6 +290,22 @@ def add_request(
287290
if sampling_params.logit_bias is not None:
288291
self.logit_bias[req_index] = sampling_params.logit_bias
289292

293+
if sampling_params.allowed_token_ids:
294+
self.has_allowed_token_ids.add(req_id)
295+
if self.allowed_token_ids_mask_cpu_tensor is None:
296+
# Lazy allocation for this tensor, which can be large.
297+
self.allowed_token_ids_mask = torch.zeros(self.max_num_reqs,
298+
self.vocab_size,
299+
dtype=torch.bool,
300+
device=self.device)
301+
self.allowed_token_ids_mask_cpu_tensor = torch.zeros(
302+
self.max_num_reqs,
303+
self.vocab_size,
304+
dtype=torch.bool,
305+
device="cpu")
306+
self.allowed_token_ids_mask_cpu_tensor[req_index][
307+
sampling_params.allowed_token_ids] = True
308+
290309
# Add request lora ID
291310
if request.lora_request:
292311
lora_id = request.lora_request.lora_int_id
@@ -332,6 +351,9 @@ def remove_request(self, req_id: str) -> Optional[int]:
332351
self.request_lora_mapping[req_index] = 0
333352

334353
self.logit_bias[req_index] = None
354+
self.has_allowed_token_ids.discard(req_id)
355+
if self.allowed_token_ids_mask_cpu_tensor is not None:
356+
self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False)
335357
return req_index
336358

337359
def condense(self, empty_req_indices: List[int]) -> None:
@@ -400,6 +422,11 @@ def condense(self, empty_req_indices: List[int]) -> None:
400422

401423
self.logit_bias[empty_index] = self.logit_bias[last_req_index]
402424

425+
if self.allowed_token_ids_mask_cpu_tensor is not None:
426+
self.allowed_token_ids_mask_cpu_tensor[
427+
empty_index] = self.allowed_token_ids_mask_cpu_tensor[
428+
last_req_index]
429+
403430
# Decrement last_req_index since it is now empty.
404431
last_req_index -= 1
405432

@@ -442,6 +469,13 @@ def _make_sampling_metadata(self) -> SamplingMetadata:
442469
else:
443470
prompt_token_ids = None
444471

472+
allowed_token_ids_mask: Optional[torch.Tensor] = None
473+
if not self.no_allowed_token_ids:
474+
assert self.allowed_token_ids_mask is not None
475+
copy_slice(self.allowed_token_ids_mask_cpu_tensor,
476+
self.allowed_token_ids_mask, num_reqs)
477+
allowed_token_ids_mask = self.allowed_token_ids_mask[:num_reqs]
478+
445479
return SamplingMetadata(
446480
temperature=temperature,
447481
all_greedy=self.all_greedy,
@@ -460,6 +494,7 @@ def _make_sampling_metadata(self) -> SamplingMetadata:
460494
min_tokens=self.min_tokens,
461495
no_penalties=self.no_penalties,
462496
logit_bias=self.logit_bias[:num_reqs],
497+
allowed_token_ids_mask=allowed_token_ids_mask,
463498
)
464499

465500
def get_sampling_metadata(
@@ -550,3 +585,7 @@ def max_num_logprobs(self) -> Optional[int]:
550585
@property
551586
def no_prompt_logprob(self) -> bool:
552587
return not self.num_prompt_logprobs
588+
589+
@property
590+
def no_allowed_token_ids(self) -> bool:
591+
return len(self.has_allowed_token_ids) == 0

0 commit comments

Comments
 (0)