Skip to content

Commit eb8b5eb

Browse files
22quinnnjhill
andauthored
[V1] Support bad_words in sampler (#13376)
Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com> Co-authored-by: Nick Hill <nhill@redhat.com>
1 parent 9513290 commit eb8b5eb

File tree

13 files changed

+266
-28
lines changed

13 files changed

+266
-28
lines changed

tests/test_utils.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from vllm.utils import (FlexibleArgumentParser, MemorySnapshot,
1515
PlaceholderModule, StoreBoolean, bind_kv_cache,
1616
deprecate_kwargs, get_open_port, memory_profiling,
17-
merge_async_iterators, supports_kw)
17+
merge_async_iterators, supports_kw, swap_dict_values)
1818

1919
from .utils import error_on_warning, fork_new_process_for_each_test
2020

@@ -449,3 +449,26 @@ def build_ctx():
449449
with build_ctx():
450450
# Test conflict with internal __module attribute
451451
_ = placeholder_attr.module
452+
453+
454+
@pytest.mark.parametrize(
455+
"obj,key1,key2",
456+
[
457+
# Tests for both keys exist
458+
({1: "a", 2: "b"}, 1, 2),
459+
# Tests for one key does not exist
460+
({1: "a", 2: "b"}, 1, 3),
461+
# Tests for both keys do not exist
462+
({1: "a", 2: "b"}, 3, 4),
463+
])
464+
def test_swap_dict_values(obj, key1, key2):
465+
original_obj = obj.copy()
466+
swap_dict_values(obj, key1, key2)
467+
if key1 in original_obj:
468+
assert obj[key2] == original_obj[key1]
469+
else:
470+
assert key2 not in obj
471+
if key2 in original_obj:
472+
assert obj[key1] == original_obj[key2]
473+
else:
474+
assert key1 not in obj

tests/v1/sample/test_rejection_sampler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def create_sampling_metadata(spec_tokens: list[list[int]]) -> SamplingMetadata:
4242
min_tokens={},
4343
logit_bias=[None] * batch_size,
4444
allowed_token_ids_mask=None,
45+
bad_words_token_ids={},
4546
)
4647

4748

