Skip to content

Commit 6224a9f

Browse files
authored
Support logit_bias in v1 Sampler (#13079)
1 parent 085b7b2 commit 6224a9f

File tree

6 files changed

+200
-101
lines changed

6 files changed

+200
-101
lines changed

tests/v1/sample/test_sampler.py

Lines changed: 59 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3-
from typing import List, Set, Tuple
3+
from typing import Dict, List, Optional, Set, Tuple
44

55
import numpy as np
66
import pytest
@@ -45,6 +45,18 @@ def _create_prompt_tokens_tensor(
4545
)
4646

4747

48+
def _create_logit_bias(
49+
batch_size: int,
50+
vocab_size: int,
51+
bias_value: float,
52+
) -> List[Optional[Dict[int, float]]]:
53+
res: List[Optional[Dict[int, float]]] = []
54+
for i in range(batch_size):
55+
logit_bias = {min(i, vocab_size - 1): bias_value}
56+
res.append(logit_bias)
57+
return res
58+
59+
4860
def _create_default_sampling_metadata(
4961
num_output_tokens: int,
5062
batch_size: int,
@@ -80,6 +92,7 @@ def _create_default_sampling_metadata(
8092
no_penalties=True,
8193
min_tokens=[],
8294
stop_token_ids=[],
95+
logit_bias=[None] * batch_size,
8396
)
8497
return fake_sampling_metadata
8598

@@ -89,14 +102,14 @@ def _generate_min_token_penalties_and_stop_tokens(
89102
batch_indices_for_min_token_penalty: List[int]
90103
) -> Tuple[List[int], List[Set[int]]]:
91104
"""
92-
Generates and returns a list of minimum token penalties (`min_tokens`)
93-
and a corresponding list of stop token IDs (`stop_token_ids`) for each
105+
Generates and returns a list of minimum token penalties (`min_tokens`)
106+
and a corresponding list of stop token IDs (`stop_token_ids`) for each
94107
batch.
95108
96-
If a batch index is included in `batch_indices_for_min_token_penalty`,
97-
a higher `min_tokens` value is assigned (within a randomized range),
98-
and a random set of stop token IDs is created. Otherwise, a lower
99-
`min_tokens` value is assigned, and the stop token IDs set is empty.
109+
If a batch index is included in `batch_indices_for_min_token_penalty`,
110+
a higher `min_tokens` value is assigned (within a randomized range),
111+
and a random set of stop token IDs is created. Otherwise, a lower
112+
`min_tokens` value is assigned, and the stop token IDs set is empty.
100113
"""
101114
stop_token_ids: List[Set[int]] = []
102115
min_tokens: List[int] = []
@@ -120,7 +133,7 @@ def _create_weighted_output_token_list(
120133
batch_size: int,
121134
vocab_size: int) -> Tuple[List[List[int]], List[List[int]]]:
122135
"""
123-
Creates an output token list where each token occurs a distinct
136+
Creates an output token list where each token occurs a distinct
124137
number of times.
125138
126139
For each batch, a random subset of token IDs is selected from the
@@ -129,8 +142,8 @@ def _create_weighted_output_token_list(
129142
130143
Returns:
131144
Tuple[List[List[int]], List[List[int]]]:
132-
- The first element is the output token list, where each sublist
133-
corresponds to a batch and contains tokens with weighted
145+
- The first element is the output token list, where each sublist
146+
corresponds to a batch and contains tokens with weighted
134147
frequencies.
135148
- The second element is a list of distinct token IDs for each
136149
batch, ordered by their frequency in the corresponding output
@@ -155,7 +168,7 @@ def _create_weighted_output_token_list(
155168
@pytest.mark.parametrize("batch_size", [1, 2, 32])
156169
def test_sampler_min_tokens_penalty(device: str, batch_size: int):
157170
"""
158-
Tests that if the number of output tokens is less than
171+
Tests that if the number of output tokens is less than
159172
SamplingParams.min_tokens then we will set the logits for
160173
the stop token ids to -inf.
161174
"""
@@ -283,7 +296,7 @@ def test_sampler_frequency_penalty(device: str, batch_size: int,
283296
def test_sampler_repetition_penalty(device: str, batch_size: int,
284297
repetition_penalty: float):
285298
"""
286-
Test to verify that when the repetition penalty is enabled, tokens
299+
Test to verify that when the repetition penalty is enabled, tokens
287300
are penalized based on their presence in the prompt or the existing
288301
output.
289302
"""
@@ -321,3 +334,37 @@ def test_sampler_repetition_penalty(device: str, batch_size: int,
321334
penalized_token_id not in output_tokens)
322335
assert (non_penalized_token_id in prompt_tokens or \
323336
non_penalized_token_id in output_tokens)
337+
338+
339+
@pytest.mark.parametrize("device", CUDA_DEVICES)
340+
@pytest.mark.parametrize("batch_size", [1, 2, 32])
341+
@pytest.mark.parametrize("bias_value", [-0.1, 1.2])
342+
def test_sampler_logit_bias(device: str, batch_size: int, bias_value: float):
343+
"""
344+
Test to verify that when the repetition penalty is enabled, tokens
345+
are penalized based on their presence in the prompt or the existing
346+
output.
347+
"""
348+
torch.set_default_device(device)
349+
# Create fake logits where each token is assigned the same
350+
# logit value.
351+
fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE)
352+
sampling_metadata = _create_default_sampling_metadata(
353+
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
354+
sampling_metadata.logit_bias = _create_logit_bias(
355+
batch_size=batch_size,
356+
vocab_size=VOCAB_SIZE,
357+
bias_value=bias_value,
358+
)
359+
sampler = Sampler()
360+
logits = sampler.apply_logits_bias(fake_logits, sampling_metadata)
361+
logits = logits.cpu()
362+
for batch_idx in range(batch_size):
363+
logits_for_req = logits[batch_idx]
364+
biased_index = min(batch_idx, VOCAB_SIZE - 1)
365+
for token_id in range(VOCAB_SIZE):
366+
if biased_index == token_id:
367+
assert logits_for_req[token_id] == pytest.approx(bias_value +
368+
1e-2)
369+
else:
370+
assert logits_for_req[token_id] == pytest.approx(1e-2)

tests/v1/worker/test_gpu_input_batch.py

Lines changed: 80 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,11 @@ def _remove_requests(
4545

4646

4747
def _construct_expected_sampling_metadata(
48-
reqs: List[CachedRequestState], req_ids_retained: Set[int],
49-
req_id_index_in_input_batch: Dict[str, int],
50-
device: torch.device) -> SamplingMetadata:
48+
reqs: List[CachedRequestState],
49+
req_ids_retained: Set[int],
50+
req_id_index_in_input_batch: Dict[str, int],
51+
device: torch.device,
52+
) -> SamplingMetadata:
5153
"""
5254
Constructs and returns the expected SamplingMetadata for this
5355
batch.
@@ -63,6 +65,7 @@ def _construct_expected_sampling_metadata(
6365
temperature = [0.0 for _ in range(num_reqs)]
6466
stop_token_ids: List[Set[int]] = [set() for _ in range(num_reqs)]
6567
min_tokens = [0 for _ in range(num_reqs)]
68+
logit_bias = [None] * num_reqs
6669
for req in reqs:
6770
if req.req_id not in req_ids_retained:
6871
continue
@@ -71,20 +74,21 @@ def _construct_expected_sampling_metadata(
7174
prompt_token_ids[index_in_input_batch] = req.prompt_token_ids
7275
presence_penalties[
7376
index_in_input_batch] = req.sampling_params.presence_penalty
74-
frequency_penalties[
75-
index_in_input_batch] = req.sampling_params.frequency_penalty
76-
repetition_penalties[
77-
index_in_input_batch] = req.sampling_params.repetition_penalty
77+
frequency_penalties[index_in_input_batch] = (
78+
req.sampling_params.frequency_penalty)
79+
repetition_penalties[index_in_input_batch] = (
80+
req.sampling_params.repetition_penalty)
7881
top_k[index_in_input_batch] = req.sampling_params.top_k
7982
top_p[index_in_input_batch] = req.sampling_params.top_p
8083
temperature[index_in_input_batch] = req.sampling_params.temperature
8184
stop_token_ids[
8285
index_in_input_batch] = req.sampling_params.all_stop_token_ids
8386
min_tokens[index_in_input_batch] = req.sampling_params.min_tokens
84-
87+
logit_bias[index_in_input_batch] = req.sampling_params.logit_bias
8588

8689
return SamplingMetadata(
87-
temperature=torch.tensor(temperature, dtype=torch.float, device=device),
90+
temperature=torch.tensor(temperature, dtype=torch.float,
91+
device=device),
8892
all_greedy=False,
8993
all_random=True,
9094
top_p=torch.tensor(top_p, dtype=torch.float, device=device),
@@ -93,41 +97,45 @@ def _construct_expected_sampling_metadata(
9397
no_top_k=all(x == 0 for x in top_k),
9498
generators={},
9599
max_num_logprobs=0,
96-
prompt_token_ids= make_tensor_with_pad(
100+
prompt_token_ids=make_tensor_with_pad(
97101
prompt_token_ids,
98102
pad=VOCAB_SIZE,
99103
device=torch.device(device),
100104
dtype=torch.int64,
101105
),
102-
frequency_penalties=torch.tensor(
103-
frequency_penalties, dtype=torch.float,
104-
device=device),
105-
presence_penalties=torch.tensor(
106-
presence_penalties, dtype=torch.float,
107-
device=device),
108-
repetition_penalties=torch.tensor(
109-
repetition_penalties, dtype=torch.float,
110-
device=device),
106+
frequency_penalties=torch.tensor(frequency_penalties,
107+
dtype=torch.float,
108+
device=device),
109+
presence_penalties=torch.tensor(presence_penalties,
110+
dtype=torch.float,
111+
device=device),
112+
repetition_penalties=torch.tensor(repetition_penalties,
113+
dtype=torch.float,
114+
device=device),
111115
output_token_ids=output_token_ids,
112116
min_tokens=min_tokens,
113117
stop_token_ids=stop_token_ids,
114-
no_penalties=(all(x ==0 for x in presence_penalties) and \
115-
all(x ==0 for x in frequency_penalties) and \
116-
all(x ==1 for x in repetition_penalties))
118+
no_penalties=(all(x == 0 for x in presence_penalties)
119+
and all(x == 0 for x in frequency_penalties)
120+
and all(x == 1 for x in repetition_penalties)),
121+
logit_bias=logit_bias,
117122
)
118123

119124

120125
def _create_sampling_params():
121-
return SamplingParams(top_k=np.random.randint(1, 10),
122-
top_p=np.random.uniform(0.0, 1.0),
123-
presence_penalty=np.random.uniform(-2.0, 2.0),
124-
repetition_penalty=np.random.uniform(0.0, 2.0),
125-
frequency_penalty=np.random.uniform(-2.0, 2.0),
126-
min_tokens=np.random.randint(1, 10),
127-
stop_token_ids=[
128-
np.random.randint(0, VOCAB_SIZE)
129-
for _ in range(np.random.randint(10))
130-
])
126+
return SamplingParams(
127+
top_k=np.random.randint(1, 10),
128+
top_p=np.random.uniform(0.0, 1.0),
129+
presence_penalty=np.random.uniform(-2.0, 2.0),
130+
repetition_penalty=np.random.uniform(0.0, 2.0),
131+
frequency_penalty=np.random.uniform(-2.0, 2.0),
132+
min_tokens=np.random.randint(1, 10),
133+
stop_token_ids=[
134+
np.random.randint(0, VOCAB_SIZE)
135+
for _ in range(np.random.randint(10))
136+
],
137+
logit_bias={0: np.random.uniform(-3.0, 3.0)},
138+
)
131139

132140

133141
def _construct_cached_request_state(req_id_suffix: int):
@@ -139,16 +147,18 @@ def _construct_cached_request_state(req_id_suffix: int):
139147
np.random.randint(0, VOCAB_SIZE)
140148
for _ in range(np.random.randint(0, NUM_OUTPUT_TOKENS))
141149
]
142-
return CachedRequestState(req_id=f"req_id_{req_id_suffix}",
143-
prompt_token_ids=prompt_token_ids,
144-
prompt=None,
145-
sampling_params=_create_sampling_params(),
146-
mm_inputs=[],
147-
mm_positions=[],
148-
block_ids=[],
149-
generator=None,
150-
num_computed_tokens=len(output_token_ids),
151-
output_token_ids=output_token_ids)
150+
return CachedRequestState(
151+
req_id=f"req_id_{req_id_suffix}",
152+
prompt_token_ids=prompt_token_ids,
153+
prompt=None,
154+
sampling_params=_create_sampling_params(),
155+
mm_inputs=[],
156+
mm_positions=[],
157+
block_ids=[],
158+
generator=None,
159+
num_computed_tokens=len(output_token_ids),
160+
output_token_ids=output_token_ids,
161+
)
152162

153163

154164
@pytest.mark.parametrize("device", CUDA_DEVICES)
@@ -163,12 +173,14 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
163173
output of `make_sampling_metadata` is then compared against the expected
164174
results to ensure correctness.
165175
"""
166-
input_batch: InputBatch = InputBatch(max_num_reqs=batch_size,
167-
max_model_len=1024,
168-
max_num_blocks_per_req=10,
169-
device=torch.device(device),
170-
pin_memory=is_pin_memory_available(),
171-
vocab_size=1024)
176+
input_batch: InputBatch = InputBatch(
177+
max_num_reqs=batch_size,
178+
max_model_len=1024,
179+
max_num_blocks_per_req=10,
180+
device=torch.device(device),
181+
pin_memory=is_pin_memory_available(),
182+
vocab_size=1024,
183+
)
172184
reqs: List[CachedRequestState] = []
173185
req_id_reqs = {}
174186
req_id_output_token_ids = {}
@@ -206,21 +218,27 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
206218
sampling_metadata.top_p)
207219
assert torch.allclose(expected_sampling_metadata.top_k,
208220
sampling_metadata.top_k)
209-
assert torch.allclose(expected_sampling_metadata.frequency_penalties,
210-
sampling_metadata.frequency_penalties)
211-
assert torch.allclose(expected_sampling_metadata.presence_penalties,
212-
sampling_metadata.presence_penalties)
213-
assert torch.allclose(expected_sampling_metadata.repetition_penalties,
214-
sampling_metadata.repetition_penalties)
221+
assert torch.allclose(
222+
expected_sampling_metadata.frequency_penalties,
223+
sampling_metadata.frequency_penalties,
224+
)
225+
assert torch.allclose(
226+
expected_sampling_metadata.presence_penalties,
227+
sampling_metadata.presence_penalties,
228+
)
229+
assert torch.allclose(
230+
expected_sampling_metadata.repetition_penalties,
231+
sampling_metadata.repetition_penalties,
232+
)
215233
assert torch.allclose(expected_sampling_metadata.prompt_token_ids,
216234
sampling_metadata.prompt_token_ids)
217235
assert (expected_sampling_metadata.output_token_ids ==
218236
sampling_metadata.output_token_ids)
219-
assert (
220-
expected_sampling_metadata.min_tokens == sampling_metadata.min_tokens)
221-
assert (expected_sampling_metadata.stop_token_ids ==
222-
sampling_metadata.stop_token_ids)
223-
assert (expected_sampling_metadata.no_penalties ==
224-
sampling_metadata.no_penalties)
225-
assert (expected_sampling_metadata.no_top_p == sampling_metadata.no_top_p)
226-
assert (expected_sampling_metadata.no_top_k == sampling_metadata.no_top_k)
237+
assert expected_sampling_metadata.min_tokens == sampling_metadata.min_tokens
238+
assert expected_sampling_metadata.stop_token_ids == \
239+
sampling_metadata.stop_token_ids
240+
assert expected_sampling_metadata.no_penalties == \
241+
sampling_metadata.no_penalties
242+
assert expected_sampling_metadata.no_top_p == sampling_metadata.no_top_p
243+
assert expected_sampling_metadata.no_top_k == sampling_metadata.no_top_k
244+
assert expected_sampling_metadata.logit_bias == sampling_metadata.logit_bias

vllm/sampling_params.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,8 +243,10 @@ def from_optional(
243243
allowed_token_ids: Optional[List[int]] = None,
244244
) -> "SamplingParams":
245245
if logit_bias is not None:
246+
# Convert token_id to integer
247+
# Clamp the bias between -100 and 100 per OpenAI API spec
246248
logit_bias = {
247-
int(token): bias
249+
int(token): min(100.0, max(-100.0, bias))
248250
for token, bias in logit_bias.items()
249251
}
250252

vllm/v1/sample/metadata.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,5 @@ class SamplingMetadata:
3232
output_token_ids: List[List[int]]
3333
min_tokens: List[int]
3434
stop_token_ids: List[Set[int]]
35+
36+
logit_bias: List[Optional[Dict[int, float]]]

vllm/v1/sample/sampler.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ def forward(
3737

3838
# Use float32 for the logits.
3939
logits = logits.to(torch.float32)
40+
# Apply logits bias.
41+
logits = self.apply_logits_bias(logits, sampling_metadata)
4042
# Apply penalties (e.g., min_tokens, freq_penalties).
4143
logits = self.apply_penalties(logits, sampling_metadata)
4244
# Apply temperature.
@@ -166,3 +168,17 @@ def apply_penalties(
166168
sampling_metadata.repetition_penalties,
167169
sampling_metadata.output_token_ids)
168170
return logits
171+
172+
def apply_logits_bias(
173+
self,
174+
logits: torch.Tensor,
175+
sampling_metadata: SamplingMetadata,
176+
) -> torch.Tensor:
177+
# TODO(houseroad): this implementation is extremely inefficient.
178+
# One idea is implement this as a PyTorch C++ op, and we may
179+
# even optimize the logit_bias layout.
180+
for i, logit_bias in enumerate(sampling_metadata.logit_bias):
181+
if logit_bias:
182+
for token_id, bias in logit_bias.items():
183+
logits[i, token_id] += bias
184+
return logits

0 commit comments

Comments
 (0)