Skip to content

Commit 368369c

Browse files
committed
address comments
Signed-off-by: Lu Fang <lufang@fb.com>
1 parent 4897f02 commit 368369c

File tree

5 files changed

+11
-30
lines changed

5 files changed

+11
-30
lines changed

tests/v1/sample/test_sampler.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,8 @@ def _create_allowed_token_ids(
6262
vocab_size: int,
6363
num_allowed_token_ids: int,
6464
device: torch.device,
65-
) -> Tuple[bool, Optional[torch.Tensor]]:
65+
) -> Optional[torch.Tensor]:
6666
mask: Optional[torch.Tensor] = None
67-
no_allowed_token_ids = True
6867
for i in range(batch_size):
6968
if i % 2 == 1:
7069
continue
@@ -75,8 +74,7 @@ def _create_allowed_token_ids(
7574
start = min(i, vocab_size - 1)
7675
end = min(i + num_allowed_token_ids, vocab_size - 1)
7776
mask[i, start:end] = True
78-
no_allowed_token_ids = False
79-
return (no_allowed_token_ids, mask)
77+
return mask
8078

8179

8280
def _create_default_sampling_metadata(
@@ -114,7 +112,6 @@ def _create_default_sampling_metadata(
114112
no_penalties=True,
115113
min_tokens={},
116114
logit_bias=[None] * batch_size,
117-
no_allowed_token_ids=True,
118115
allowed_token_ids_mask=None,
119116
)
120117
return fake_sampling_metadata
@@ -448,18 +445,16 @@ def test_sampler_allowed_token_ids(device: str, batch_size: int,
448445
fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE)
449446
sampling_metadata = _create_default_sampling_metadata(
450447
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
451-
no_allowed_token_ids, mask = _create_allowed_token_ids(
448+
mask = _create_allowed_token_ids(
452449
batch_size=batch_size,
453450
vocab_size=VOCAB_SIZE,
454451
num_allowed_token_ids=num_allowed_token_ids,
455452
device=device,
456453
)
457-
sampling_metadata.no_allowed_token_ids = no_allowed_token_ids
458454
sampling_metadata.allowed_token_ids_mask = mask
459455
sampler = Sampler()
460456
logits = sampler.apply_allowed_token_ids(fake_logits, sampling_metadata)
461457
logits = logits.cpu()
462-
assert not sampling_metadata.no_allowed_token_ids
463458
for batch_idx in range(batch_size):
464459
logits_for_req = logits[batch_idx]
465460
if batch_idx % 2 == 1:

tests/v1/worker/test_gpu_input_batch.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,6 @@ def _construct_expected_sampling_metadata(
6666
temperature = [0.0 for _ in range(num_reqs)]
6767
min_tokens = {}
6868
logit_bias = [None] * num_reqs
69-
has_allowed_token_ids = [False] * num_reqs
7069
allowed_token_ids_mask = torch.zeros(num_reqs,
7170
VOCAB_SIZE,
7271
dtype=torch.bool,
@@ -92,7 +91,6 @@ def _construct_expected_sampling_metadata(
9291
req.sampling_params.all_stop_token_ids)
9392
logit_bias[index_in_input_batch] = req.sampling_params.logit_bias
9493
if req.sampling_params.allowed_token_ids:
95-
has_allowed_token_ids[index_in_input_batch] = True
9694
allowed_token_ids_mask[index_in_input_batch][
9795
req.sampling_params.allowed_token_ids] = True
9896

@@ -131,7 +129,6 @@ def _construct_expected_sampling_metadata(
131129
and all(x == 0 for x in frequency_penalties)
132130
and all(x == 1 for x in repetition_penalties)),
133131
logit_bias=logit_bias,
134-
no_allowed_token_ids=not any(has_allowed_token_ids),
135132
allowed_token_ids_mask=allowed_token_ids_mask,
136133
)
137134

@@ -254,9 +251,7 @@ def same(t1: Optional[torch.Tensor], t2: Optional[torch.Tensor]) -> bool:
254251
assert expected_sampling_metadata.no_penalties == \
255252
sampling_metadata.no_penalties
256253
assert expected_sampling_metadata.logit_bias == sampling_metadata.logit_bias
257-
assert (expected_sampling_metadata.no_allowed_token_ids ==
258-
sampling_metadata.no_allowed_token_ids)
259-
if not sampling_metadata.no_allowed_token_ids:
254+
if sampling_metadata.allowed_token_ids_mask:
260255
assert torch.allclose(
261256
expected_sampling_metadata.allowed_token_ids_mask,
262257
sampling_metadata.allowed_token_ids_mask)

vllm/v1/sample/metadata.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,6 @@ class SamplingMetadata:
3838

3939
logit_bias: List[Optional[Dict[int, float]]]
4040

41-
# These two parameters are for allowed_token_ids.
42-
# `no_allowed_token_ids`` is a bool to indicate whether we have
43-
# allowed_token_ids.
4441
# `allowed_token_ids_mask` is a 2D bool tensor of shape (max batch size,
4542
# vocab size).
46-
no_allowed_token_ids: bool
4743
allowed_token_ids_mask: Optional[torch.Tensor]

vllm/v1/sample/sampler.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -237,9 +237,7 @@ def apply_allowed_token_ids(
237237
logits: torch.Tensor,
238238
sampling_metadata: SamplingMetadata,
239239
) -> torch.Tensor:
240-
# One idea is implement this as a PyTorch C++ op, and we may
241-
# even optimize the logit_bias layout.
242-
if not sampling_metadata.no_allowed_token_ids:
240+
if sampling_metadata.allowed_token_ids_mask is not None:
243241
logits.masked_fill_(sampling_metadata.allowed_token_ids_mask,
244242
float("-inf"))
245243
return logits

vllm/v1/worker/gpu_input_batch.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ def __init__(
192192

193193
self.logit_bias: List[Optional[Dict[int,
194194
float]]] = [None] * max_num_reqs
195-
self.has_allowed_token_ids: List[bool] = [False] * max_num_reqs
195+
self.has_allowed_token_ids: Set[str] = set()
196196
self.allowed_token_ids_mask: Optional[torch.Tensor] = None
197197
self.allowed_token_ids_mask_cpu_tensor: Optional[torch.Tensor] = None
198198

@@ -297,7 +297,7 @@ def add_request(
297297
for tid in sampling_params.allowed_token_ids):
298298
raise ValueError(
299299
"allowed_token_ids contains out-of-vocab token id")
300-
self.has_allowed_token_ids[req_index] = True
300+
self.has_allowed_token_ids.add(req_id)
301301
if self.allowed_token_ids_mask_cpu_tensor is None:
302302
# Lazy allocation for this tensor, which can be large.
303303
self.allowed_token_ids_mask = torch.zeros(self.max_num_reqs,
@@ -357,7 +357,7 @@ def remove_request(self, req_id: str) -> Optional[int]:
357357
self.request_lora_mapping[req_index] = 0
358358

359359
self.logit_bias[req_index] = None
360-
self.has_allowed_token_ids[req_index] = False
360+
self.has_allowed_token_ids.discard(req_id)
361361
if self.allowed_token_ids_mask_cpu_tensor is not None:
362362
self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False)
363363
return req_index
@@ -428,8 +428,6 @@ def condense(self, empty_req_indices: List[int]) -> None:
428428

429429
self.logit_bias[empty_index] = self.logit_bias[last_req_index]
430430

431-
self.has_allowed_token_ids[
432-
empty_index] = self.has_allowed_token_ids[last_req_index]
433431
if self.allowed_token_ids_mask_cpu_tensor is not None:
434432
self.allowed_token_ids_mask_cpu_tensor[
435433
empty_index] = self.allowed_token_ids_mask_cpu_tensor[
@@ -478,8 +476,8 @@ def _make_sampling_metadata(self) -> SamplingMetadata:
478476
prompt_token_ids = None
479477

480478
allowed_token_ids_mask: Optional[torch.Tensor] = None
481-
if not self.no_allowed_token_ids and \
482-
self.allowed_token_ids_mask is not None:
479+
if not self.no_allowed_token_ids:
480+
assert self.allowed_token_ids_mask is not None
483481
copy_slice(self.allowed_token_ids_mask_cpu_tensor,
484482
self.allowed_token_ids_mask, num_reqs)
485483
allowed_token_ids_mask = self.allowed_token_ids_mask[:num_reqs]
@@ -502,7 +500,6 @@ def _make_sampling_metadata(self) -> SamplingMetadata:
502500
min_tokens=self.min_tokens,
503501
no_penalties=self.no_penalties,
504502
logit_bias=self.logit_bias[:num_reqs],
505-
no_allowed_token_ids=self.no_allowed_token_ids,
506503
allowed_token_ids_mask=allowed_token_ids_mask,
507504
)
508505

@@ -597,4 +594,4 @@ def no_prompt_logprob(self) -> bool:
597594

598595
@property
599596
def no_allowed_token_ids(self) -> bool:
600-
return not any(self.has_allowed_token_ids)
597+
return len(self.has_allowed_token_ids) == 0

0 commit comments

Comments
 (0)