Skip to content

Commit 30172b4

Browse files
authored
[V1] Optimize handling of sampling metadata and req_ids list (#13244)
Signed-off-by: Nick Hill <nhill@redhat.com>
1 parent a4d577b commit 30172b4

File tree

15 files changed

+255
-298
lines changed

15 files changed

+255
-298
lines changed

tests/v1/sample/test_rejection_sampler.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,17 +26,13 @@ def create_logits_tensor(token_ids: List[int],
2626
def create_sampling_metadata(spec_tokens: List[List[int]]) -> SamplingMetadata:
2727
batch_size = len(spec_tokens)
2828
return SamplingMetadata(
29-
temperature=0.0,
29+
temperature=torch.tensor([]),
3030
all_greedy=True,
3131
all_random=False,
32-
rejection_sampling=True,
3332
spec_token_ids=spec_tokens,
3433
top_p=None,
3534
top_k=None,
36-
no_top_p=False,
37-
no_top_k=False,
3835
min_p=torch.empty(batch_size, ),
39-
no_min_p=True,
4036
generators={},
4137
max_num_logprobs=0,
4238
no_penalties=False,
@@ -45,8 +41,7 @@ def create_sampling_metadata(spec_tokens: List[List[int]]) -> SamplingMetadata:
4541
presence_penalties=torch.tensor([]),
4642
repetition_penalties=torch.tensor([]),
4743
output_token_ids=[],
48-
min_tokens=[],
49-
stop_token_ids=[],
44+
min_tokens={},
5045
logit_bias=[None] * batch_size,
5146
)
5247

tests/v1/sample/test_sampler.py

Lines changed: 18 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -77,25 +77,20 @@ def _create_default_sampling_metadata(
7777
temperature=torch.full((batch_size, ), 0.0),
7878
all_greedy=True,
7979
all_random=False,
80-
rejection_sampling=False,
81-
top_p=torch.empty(batch_size, ),
82-
top_k=torch.empty(batch_size, ),
83-
no_top_p=True,
84-
no_top_k=True,
85-
min_p=torch.empty(batch_size, ),
86-
no_min_p=True,
80+
top_p=None,
81+
top_k=None,
82+
min_p=None,
8783
generators={},
8884
max_num_logprobs=0,
8985
prompt_token_ids=_create_prompt_tokens_tensor(prompt_token_ids,
9086
vocab_size, device),
9187
output_token_ids=output_token_ids,
92-
spec_token_ids=[],
88+
spec_token_ids=None,
9389
frequency_penalties=_create_penalty_tensor(batch_size, 0.0, device),
9490
presence_penalties=_create_penalty_tensor(batch_size, 0.0, device),
9591
repetition_penalties=_create_penalty_tensor(batch_size, 1.0, device),
9692
no_penalties=True,
97-
min_tokens=[],
98-
stop_token_ids=[],
93+
min_tokens={},
9994
logit_bias=[None] * batch_size,
10095
)
10196
return fake_sampling_metadata
@@ -104,33 +99,30 @@ def _create_default_sampling_metadata(
10499
def _generate_min_token_penalties_and_stop_tokens(
105100
num_output_tokens: int, batch_size: int, vocab_size: int,
106101
batch_indices_for_min_token_penalty: List[int]
107-
) -> Tuple[List[int], List[Set[int]]]:
102+
) -> Dict[int, Tuple[int, Set[int]]]:
108103
"""
109-
Generates and returns a list of minimum token penalties (`min_tokens`)
110-
and a corresponding list of stop token IDs (`stop_token_ids`) for each
104+
Generates and returns a dict of minimum token penalties and
105+
corresponding stop token IDs (`min_tokens`, `stop_token_ids`) for each
111106
batch.
112107
113108
If a batch index is included in `batch_indices_for_min_token_penalty`,
114109
a higher `min_tokens` value is assigned (within a randomized range),
115110
and a random set of stop token IDs is created. Otherwise, a lower
116111
`min_tokens` value is assigned, and the stop token IDs set is empty.
117112
"""
118-
stop_token_ids: List[Set[int]] = []
119-
min_tokens: List[int] = []
113+
min_tokens: Dict[int, Tuple[int, Set[int]]] = {}
120114
for index in range(batch_size):
121115
if index in batch_indices_for_min_token_penalty:
122-
min_tokens.append(
116+
min_tokens[index] = (
123117
np.random.randint(num_output_tokens + 1,
124-
2 * num_output_tokens))
125-
stop_token_ids.append(
118+
2 * num_output_tokens),
126119
set(
127120
np.random.randint(0, vocab_size - 1)
128121
for _ in range(np.random.randint(0, vocab_size))))
129-
130122
else:
131-
min_tokens.append(np.random.randint(0, num_output_tokens))
132-
stop_token_ids.append(set())
133-
return (min_tokens, stop_token_ids)
123+
min_tokens[index] = (np.random.randint(0,
124+
num_output_tokens), set())
125+
return min_tokens
134126

135127

136128
def _create_weighted_output_token_list(
@@ -165,7 +157,7 @@ def _create_weighted_output_token_list(
165157
output_token_ids_for_batch.extend(
166158
[token_id for _ in range(index + 1)])
167159
output_token_ids.append(output_token_ids_for_batch)
168-
return (output_token_ids, sorted_token_ids_in_output)
160+
return output_token_ids, sorted_token_ids_in_output
169161

170162

171163
@pytest.mark.parametrize("device", CUDA_DEVICES)
@@ -182,17 +174,17 @@ def test_sampler_min_tokens_penalty(device: str, batch_size: int):
182174
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
183175
batch_indices_for_min_token_penalty = np.random.randint(
184176
0, batch_size - 1, size=np.random.randint(0, batch_size)).tolist()
185-
min_tokens, stop_token_ids = _generate_min_token_penalties_and_stop_tokens(
177+
min_tokens = _generate_min_token_penalties_and_stop_tokens(
186178
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE,
187179
batch_indices_for_min_token_penalty)
188180
sampling_metadata.min_tokens = min_tokens
189-
sampling_metadata.stop_token_ids = stop_token_ids
190181
sampler = Sampler()
191182
logits = sampler.apply_penalties(fake_logits, sampling_metadata)
192183
logits = logits.cpu()
193184
for batch_idx in range(batch_size):
194185
for token_id in range(VOCAB_SIZE):
195-
if token_id in stop_token_ids[batch_idx]:
186+
_, stop_token_ids = min_tokens.get(batch_idx, (0, set()))
187+
if token_id in stop_token_ids:
196188
assert logits[batch_idx][token_id] == -float("inf")
197189
else:
198190
assert logits[batch_idx][token_id] != -float("inf")

tests/v1/worker/test_gpu_input_batch.py

Lines changed: 21 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3-
from typing import Dict, List, Set, Tuple
3+
from typing import Dict, List, Optional, Set, Tuple
44

55
import numpy as np
66
import pytest
@@ -41,7 +41,7 @@ def _remove_requests(
4141
for index in req_indices_to_remove:
4242
input_batch.remove_request(reqs[index].req_id)
4343
req_ids_to_remove.add(reqs[index].req_id)
44-
return (req_ids_to_remove, req_indices_to_remove_list)
44+
return req_ids_to_remove, req_indices_to_remove_list
4545

4646

4747
def _construct_expected_sampling_metadata(
@@ -64,8 +64,7 @@ def _construct_expected_sampling_metadata(
6464
top_p = [0.0 for _ in range(num_reqs)]
6565
min_p = [0.0 for _ in range(num_reqs)]
6666
temperature = [0.0 for _ in range(num_reqs)]
67-
stop_token_ids: List[Set[int]] = [set() for _ in range(num_reqs)]
68-
min_tokens = [0 for _ in range(num_reqs)]
67+
min_tokens = {}
6968
logit_bias = [None] * num_reqs
7069
for req in reqs:
7170
if req.req_id not in req_ids_retained:
@@ -83,22 +82,21 @@ def _construct_expected_sampling_metadata(
8382
top_p[index_in_input_batch] = req.sampling_params.top_p
8483
min_p[index_in_input_batch] = req.sampling_params.min_p
8584
temperature[index_in_input_batch] = req.sampling_params.temperature
86-
stop_token_ids[
87-
index_in_input_batch] = req.sampling_params.all_stop_token_ids
88-
min_tokens[index_in_input_batch] = req.sampling_params.min_tokens
85+
min_tokens[index_in_input_batch] = (
86+
req.sampling_params.min_tokens,
87+
req.sampling_params.all_stop_token_ids)
8988
logit_bias[index_in_input_batch] = req.sampling_params.logit_bias
9089
return SamplingMetadata(
9190
temperature=torch.tensor(temperature, dtype=torch.float,
9291
device=device),
9392
all_greedy=False,
9493
all_random=True,
95-
rejection_sampling=False,
96-
top_p=torch.tensor(top_p, dtype=torch.float, device=device),
97-
top_k=torch.tensor(top_k, dtype=torch.int, device=device),
98-
no_top_p=all(x == 1.0 for x in top_p),
99-
no_top_k=all(x == 0 for x in top_k),
100-
min_p=torch.tensor(min_p, dtype=torch.float, device=device),
101-
no_min_p=all(x == 0.0 for x in min_p),
94+
top_p=None if all(x == 1.0 for x in top_p) else torch.tensor(
95+
top_p, dtype=torch.float, device=device),
96+
top_k=None if all(x == 0 for x in top_k) else torch.tensor(
97+
top_k, dtype=torch.int, device=device),
98+
min_p=None if all(x == 0.0 for x in min_p) else torch.tensor(
99+
min_p, dtype=torch.float, device=device),
102100
generators={},
103101
max_num_logprobs=0,
104102
prompt_token_ids=make_tensor_with_pad(
@@ -117,9 +115,8 @@ def _construct_expected_sampling_metadata(
117115
dtype=torch.float,
118116
device=device),
119117
output_token_ids=output_token_ids,
120-
spec_token_ids=[],
118+
spec_token_ids=None,
121119
min_tokens=min_tokens,
122-
stop_token_ids=stop_token_ids,
123120
no_penalties=(all(x == 0 for x in presence_penalties)
124121
and all(x == 0 for x in frequency_penalties)
125122
and all(x == 1 for x in repetition_penalties)),
@@ -206,8 +203,7 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
206203
input_batch.condense(req_indices_to_remove)
207204

208205
# Generate the sampling metadata
209-
sampling_metadata = input_batch.make_sampling_metadata(
210-
req_id_output_token_ids, req_id_to_spec_token_ids={}, skip_copy=False)
206+
sampling_metadata = input_batch._make_sampling_metadata()
211207

212208
# Create expected output.
213209
expected_sampling_metadata = _construct_expected_sampling_metadata(
@@ -216,13 +212,16 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
216212
input_batch.req_id_to_index,
217213
device=torch.device(device))
218214

215+
def same(t1: Optional[torch.Tensor], t2: Optional[torch.Tensor]) -> bool:
216+
return (t1 is None
217+
and t2 is None) or (t1 is not None and t2 is not None
218+
and torch.allclose(t1, t2))
219+
219220
# Assert the actual and expected output.
220221
assert torch.allclose(expected_sampling_metadata.temperature,
221222
sampling_metadata.temperature)
222-
assert torch.allclose(expected_sampling_metadata.top_p,
223-
sampling_metadata.top_p)
224-
assert torch.allclose(expected_sampling_metadata.top_k,
225-
sampling_metadata.top_k)
223+
assert same(expected_sampling_metadata.top_p, sampling_metadata.top_p)
224+
assert same(expected_sampling_metadata.top_k, sampling_metadata.top_k)
226225
assert torch.allclose(
227226
expected_sampling_metadata.frequency_penalties,
228227
sampling_metadata.frequency_penalties,
@@ -240,10 +239,6 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
240239
assert (expected_sampling_metadata.output_token_ids ==
241240
sampling_metadata.output_token_ids)
242241
assert expected_sampling_metadata.min_tokens == sampling_metadata.min_tokens
243-
assert expected_sampling_metadata.stop_token_ids == \
244-
sampling_metadata.stop_token_ids
245242
assert expected_sampling_metadata.no_penalties == \
246243
sampling_metadata.no_penalties
247-
assert expected_sampling_metadata.no_top_p == sampling_metadata.no_top_p
248-
assert expected_sampling_metadata.no_top_k == sampling_metadata.no_top_k
249244
assert expected_sampling_metadata.logit_bias == sampling_metadata.logit_bias

tests/v1/worker/test_gpu_model_runner.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from vllm.sampling_params import SamplingParams
66
from vllm.v1.core.scheduler_output import (CachedRequestData, NewRequestData,
77
SchedulerOutput)
8+
from vllm.v1.sample.metadata import SamplingMetadata
89
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
910

1011

@@ -82,14 +83,21 @@ def _is_req_added(model_runner, req_id: str) -> bool:
8283
return req_id in model_runner.requests
8384

8485

86+
def _is_sampling_metadata_changed(model_runner,
87+
sampling_metadata_before: SamplingMetadata):
88+
return model_runner.input_batch.sampling_metadata is not (
89+
sampling_metadata_before)
90+
91+
8592
def test_update_states_new_request(model_runner):
8693
req_id = "req_0"
8794

8895
# new req
8996
scheduler_output = _schedule_new_request(req_id)
9097

91-
batch_changed = model_runner._update_states(scheduler_output)
92-
assert batch_changed is True
98+
metadata_before = model_runner.input_batch.sampling_metadata
99+
model_runner._update_states(scheduler_output)
100+
assert _is_sampling_metadata_changed(model_runner, metadata_before)
93101
assert _is_req_added(model_runner, req_id)
94102
assert _is_req_scheduled(model_runner, req_id)
95103

@@ -117,8 +125,9 @@ def test_update_states_request_finished(model_runner):
117125
free_encoder_input_ids=[],
118126
)
119127

120-
batch_changed = model_runner._update_states(scheduler_output)
121-
assert batch_changed is True
128+
metadata_before = model_runner.input_batch.sampling_metadata
129+
model_runner._update_states(scheduler_output)
130+
assert _is_sampling_metadata_changed(model_runner, metadata_before)
122131
assert not _is_req_added(model_runner, req_id)
123132
assert not _is_req_scheduled(model_runner, req_id)
124133

@@ -142,7 +151,7 @@ def test_update_states_request_resumed(model_runner):
142151
scheduled_spec_decode_tokens={},
143152
scheduled_encoder_inputs={},
144153
num_common_prefix_blocks=0,
145-
finished_req_ids={},
154+
finished_req_ids=set(),
146155
free_encoder_input_ids=[],
147156
)
148157

@@ -171,8 +180,9 @@ def test_update_states_request_resumed(model_runner):
171180
free_encoder_input_ids=[],
172181
)
173182

174-
batch_changed = model_runner._update_states(scheduler_output)
175-
assert batch_changed is True
183+
metadata_before = model_runner.input_batch.sampling_metadata
184+
model_runner._update_states(scheduler_output)
185+
assert _is_sampling_metadata_changed(model_runner, metadata_before)
176186
assert _is_req_added(model_runner, req_id)
177187
assert _is_req_scheduled(model_runner, req_id)
178188

@@ -200,8 +210,9 @@ def test_update_states_no_changes(model_runner):
200210
free_encoder_input_ids=[],
201211
)
202212

203-
batch_changed = model_runner._update_states(scheduler_output)
204-
assert batch_changed is False
213+
metadata_before = model_runner.input_batch.sampling_metadata
214+
model_runner._update_states(scheduler_output)
215+
assert not _is_sampling_metadata_changed(model_runner, metadata_before)
205216
assert _is_req_added(model_runner, req_id)
206217
assert _is_req_scheduled(model_runner, req_id)
207218

@@ -233,8 +244,8 @@ def test_update_states_request_unscheduled(model_runner):
233244
free_encoder_input_ids=[],
234245
)
235246

236-
batch_changed = model_runner._update_states(scheduler_output)
237-
assert batch_changed is True
247+
metadata_before = model_runner._update_states(scheduler_output)
248+
assert _is_sampling_metadata_changed(model_runner, metadata_before)
238249

239250
assert _is_req_added(model_runner, req_ids[0])
240251
assert _is_req_scheduled(model_runner, req_ids[0])

vllm/model_executor/layers/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,14 +45,14 @@ def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
4545
vocab_size, num_seqs)
4646
output_bin_counts, output_mask = get_token_bin_counts_and_mask(
4747
output_tokens_tensor, vocab_size, num_seqs)
48-
repetition_penalties = repetition_penalties.unsqueeze_(dim=1).repeat(
48+
repetition_penalties = repetition_penalties.unsqueeze(dim=1).repeat(
4949
1, vocab_size)
5050
logits[logits > 0] /= torch.where(prompt_mask | output_mask,
5151
repetition_penalties, 1.0)[logits > 0]
5252
logits[logits <= 0] *= torch.where(prompt_mask | output_mask,
5353
repetition_penalties, 1.0)[logits <= 0]
5454
# We follow the definition in OpenAI API.
5555
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
56-
logits -= frequency_penalties.unsqueeze_(dim=1) * output_bin_counts
57-
logits -= presence_penalties.unsqueeze_(dim=1) * output_mask
56+
logits -= frequency_penalties.unsqueeze(dim=1) * output_bin_counts
57+
logits -= presence_penalties.unsqueeze(dim=1) * output_mask
5858
return logits

vllm/v1/core/scheduler.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,8 +195,10 @@ def schedule(self) -> "SchedulerOutput":
195195
request.num_computed_tokens -
196196
request.num_tokens)
197197
if num_scheduled_spec_tokens > 0:
198+
# Trim spec_token_ids list to num_scheduled_spec_tokens.
199+
del request.spec_token_ids[num_scheduled_spec_tokens:]
198200
scheduled_spec_decode_tokens[request.request_id] = (
199-
request.spec_token_ids[:num_scheduled_spec_tokens])
201+
request.spec_token_ids)
200202

201203
# Encoder-related.
202204
if encoder_inputs_to_schedule:
@@ -567,7 +569,7 @@ def update_from_output(
567569
outputs.append(
568570
EngineCoreOutput(
569571
request_id=req_id,
570-
new_token_ids=new_token_ids or [],
572+
new_token_ids=new_token_ids,
571573
finish_reason=request.get_finished_reason(),
572574
new_logprobs=new_logprobs,
573575
new_prompt_logprobs_tensors=prompt_logprobs_tensors,

0 commit comments

Comments
 (0)