11# SPDX-License-Identifier: Apache-2.0
22"""A layer that compute logits from hidden_stats."""
33import inspect
4+ from concurrent .futures import ThreadPoolExecutor
45from typing import Optional
56
67import torch
1516from vllm .model_executor .sampling_metadata import SamplingMetadata
1617from vllm .platforms import current_platform
1718
19+ _logits_processor_threadpool : Optional [ThreadPoolExecutor ] = None
20+ if envs .VLLM_LOGITS_PROCESSOR_THREADS is not None :
21+ _logits_processor_threadpool = ThreadPoolExecutor (
22+ envs .VLLM_LOGITS_PROCESSOR_THREADS )
23+
1824
1925class LogitsProcessor (nn .Module ):
2026 """Process logits and apply logits processors from sampling metadata.
@@ -135,6 +141,7 @@ def _apply_logits_processors(
135141) -> torch .Tensor :
136142 found_logits_processors = False
137143 logits_processed = 0
144+ logits_row_ids_and_logits_row_futures = []
138145 for seq_group in sampling_metadata .seq_groups :
139146 seq_ids = seq_group .seq_ids
140147 sampling_params = seq_group .sampling_params
@@ -148,22 +155,39 @@ def _apply_logits_processors(
148155 past_tokens_ids = seq_group .seq_data [seq_id ].output_token_ids
149156 prompt_tokens_ids = seq_group .seq_data [seq_id ].prompt_token_ids
150157
151- for logits_processor in logits_processors :
152- parameters = inspect .signature (logits_processor ).parameters
153- if len (parameters ) == 3 :
154- logits_row = logits_processor (prompt_tokens_ids ,
155- past_tokens_ids ,
156- logits_row )
157- else :
158- logits_row = logits_processor (past_tokens_ids ,
159- logits_row )
160-
161- logits [logits_row_idx ] = logits_row
158+ if _logits_processor_threadpool is not None :
159+ logits_row_ids_and_logits_row_futures .append (
160+ (logits_row_idx ,
161+ _logits_processor_threadpool .submit (
162+ _apply_logits_processors_single_seq , logits_row ,
163+ logits_processors , past_tokens_ids ,
164+ prompt_tokens_ids )))
165+ else :
166+ logits [logits_row_idx ] = \
167+ _apply_logits_processors_single_seq (
168+ logits_row , logits_processors , past_tokens_ids ,
169+ prompt_tokens_ids )
162170
163171 logits_processed += len (seq_group .sample_indices ) + len (
164172 seq_group .prompt_logprob_indices )
165173
174+ for logits_row_idx , future in logits_row_ids_and_logits_row_futures :
175+ logits [logits_row_idx ] = future .result ()
176+
166177 if found_logits_processors :
167178 # verifies that no rows in logits were missed unexpectedly
168179 assert logits_processed == logits .shape [0 ]
169180 return logits
181+
182+
183+ def _apply_logits_processors_single_seq (logits_row , logits_processors ,
184+ past_tokens_ids ,
185+ prompt_tokens_ids ) -> torch .Tensor :
186+ for logits_processor in logits_processors :
187+ parameters = inspect .signature (logits_processor ).parameters
188+ if len (parameters ) == 3 :
189+ logits_row = logits_processor (prompt_tokens_ids , past_tokens_ids ,
190+ logits_row )
191+ else :
192+ logits_row = logits_processor (past_tokens_ids , logits_row )
193+ return logits_row
0 commit comments