Skip to content

Commit b3a0d01

Browse files
authored
[Core] add and implement VLLM_LOGITS_PROCESSOR_THREADS (#12368)
Signed-off-by: Aviv Keshet <akeshet@scaledcognition.com>
1 parent 75e9430 commit b3a0d01

File tree

2 files changed

+44
-11
lines changed

2 files changed

+44
-11
lines changed

vllm/envs.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
VLLM_LOGGING_LEVEL: str = "INFO"
3232
VLLM_LOGGING_PREFIX: str = ""
3333
VLLM_LOGGING_CONFIG_PATH: Optional[str] = None
34+
VLLM_LOGITS_PROCESSOR_THREADS: Optional[int] = None
3435
VLLM_TRACE_FUNCTION: int = 0
3536
VLLM_ATTENTION_BACKEND: Optional[str] = None
3637
VLLM_USE_FLASHINFER_SAMPLER: Optional[bool] = None
@@ -282,6 +283,14 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
282283
"VLLM_LOGGING_PREFIX":
283284
lambda: os.getenv("VLLM_LOGGING_PREFIX", ""),
284285

286+
# if set, vllm will call logits processors in a thread pool with this many
287+
# threads. This is useful when using custom logits processors that either
288+
# (a) launch additional CUDA kernels or (b) do significant CPU-bound work
289+
# while not holding the python GIL, or both.
290+
"VLLM_LOGITS_PROCESSOR_THREADS":
291+
lambda: int(os.getenv("VLLM_LOGITS_PROCESSOR_THREADS", "0"))
292+
if "VLLM_LOGITS_PROCESSOR_THREADS" in os.environ else None,
293+
285294
# Trace function calls
286295
# If set to 1, vllm will trace function calls
287296
# Useful for debugging

vllm/model_executor/layers/logits_processor.py

Lines changed: 35 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22
"""A layer that compute logits from hidden_stats."""
33
import inspect
4+
from concurrent.futures import ThreadPoolExecutor
45
from typing import Optional
56

67
import torch
@@ -15,6 +16,11 @@
1516
from vllm.model_executor.sampling_metadata import SamplingMetadata
1617
from vllm.platforms import current_platform
1718

19+
_logits_processor_threadpool: Optional[ThreadPoolExecutor] = None
20+
if envs.VLLM_LOGITS_PROCESSOR_THREADS is not None:
21+
_logits_processor_threadpool = ThreadPoolExecutor(
22+
envs.VLLM_LOGITS_PROCESSOR_THREADS)
23+
1824

1925
class LogitsProcessor(nn.Module):
2026
"""Process logits and apply logits processors from sampling metadata.
@@ -135,6 +141,7 @@ def _apply_logits_processors(
135141
) -> torch.Tensor:
136142
found_logits_processors = False
137143
logits_processed = 0
144+
logits_row_ids_and_logits_row_futures = []
138145
for seq_group in sampling_metadata.seq_groups:
139146
seq_ids = seq_group.seq_ids
140147
sampling_params = seq_group.sampling_params
@@ -148,22 +155,39 @@ def _apply_logits_processors(
148155
past_tokens_ids = seq_group.seq_data[seq_id].output_token_ids
149156
prompt_tokens_ids = seq_group.seq_data[seq_id].prompt_token_ids
150157

151-
for logits_processor in logits_processors:
152-
parameters = inspect.signature(logits_processor).parameters
153-
if len(parameters) == 3:
154-
logits_row = logits_processor(prompt_tokens_ids,
155-
past_tokens_ids,
156-
logits_row)
157-
else:
158-
logits_row = logits_processor(past_tokens_ids,
159-
logits_row)
160-
161-
logits[logits_row_idx] = logits_row
158+
if _logits_processor_threadpool is not None:
159+
logits_row_ids_and_logits_row_futures.append(
160+
(logits_row_idx,
161+
_logits_processor_threadpool.submit(
162+
_apply_logits_processors_single_seq, logits_row,
163+
logits_processors, past_tokens_ids,
164+
prompt_tokens_ids)))
165+
else:
166+
logits[logits_row_idx] = \
167+
_apply_logits_processors_single_seq(
168+
logits_row, logits_processors, past_tokens_ids,
169+
prompt_tokens_ids)
162170

163171
logits_processed += len(seq_group.sample_indices) + len(
164172
seq_group.prompt_logprob_indices)
165173

174+
for logits_row_idx, future in logits_row_ids_and_logits_row_futures:
175+
logits[logits_row_idx] = future.result()
176+
166177
if found_logits_processors:
167178
# verifies that no rows in logits were missed unexpectedly
168179
assert logits_processed == logits.shape[0]
169180
return logits
181+
182+
183+
def _apply_logits_processors_single_seq(logits_row, logits_processors,
184+
past_tokens_ids,
185+
prompt_tokens_ids) -> torch.Tensor:
186+
for logits_processor in logits_processors:
187+
parameters = inspect.signature(logits_processor).parameters
188+
if len(parameters) == 3:
189+
logits_row = logits_processor(prompt_tokens_ids, past_tokens_ids,
190+
logits_row)
191+
else:
192+
logits_row = logits_processor(past_tokens_ids, logits_row)
193+
return logits_row

0 commit comments

Comments
 (0)