Skip to content

Commit 7ae9887

Browse files
[V1] Logits processor docs (#22919)
Signed-off-by: Andrew Feldman <afeldman@redhat.com> Signed-off-by: afeldman-nm <156691304+afeldman-nm@users.noreply.github.com> Co-authored-by: Joseph Marinier <Joseph.Marinier@gmail.com>
1 parent e3db5eb commit 7ae9887

File tree

7 files changed

+1065
-16
lines changed

7 files changed

+1065
-16
lines changed

docs/design/logits_processors.md

Lines changed: 559 additions & 0 deletions
Large diffs are not rendered by default.

docs/features/custom_arguments.md

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Custom Arguments
2+
3+
You can use vLLM *custom arguments* to pass in arguments which are not part of the vLLM `SamplingParams` and REST API specifications. Adding or removing a vLLM custom argument does not require recompiling vLLM, since the custom arguments are passed in as a dictionary.
4+
5+
Custom arguments can be useful if, for example, you want to use a [custom logits processor](./custom_logitsprocs.md) without modifying the vLLM source code.
6+
7+
## Offline Custom Arguments
8+
9+
Custom arguments passed to `SamplingParams.extra_args` as a `dict` will be visible to any code which has access to `SamplingParams`:
10+
11+
``` python
12+
SamplingParams(extra_args={"your_custom_arg_name": 67})
13+
```
14+
15+
This allows arguments which are not already part of `SamplingParams` to be passed into `LLM` as part of a request.
16+
17+
## Online Custom Arguments
18+
19+
The vLLM REST API allows custom arguments to be passed to the vLLM server via `vllm_xargs`. The example below integrates custom arguments into a vLLM REST API request:
20+
21+
``` bash
22+
curl http://localhost:8000/v1/completions \
23+
-H "Content-Type: application/json" \
24+
-d '{
25+
"model": "Qwen/Qwen2.5-1.5B-Instruct",
26+
...
27+
"vllm_xargs": {"your_custom_arg": 67}
28+
}'
29+
```
30+
31+
Furthermore, OpenAI SDK users can access `vllm_xargs` via the `extra_body` argument:
32+
33+
``` python
34+
batch = await client.completions.create(
35+
model="Qwen/Qwen2.5-1.5B-Instruct",
36+
...,
37+
extra_body={
38+
"vllm_xargs": {
39+
"your_custom_arg": 67
40+
}
41+
}
42+
)
43+
```
44+
45+
!!! note
46+
`vllm_xargs` is assigned to `SamplingParams.extra_args` under the hood, so code which uses `SamplingParams.extra_args` is compatible with both offline and online scenarios.

docs/features/custom_logitsprocs.md

Lines changed: 445 additions & 0 deletions
Large diffs are not rendered by default.

examples/offline_inference/logits_processor/custom.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@ def __init__(
5656
self.req_info: dict[int, int] = {}
5757

5858
def is_argmax_invariant(self) -> bool:
59-
"""Never impacts greedy sampling"""
6059
return False
6160

6261
def update_state(self, batch_update: Optional[BatchUpdate]):
@@ -75,13 +74,12 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor:
7574
return logits
7675

7776
# Save target values before modification
78-
rows_list = list(self.req_info.keys())
7977
cols = torch.tensor(
80-
[self.req_info[i] for i in rows_list],
81-
dtype=torch.long,
82-
device=logits.device,
78+
list(self.req_info.values()), dtype=torch.long, device=logits.device
79+
)
80+
rows = torch.tensor(
81+
list(self.req_info.keys()), dtype=torch.long, device=logits.device
8382
)
84-
rows = torch.tensor(rows_list, dtype=torch.long, device=logits.device)
8583
values_to_keep = logits[rows, cols].clone()
8684

8785
# Mask all but target tokens

tests/v1/logits_processors/utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,12 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor:
6969
return logits
7070

7171
# Save target values before modification
72-
rows_list = list(self.req_info.keys())
73-
cols = torch.tensor([self.req_info[i] for i in rows_list],
72+
cols = torch.tensor(list(self.req_info.values()),
73+
dtype=torch.long,
74+
device=logits.device)
75+
rows = torch.tensor(list(self.req_info.keys()),
7476
dtype=torch.long,
7577
device=logits.device)
76-
rows = torch.tensor(rows_list, dtype=torch.long, device=logits.device)
7778
values_to_keep = logits[rows, cols].clone()
7879

7980
# Mask all but target tokens

vllm/v1/sample/logits_processor/interface.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ class MoveDirectionality(Enum):
2121
SWAP = auto()
2222

2323

24+
# Batch indices of any removed requests.
25+
RemovedRequest = int
26+
2427
# (index, params, prompt_tok_ids, output_tok_ids) tuples for new
2528
# requests added to the batch.
2629
AddedRequest = tuple[int, SamplingParams, list[int], list[int]]
@@ -29,9 +32,6 @@ class MoveDirectionality(Enum):
2932
# one-way moves or two-way swaps of requests in batch
3033
MovedRequest = tuple[int, int, MoveDirectionality]
3134

32-
# Batch indices of any removed requests.
33-
RemovedRequest = int
34-
3535

3636
@dataclass(frozen=True)
3737
class BatchUpdate:

vllm/v1/sample/logits_processor/state.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,18 +36,18 @@ class BatchUpdateBuilder:
3636

3737
_removed: list[RemovedRequest]
3838
_is_removed_sorted: bool
39-
moved: list[MovedRequest]
4039
added: list[AddedRequest]
40+
moved: list[MovedRequest]
4141

4242
def __init__(
4343
self,
4444
removed: Optional[list[RemovedRequest]] = None,
45-
moved: Optional[list[MovedRequest]] = None,
4645
added: Optional[list[AddedRequest]] = None,
46+
moved: Optional[list[MovedRequest]] = None,
4747
) -> None:
4848
self._removed = removed or []
49-
self.moved = moved or []
5049
self.added = added or []
50+
self.moved = moved or []
5151
self._is_removed_sorted = False
5252

5353
# Used to track changes in the pooling case
@@ -107,8 +107,8 @@ def reset(self) -> bool:
107107
"""Returns True if there were any changes to the batch."""
108108
self._is_removed_sorted = False
109109
self._removed.clear()
110-
self.moved.clear()
111110
self.added.clear()
111+
self.moved.clear()
112112
batch_changed = self.batch_changed
113113
self.batch_changed = False
114114
return batch_changed

0 commit comments

Comments
 (0)