Skip to content

Commit 3f5a4b6

Browse files
authored
[Bugfix] Validate custom logits processor xargs for online serving (#27560)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
1 parent 6cae1e5 commit 3f5a4b6

File tree

18 files changed

+232
-49
lines changed

18 files changed

+232
-49
lines changed

docs/design/logits_processors.md

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,15 @@ The previous sections alluded to the interfaces which vLLM logits processors mus
254254
changes to the batch makeup.
255255
"""
256256
raise NotImplementedError
257-
257+
258+
@classmethod
259+
def validate_params(cls, sampling_params: SamplingParams):
260+
"""Validate sampling params for this logits processor.
261+
262+
Raise ValueError for invalid ones.
263+
"""
264+
return None
265+
258266
```
259267

260268
A vLLM logits processor must subclass `LogitsProcessor` and define (at minimum) the following methods:
@@ -279,6 +287,10 @@ A vLLM logits processor must subclass `LogitsProcessor` and define (at minimum)
279287
* Use the `BatchUpdate` members to update logits processor internal state
280288
* **Note:** batch update data structure may be `None`, signaling no change to the batch constituents. In this case, the LogitsProcessor might still want to update its state based on the updated `output_token_ids` lists that it could have retained when they were added.
281289

290+
* `validate_params(cls, sampling_params: SamplingParams)`:
291+
* Raise `ValueError` if `SamplingParams` has invalid arguments (especially custom arguments) used by logits processor.
292+
* When request is sent to entrypoint, `validate_params()` will validate `SamplingParams` and refuse request with invalid arguments.
293+
282294
### `BatchUpdate` data structure
283295

284296
The `BatchUpdate` abstraction models the persistent batch as a list of requests, supporting the following operations to change batch state (note that the order in which the operations are mentioned below reflects the order in which they should be processed in `update_state()`):

docs/features/custom_arguments.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@ You can use vLLM *custom arguments* to pass in arguments which are not part of t
44

55
Custom arguments can be useful if, for example, you want to use a [custom logits processor](./custom_logitsprocs.md) without modifying the vLLM source code.
66

7+
!!! note
8+
Make sure your custom logits processor have implemented `validate_params` for custom arguments. Otherwise invalid custom arguments can cause unexpected behaviour.
9+
710
## Offline Custom Arguments
811

912
Custom arguments passed to `SamplingParams.extra_args` as a `dict` will be visible to any code which has access to `SamplingParams`:

docs/features/custom_logitsprocs.md

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,11 @@ In vLLM, logits processors operate at batch granularity. During a given engine s
1818

1919
Custom logits processors must subclass `vllm.v1.sample.logits_processor.LogitsProcessor` and define (at minimum) the following methods:
2020

21+
* `validate_params(cls, sampling_params: SamplingParams)`:
22+
* Raise `ValueError` if `SamplingParams` has invalid arguments (especially custom arguments) used by logits processor.
23+
* When request is sent to entrypoint, `validate_params()` will validate `SamplingParams` and refuse request with invalid arguments.
24+
* **Note:** it's important to implement `validate_params()` to prevent invalid parameters for custom logits processor. Otherwise requests with invalid parameters can cause unexpected behaviour in custom logits processor.
25+
2126
* `__init__(self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool)`
2227
* `vllm_config`: engine configuration data structure
2328
* `device`: hardware accelerator device info
@@ -103,6 +108,14 @@ The contrived example below implements a custom logits processor which consumes
103108
class DummyLogitsProcessor(LogitsProcessor):
104109
"""Fake logit processor to support unit testing and examples"""
105110

111+
@classmethod
112+
def validate_params(cls, params: SamplingParams):
113+
target_token: int | None = params.extra_args and params.extra_args.get(
114+
"target_token"
115+
)
116+
if target_token is not None and not isinstance(target_token, int):
117+
raise ValueError(f"target_token value {target_token} is not int")
118+
106119
def __init__(self, vllm_config: "VllmConfig", device: torch.device,
107120
is_pin_memory: bool):
108121
self.req_info: dict[int, int] = {}
@@ -118,6 +131,7 @@ The contrived example below implements a custom logits processor which consumes
118131
# Process added requests.
119132
for index, params, _, _ in batch_update.added:
120133
assert params is not None
134+
self.validate_params(params)
121135
if params.extra_args and (target_token :=
122136
params.extra_args.get("target_token")):
123137
self.req_info[index] = target_token
@@ -157,6 +171,7 @@ The contrived example below implements a custom logits processor which consumes
157171
logits[rows, cols] = values_to_keep
158172

159173
return logits
174+
160175
```
161176

162177
In the rest of this document, we will use `DummyLogitsProcessor` as an example of a custom logits processor.
@@ -180,7 +195,13 @@ RequestLogitsProcessor = Union[
180195

181196
While request-level logits processors are explicitly *not* supported in the vLLM engine, vLLM *does* provide a convenient process to wrap an existing `Callable` request-level logits processor and create a batch-level logits processor that is compatible with vLLM. The `Callable` must conform to the type annotation above; if your request-level logits processor has a different interface, then in order to wrap it, you may need to modify it or implement an additional wrapper layer to comply with the interface specification above.
182197

183-
You can wrap the request-level logits processor by subclassing `AdapterLogitsProcessor` as shown in the example below (in this example, `DummyPerReqLogitsProcessor` is a stand-in for your request-level logits processor which needs to be wrapped.) Override `AdapterLogitsProcessor.is_argmax_invariant(self)` to accurately reflect whether your request-level logits processor may impact which token has the highest-value logit. Override `AdapterLogitsProcessor.new_req_logits_processor(self,params)` to create a new request-level logits processor instance from a `SamplingParams` instance:
198+
You can wrap the request-level logits processor by subclassing `AdapterLogitsProcessor` as shown in the example below (in this example, `DummyPerReqLogitsProcessor` is a stand-in for your request-level logits processor which needs to be wrapped.):
199+
200+
* Override `AdapterLogitsProcessor.validate_params(cls,params)` to validate request's sampling parameters.
201+
202+
* Override `AdapterLogitsProcessor.is_argmax_invariant(self)` to accurately reflect whether your request-level logits processor may impact which token has the highest-value logit.
203+
204+
* Override `AdapterLogitsProcessor.new_req_logits_processor(self,params)` to create a new request-level logits processor instance from a `SamplingParams` instance:
184205

185206
??? code "Example of Wrapping a Request-Level Logits Processor"
186207

@@ -220,6 +241,16 @@ You can wrap the request-level logits processor by subclassing `AdapterLogitsPro
220241
"""Example of wrapping a fake request-level logit processor to create a
221242
batch-level logits processor"""
222243

244+
@classmethod
245+
def validate_params(cls, params: SamplingParams):
246+
target_token: Any | None = params.extra_args and params.extra_args.get(
247+
"target_token"
248+
)
249+
if target_token is not None and not isinstance(target_token, int):
250+
raise ValueError(
251+
f"target_token value {target_token} is not int"
252+
)
253+
223254
def is_argmax_invariant(self) -> bool:
224255
return False
225256

@@ -240,18 +271,11 @@ You can wrap the request-level logits processor by subclassing `AdapterLogitsPro
240271
Returns:
241272
`Callable` request logits processor, or None
242273
"""
243-
target_token: Optional[Any] = params.extra_args and params.extra_args.get(
274+
target_token: Any | None = params.extra_args and params.extra_args.get(
244275
"target_token"
245276
)
246277
if target_token is None:
247278
return None
248-
if not isinstance(target_token, int):
249-
logger.warning(
250-
"target_token value %s is not int; not applying logits"
251-
" processor to request.",
252-
target_token,
253-
)
254-
return None
255279
return DummyPerReqLogitsProcessor(target_token)
256280
```
257281

examples/offline_inference/logits_processor/custom.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ class object.
3333
------------------------------------------------------------
3434
"""
3535

36+
from typing import Any
37+
3638
import torch
3739

3840
from vllm import LLM, SamplingParams
@@ -48,6 +50,16 @@ class object.
4850
class DummyLogitsProcessor(LogitsProcessor):
4951
"""Fake logit processor to support unit testing and examples"""
5052

53+
@classmethod
54+
def validate_params(cls, params: SamplingParams):
55+
target_token: Any | None = params.extra_args and params.extra_args.get(
56+
"target_token"
57+
)
58+
if target_token is not None and not isinstance(target_token, int):
59+
raise ValueError(
60+
f"target_token value {target_token} {type(target_token)} is not int"
61+
)
62+
5163
def __init__(
5264
self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool
5365
):
@@ -57,14 +69,17 @@ def is_argmax_invariant(self) -> bool:
5769
return False
5870

5971
def update_state(self, batch_update: BatchUpdate | None):
72+
def extract_extra_arg(params: SamplingParams) -> int | None:
73+
self.validate_params(params)
74+
return params.extra_args and params.extra_args.get("target_token")
75+
6076
process_dict_updates(
6177
self.req_info,
6278
batch_update,
6379
# This function returns the LP's per-request state based on the
6480
# request details, or None if this LP does not apply to the
6581
# request.
66-
lambda params, _, __: params.extra_args
67-
and (params.extra_args.get("target_token")),
82+
lambda params, _, __: extract_extra_arg(params),
6883
)
6984

7085
def apply(self, logits: torch.Tensor) -> torch.Tensor:

examples/offline_inference/logits_processor/custom_req.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,14 @@ class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor):
7676
"""Example of wrapping a fake request-level logit processor to create a
7777
batch-level logits processor"""
7878

79+
@classmethod
80+
def validate_params(cls, params: SamplingParams):
81+
target_token: Any | None = params.extra_args and params.extra_args.get(
82+
"target_token"
83+
)
84+
if target_token is not None and not isinstance(target_token, int):
85+
raise ValueError(f"target_token value {target_token} is not int")
86+
7987
def is_argmax_invariant(self) -> bool:
8088
return False
8189

@@ -101,13 +109,6 @@ def new_req_logits_processor(
101109
)
102110
if target_token is None:
103111
return None
104-
if not isinstance(target_token, int):
105-
logger.warning(
106-
"target_token value %s is not int; not applying logits"
107-
" processor to request.",
108-
target_token,
109-
)
110-
return None
111112
return DummyPerReqLogitsProcessor(target_token)
112113

113114

examples/offline_inference/logits_processor/custom_req_init.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,14 @@ class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor):
7777
"""Example of overriding the wrapper class `__init__()` in order to utilize
7878
info about the device type"""
7979

80+
@classmethod
81+
def validate_params(cls, params: SamplingParams):
82+
target_token = params.extra_args and params.extra_args.get("target_token")
83+
if target_token is not None and not isinstance(target_token, int):
84+
raise ValueError(
85+
f"`target_token` has to be an integer, got {target_token}."
86+
)
87+
8088
def __init__(
8189
self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool
8290
):
@@ -113,13 +121,6 @@ def new_req_logits_processor(
113121
is None
114122
):
115123
return None
116-
if not isinstance(target_token, int):
117-
logger.warning(
118-
"target_token value %s is not int; not applying logits"
119-
" processor to request.",
120-
target_token,
121-
)
122-
return None
123124
return DummyPerReqLogitsProcessor(target_token)
124125

125126

tests/entrypoints/openai/test_lora_resolvers.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ class MockModelConfig:
4040
tokenizer_revision: str | None = None
4141
multimodal_config: MultiModalConfig = field(default_factory=MultiModalConfig)
4242
hf_config: MockHFConfig = field(default_factory=MockHFConfig)
43+
logits_processors: list[str] | None = None
4344
logits_processor_pattern: str | None = None
4445
diff_sampling_param: dict | None = None
4546
allowed_local_media_path: str = ""

tests/entrypoints/openai/test_serving_chat.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,7 @@ class MockModelConfig:
353353
tokenizer_revision = None
354354
multimodal_config = MultiModalConfig()
355355
hf_config = MockHFConfig()
356+
logits_processors: list[str] | None = None
356357
logits_processor_pattern = None
357358
diff_sampling_param: dict | None = None
358359
allowed_local_media_path: str = ""

tests/v1/logits_processors/test_custom_online.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,3 +177,32 @@ async def test_custom_logitsprocs(client: openai.AsyncOpenAI, model_name: str):
177177

178178
# Alternate whether to activate dummy logitproc for each request
179179
use_dummy_logitproc = not use_dummy_logitproc
180+
181+
182+
@pytest.mark.asyncio
183+
@pytest.mark.parametrize(
184+
"model_name",
185+
[MODEL_NAME],
186+
)
187+
async def test_invalid_custom_logitsproc_arg(
188+
client: openai.AsyncOpenAI, model_name: str
189+
):
190+
"""Test that request with invalid custom logitsproc is rejected"""
191+
192+
prompt = "Hello, my name is"
193+
# Pass invalid (non-int) target_token value to dummy logits processor
194+
request_keyword_args: dict[str, Any] = {
195+
**api_keyword_args,
196+
"extra_body": {
197+
"vllm_xargs": {DUMMY_LOGITPROC_ARG: "invalid_target_token_value"}
198+
},
199+
}
200+
201+
with pytest.raises(openai.OpenAIError) as exc_info:
202+
await client.completions.create(
203+
model=model_name,
204+
prompt=prompt,
205+
**request_keyword_args,
206+
)
207+
208+
assert "is not int" in str(exc_info.value)

tests/v1/logits_processors/utils.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,16 @@ class CustomLogitprocSource(Enum):
5252
class DummyLogitsProcessor(LogitsProcessor):
5353
"""Fake logit processor to support unit testing and examples"""
5454

55+
@classmethod
56+
def validate_params(cls, params: SamplingParams):
57+
target_token: int | None = params.extra_args and params.extra_args.get(
58+
"target_token"
59+
)
60+
if target_token is not None and not isinstance(target_token, int):
61+
raise ValueError(
62+
f"target_token value {target_token} {type(target_token)} is not int"
63+
)
64+
5565
def __init__(
5666
self, vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool
5767
):
@@ -62,11 +72,14 @@ def is_argmax_invariant(self) -> bool:
6272
return False
6373

6474
def update_state(self, batch_update: BatchUpdate | None):
75+
def extract_extra_arg(params: SamplingParams) -> int | None:
76+
self.validate_params(params)
77+
return params.extra_args and params.extra_args.get("target_token")
78+
6579
process_dict_updates(
6680
self.req_info,
6781
batch_update,
68-
lambda params, _, __: params.extra_args
69-
and (params.extra_args.get("target_token")),
82+
lambda params, _, __: extract_extra_arg(params),
7083
)
7184

7285
def apply(self, logits: torch.Tensor) -> torch.Tensor:

0 commit comments

Comments
 (0)