tests/v1/sample/test_sampler.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,49 @@ def _create_allowed_token_ids(
7777
return mask
7878

7979

80+
def _create_bad_words_token_ids(
81+
batch_size: int, vocab_size: int,
82+
bad_words_lengths: list[tuple[int]]) -> dict[int, list[list[int]]]:
83+
bad_words_token_ids = {}
84+
for batch_idx in range(batch_size):
85+
token_ids_single_batch = []
86+
for bad_words_length in bad_words_lengths:
87+
token_ids = np.random.choice(vocab_size,
88+
size=bad_words_length,
89+
replace=True).tolist()
90+
token_ids_single_batch.append(token_ids)
91+
bad_words_token_ids[batch_idx] = token_ids_single_batch
92+
if batch_size >= 2:
93+
# Test no bad_words for some batch
94+
no_bad_words_batch_idx = np.random.choice(batch_size)
95+
bad_words_token_ids.pop(no_bad_words_batch_idx, None)
96+
return bad_words_token_ids
97+
98+
99+
def _update_output_token_ids_for_bad_words(
100+
metadata: SamplingMetadata, vocab_size: int) -> dict[int, list[int]]:
101+
bad_words_last_tokens = {}
102+
for batch_idx, bad_words_token_ids in metadata.bad_words_token_ids.items():
103+
output_token_ids = metadata.output_token_ids[batch_idx]
104+
bad_words_last_token: list[int] = []
105+
for i, bad_word_token_ids in enumerate(bad_words_token_ids):
106+
if len(bad_word_token_ids) == 1:
107+
# Single token id always affects logits
108+
bad_words_last_token.append(bad_word_token_ids[0])
109+
else:
110+
prefix_length = len(bad_word_token_ids) - 1
111+
has_bad_words = np.random.choice([True, False])
112+
if has_bad_words:
113+
output_token_ids[-prefix_length:] = bad_word_token_ids[:-1]
114+
bad_words_last_token.append(bad_word_token_ids[-1])
115+
break # Maximum one update to output_token_ids
116+
else: # Make sure no accidental match to bad words
117+
output_token_ids[-1] = (bad_word_token_ids[-2] +
118+
1) % vocab_size
119+
bad_words_last_tokens[batch_idx] = bad_words_last_token
120+
return bad_words_last_tokens
121+
122+
80123
def _create_default_sampling_metadata(
81124
num_output_tokens: int,
82125
batch_size: int,
@@ -112,6 +155,7 @@ def _create_default_sampling_metadata(
112155
min_tokens={},
113156
logit_bias=[None] * batch_size,
114157
allowed_token_ids_mask=None,
158+
bad_words_token_ids={},
115159
)
116160
return fake_sampling_metadata
117161

@@ -467,3 +511,35 @@ def test_sampler_allowed_token_ids(device: str, batch_size: int,
467511
"inf"), f"{batch_idx}, {token_id}"
468512
else:
469513
assert logits_for_req[token_id] != -float("inf")
514+
515+
516+
@pytest.mark.parametrize("device", CUDA_DEVICES)
517+
@pytest.mark.parametrize("batch_size", [1, 2, 32])
518+
@pytest.mark.parametrize("bad_words_lengths", [(1, ), (1, 3), (2, 2)])
519+
def test_sampler_bad_words(device: str, batch_size: int,
520+
bad_words_lengths: list[tuple[int]]):
521+
"""
522+
Test to verify that when the bad words restriction is present, tokens
523+
are penalized based on their match with the bad words.
524+
"""
525+
torch.set_default_device(device)
526+
# Create fake logits where each token is assigned the same
527+
# logit value.
528+
fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE)
529+
sampling_metadata = _create_default_sampling_metadata(
530+
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
531+
sampling_metadata.bad_words_token_ids = _create_bad_words_token_ids(
532+
batch_size, VOCAB_SIZE, bad_words_lengths)
533+
bad_words_last_tokens = _update_output_token_ids_for_bad_words(
534+
sampling_metadata, VOCAB_SIZE)
535+
sampler = Sampler()
536+
logits = sampler.apply_bad_words(fake_logits, sampling_metadata)
537+
logits = logits.cpu()
538+
for batch_idx in range(batch_size):
539+
logits_for_req = logits[batch_idx]
540+
for token_id in range(VOCAB_SIZE):
541+
if (batch_idx in bad_words_last_tokens
542+
and token_id in bad_words_last_tokens[batch_idx]):
543+
assert logits_for_req[token_id] == -float("inf")
544+
else:
545+
assert logits_for_req[token_id] != -float("inf")

tests/v1/sample/test_sampling_params_e2e.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,22 @@ def test_detokenize_false(model):
120120
def test_bad_words(model):
121121
"""Check that we respect bad words."""
122122

123-
with pytest.raises(ValueError):
124-
_ = model.generate(PROMPT, SamplingParams(bad_words=["Hello"]))
123+
output = model.generate(PROMPT, SamplingParams(temperature=0))
124+
split_text = output[0].outputs[0].text.split()
125+
126+
bad_words_1 = " ".join(split_text[:2])
127+
params = SamplingParams(temperature=0, bad_words=[bad_words_1])
128+
output = model.generate(PROMPT, params)
129+
new_text = output[0].outputs[0].text
130+
assert bad_words_1 not in new_text
131+
132+
bad_words_2 = new_text.split()[-1]
133+
params = SamplingParams(temperature=0,
134+
bad_words=[bad_words_1, bad_words_2])
135+
output = model.generate(PROMPT, params)
136+
new_text = output[0].outputs[0].text
137+
assert bad_words_1 not in new_text
138+
assert bad_words_2 not in new_text
125139

126140

127141
def test_logits_processor(model):

tests/v1/worker/test_gpu_input_batch.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ def _construct_expected_sampling_metadata(
100100
VOCAB_SIZE,
101101
dtype=torch.bool,
102102
device=device)
103+
bad_words_token_ids = {}
103104
for req in reqs:
104105
if req.req_id not in req_ids_retained:
105106
continue
@@ -123,6 +124,8 @@ def _construct_expected_sampling_metadata(
123124
if req.sampling_params.allowed_token_ids:
124125
allowed_token_ids_mask[index_in_input_batch][
125126
req.sampling_params.allowed_token_ids] = True
127+
bad_words_token_ids[
128+
index_in_input_batch] = req.sampling_params.bad_words_token_ids
126129

127130
return SamplingMetadata(
128131
temperature=torch.tensor(temperature, dtype=torch.float,
@@ -159,6 +162,7 @@ def _construct_expected_sampling_metadata(
159162
and all(x == 1 for x in repetition_penalties)),
160163
logit_bias=logit_bias,
161164
allowed_token_ids_mask=allowed_token_ids_mask,
165+
bad_words_token_ids=bad_words_token_ids,
162166
)
163167

164168

@@ -284,6 +288,8 @@ def same(t1: Optional[torch.Tensor], t2: Optional[torch.Tensor]) -> bool:
284288
assert torch.allclose(
285289
expected_sampling_metadata.allowed_token_ids_mask,
286290
sampling_metadata.allowed_token_ids_mask)
291+
assert expected_sampling_metadata.bad_words_token_ids == \
292+
sampling_metadata.bad_words_token_ids
287293

288294

289295
@pytest.mark.parametrize("device", CUDA_DEVICES)

vllm/sampling_params.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111

1212
from vllm.logger import init_logger
1313
from vllm.logits_process import LogitsProcessor
14+
from vllm.transformers_utils.tokenizer import AnyTokenizer
15+
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
1416

1517
logger = init_logger(__name__)
1618

@@ -202,7 +204,6 @@ class SamplingParams(
202204
seed: Optional[int] = None
203205
stop: Optional[Union[str, list[str]]] = None
204206
stop_token_ids: Optional[list[int]] = None
205-
bad_words: Optional[list[str]] = None
206207
ignore_eos: bool = False
207208
max_tokens: Optional[int] = 16
208209
min_tokens: int = 0
@@ -232,6 +233,10 @@ class SamplingParams(
232233
allowed_token_ids: Optional[list[int]] = None
233234
extra_args: Optional[dict[str, Any]] = None
234235

236+
# Fields used for bad words
237+
bad_words: Optional[list[str]] = None
238+
_bad_words_token_ids: list[list[int]] = msgspec.field(default_factory=list)
239+
235240
@staticmethod
236241
def from_optional(
237242
n: Optional[int] = 1,
@@ -464,6 +469,46 @@ def update_from_generation_config(
464469
eos_ids.update(self.stop_token_ids)
465470
self.stop_token_ids = list(eos_ids)
466471

472+
def update_from_tokenizer(self, tokenizer: AnyTokenizer) -> None:
473+
if self.bad_words is None:
474+
return
475+
for bad_word in self.bad_words:
476+
# To prohibit words both at the beginning
477+
# and in the middle of text
478+
# (related to add_prefix_space tokenizer parameter)
479+
for add_prefix_space in [False, True]:
480+
prefix = " " if add_prefix_space else ""
481+
prompt = prefix + bad_word.lstrip()
482+
483+
if isinstance(tokenizer, MistralTokenizer):
484+
# Mistral tokenizers should not add special tokens
485+
prompt_token_ids = tokenizer.encode(text=prompt)
486+
else:
487+
prompt_token_ids = tokenizer.encode(
488+
text=prompt, add_special_tokens=False)
489+
490+
# If no space at the beginning
491+
# or if prefix space produces a new word token
492+
if (not add_prefix_space) or (
493+
add_prefix_space and prompt_token_ids[0]
494+
!= self._bad_words_token_ids[-1][0]
495+
and len(prompt_token_ids) == len(
496+
self._bad_words_token_ids[-1])):
497+
self._bad_words_token_ids.append(prompt_token_ids)
498+
499+
invalid_token_ids = [
500+
token_id for bad_words_token_ids in self._bad_words_token_ids
501+
for token_id in bad_words_token_ids
502+
if token_id < 0 or token_id > tokenizer.max_token_id
503+
]
504+
if len(invalid_token_ids) > 0:
505+
raise ValueError(
506+
f"The model vocabulary size is {tokenizer.max_token_id+1},"
507+
f" but the following tokens"
508+
f" were specified as bad: {invalid_token_ids}."
509+
f" All token id values should be integers satisfying:"
510+
f" 0 <= token_id <= {tokenizer.max_token_id}.")
511+
467512
@cached_property
468513
def sampling_type(self) -> SamplingType:
469514
if self.temperature < _SAMPLING_EPS:
@@ -476,6 +521,11 @@ def sampling_type(self) -> SamplingType:
476521
def all_stop_token_ids(self) -> set[int]:
477522
return self._all_stop_token_ids
478523

524+
@property
525+
def bad_words_token_ids(self) -> list[list[int]]:
526+
# For internal use only. Backward compatibility not guaranteed
527+
return self._bad_words_token_ids
528+
479529
def clone(self) -> "SamplingParams":
480530
"""Deep copy, but maybe not the LogitsProcessor objects.
481531

vllm/utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2361,3 +2361,19 @@ def __dir__(self) -> list[str]:
23612361
if self._module is None:
23622362
self._module = self._load()
23632363
return dir(self._module)
2364+
2365+
2366+
def swap_dict_values(obj: dict[_K, _V], key1: _K, key2: _K) -> None:
2367+
"""
2368+
Helper function to swap values for two keys
2369+
"""
2370+
v1 = obj.get(key1)
2371+
v2 = obj.get(key2)
2372+
if v1 is not None:
2373+
obj[key2] = v1
2374+
else:
2375+
obj.pop(key2, None)
2376+
if v2 is not None:
2377+
obj[key1] = v2
2378+
else:
2379+
obj.pop(key1, None)

vllm/v1/engine/processor.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,6 @@ def _validate_supported_sampling_params(
9494
# Best of not yet supported.
9595
if params.best_of is not None and params.best_of > 1:
9696
raise ValueError("VLLM V1 does not yet support best_of.")
97-
# Bad words not yet supported.
98-
if params.bad_words:
99-
raise ValueError("VLLM V1 does not yet support bad_words.")
10097
# Logits processors not supported.
10198
if params.logits_processors:
10299
raise ValueError("VLLM V1 does not support per request "
@@ -203,6 +200,8 @@ def process_inputs(
203200
sampling_params = params.clone()
204201
sampling_params.update_from_generation_config(
205202
self.generation_config_fields, eos_token_id)
203+
sampling_params.update_from_tokenizer(
204+
self.tokenizer.get_lora_tokenizer(lora_request))
206205

207206
# Multimodal related.
208207
# Compute MM hashes (if enabled)

vllm/v1/sample/metadata.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,6 @@ class SamplingMetadata:
3838
# `allowed_token_ids_mask` is a 2D bool tensor of shape (max batch size,
3939
# vocab size).
4040
allowed_token_ids_mask: Optional[torch.Tensor]
41+
42+
# req_index -> bad_words_token_ids
43+
bad_words_token_ids: dict[int, list[list[int]]]

vllm/v1/sample/ops/bad_words.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
import torch
4+
5+
_SMALLEST_LOGIT = float("-inf")
6+
7+
8+
def _apply_bad_words_single_batch(
9+
logits: torch.Tensor,
10+
bad_words_token_ids: list[list[int]],
11+
past_tokens_ids: list[int],
12+
) -> None:
13+
for bad_word_ids in bad_words_token_ids:
14+
if len(bad_word_ids) > len(past_tokens_ids) + 1:
15+
continue
16+
17+
prefix_length = len(bad_word_ids) - 1
18+
last_token_id = bad_word_ids[-1]
19+
if prefix_length > 0:
20+
actual_prefix = past_tokens_ids[-prefix_length:]
21+
else:
22+
actual_prefix = []
23+
expected_prefix = bad_word_ids[:prefix_length]
24+
25+
assert len(actual_prefix) == len(expected_prefix)
26+
27+
if actual_prefix == expected_prefix:
28+
logits[last_token_id] = _SMALLEST_LOGIT
29+
30+
31+
def apply_bad_words(
32+
logits: torch.Tensor,
33+
bad_words_token_ids: dict[int, list[list[int]]],
34+
past_tokens_ids: list[list[int]],
35+
) -> None:
36+
for i, bad_words_ids in bad_words_token_ids.items():
37+
_apply_bad_words_single_batch(logits[i], bad_words_ids,
38+
past_tokens_ids[i])

0 commit comments

Comments
 (0)