33from __future__ import annotations
44
55import multiprocessing
6- from concurrent .futures import ThreadPoolExecutor
6+ from concurrent .futures import Future , ThreadPoolExecutor
77from typing import TYPE_CHECKING , Optional
88
99from vllm .config import VllmConfig
@@ -40,6 +40,17 @@ def __init__(self, vllm_config: VllmConfig):
4040 self ._grammar_bitmask : Optional [torch .Tensor ] = None
4141 self ._full_mask = torch .tensor (- 1 , dtype = torch .int32 )
4242
43+ max_batch_size = self .vllm_config .scheduler_config .max_num_seqs
44+ self .fill_bitmask_parallel_threshold = 128
45+ if self .fill_bitmask_parallel_threshold < max_batch_size :
46+ self .fill_bitmask_parallel_batch_size = 16
47+ # Use:
48+ # - at least 1 CPU
49+ # - at most half the number of CPUs or 8, whichever is less
50+ max_workers = max (1 , min (multiprocessing .cpu_count () // 2 , 8 ))
51+ self .executor_for_fillmask = ThreadPoolExecutor (
52+ max_workers = max_workers )
53+
4354 if not self .vllm_config .model_config .skip_tokenizer_init :
4455 # The default max_workers if not specified is the number of
4556 # CPUs * 5, which is way too high since these tasks are CPU-bound,
@@ -120,6 +131,26 @@ def _async_create_grammar(
120131 assert self .backend is not None
121132 return self .backend .compile_grammar (request_type , grammar_spec )
122133
134+ def _fill_bitmasks (
135+ self ,
136+ batch : list [tuple [StructuredOutputGrammar , int , bool ]],
137+ ) -> None :
138+ assert self ._grammar_bitmask is not None
139+ for grammar , index , apply_bitmask in batch :
140+ if apply_bitmask and not grammar .is_terminated ():
141+ grammar .fill_bitmask (self ._grammar_bitmask , index )
142+ else :
143+ # Note that for thinking support, we will need to
144+ # reset the relevant part of the bitmask for consequent
145+ # requests here.
146+ self ._grammar_bitmask [index ].fill_ (self ._full_mask )
147+
148+ def _async_submit_fill_bitmask (
149+ self ,
150+ batch : list [tuple [StructuredOutputGrammar , int , bool ]],
151+ ) -> Future :
152+ return self .executor_for_fillmask .submit (self ._fill_bitmasks , batch )
153+
123154 def grammar_bitmask (
124155 self ,
125156 requests : dict [str , Request ],
@@ -146,7 +177,6 @@ def grammar_bitmask(
146177 self .backend .allocate_token_bitmask (
147178 max_batch_size * (1 + max_num_spec_tokens ))
148179
149- bitmask_tensor = self ._grammar_bitmask
150180 # Generate a batched bitmask for all structured output requests.
151181 # When speculative decoding is enabled, we need to include multiple
152182 # masks for each request, one for each possible bonus token position.
@@ -155,47 +185,61 @@ def grammar_bitmask(
155185 ordered_seq = sorted (structured_output_request_ids .items (),
156186 key = lambda x : x [1 ])
157187
158- # Note that for thinking support, we will need to
159- # reset the relevant part of the bitmask for consequent
160- # request here.
161- bitmask_tensor [:(len (ordered_seq ) * (1 + max_num_spec_tokens ))].fill_ (
162- self ._full_mask )
163-
164- # NOTE: This outer loop can likely be parallelized to improve
165- # performance of bitmask generation for large batches.
166- for req_id , _ in ordered_seq :
167- request = requests [req_id ]
168- structured_output_request = request .structured_output_request
169-
170- if TYPE_CHECKING :
171- assert structured_output_request is not None
172- assert structured_output_request .grammar is not None
173- apply_bitmask : bool = True
174- if self .reasoner is not None :
175- if structured_output_request .reasoning_ended is None :
176- structured_output_request .reasoning_ended = \
177- self .reasoner .is_reasoning_end (request .prompt_token_ids )
178- apply_bitmask = structured_output_request .reasoning_ended
179-
180- state_advancements = 0
181- req_tokens = scheduled_spec_decode_tokens .get (req_id , []) + [None ]
182- for i , token in enumerate (req_tokens ):
183- if apply_bitmask and not \
184- structured_output_request .grammar .is_terminated ():
185- structured_output_request .grammar .fill_bitmask (
186- bitmask_tensor , cumulative_index )
187- if token is not None :
188- # In order to generate the correct bitmask for each
189- # position in the speculative sequence, we advance
190- # the FSM state for each speculative token and rollback
191- # to restore the previous state when we are finished.
188+ # Optimized parallel filling of bitmasks for
189+ # non-spec, large-batch-size cases
190+ if len (ordered_seq ) > self .fill_bitmask_parallel_threshold and \
191+ max_num_spec_tokens == 0 :
192+ promises = []
193+ batch = []
194+ for req_id , _ in ordered_seq :
195+ request = requests [req_id ]
196+ structured_output_request = request .structured_output_request
197+ if TYPE_CHECKING :
198+ assert structured_output_request is not None
199+ assert structured_output_request .grammar is not None
200+
201+ apply_bitmask = self .should_fill_bitmask (request )
202+ batch .append ((structured_output_request .grammar ,
203+ cumulative_index , apply_bitmask ))
204+ if len (batch ) == self .fill_bitmask_parallel_batch_size :
205+ promises .append (self ._async_submit_fill_bitmask (batch ))
206+ batch = []
207+
208+ cumulative_index += 1
209+ if batch :
210+ promises .append (self ._async_submit_fill_bitmask (batch ))
211+
212+ # Wait for all bitmask filling tasks to complete.
213+ for promise in promises :
214+ promise .result ()
215+ else :
216+ # Fallback to serial filling of bitmasks for small-batch-size cases
217+ for req_id , _ in ordered_seq :
218+ request = requests [req_id ]
219+ structured_output_request = request .structured_output_request
220+
221+ if TYPE_CHECKING :
222+ assert structured_output_request is not None
223+ assert structured_output_request .grammar is not None
224+ apply_bitmask = self .should_fill_bitmask (request )
225+
226+ state_advancements = 0
227+ req_tokens = scheduled_spec_decode_tokens .get (req_id , [])
228+ for i , token in enumerate (req_tokens + [None ]):
229+ self ._fill_bitmasks ([(structured_output_request .grammar ,
230+ cumulative_index , apply_bitmask )])
231+
232+ if apply_bitmask and token is not None and \
233+ not structured_output_request .grammar .is_terminated ():
192234 assert structured_output_request .grammar .accept_tokens (
193235 req_id , [token ])
194236 state_advancements += 1
195- cumulative_index += 1
196- if state_advancements > 0 :
197- structured_output_request .grammar .rollback (state_advancements )
237+ cumulative_index += 1
238+ if state_advancements > 0 :
239+ structured_output_request .grammar .rollback (
240+ state_advancements )
198241
242+ bitmask_tensor = self ._grammar_bitmask
199243 if cumulative_index < bitmask_tensor .shape [0 ]:
200244 bitmask_tensor = bitmask_tensor [:cumulative_index ]
201245
@@ -204,6 +248,15 @@ def grammar_bitmask(
204248 # and deserialization when sending this to the GPU workers.
205249 return bitmask_tensor .numpy ()
206250
251+ def should_fill_bitmask (self , request : Request ) -> bool :
252+ if self .reasoner is not None :
253+ assert request .structured_output_request is not None
254+ if request .structured_output_request .reasoning_ended is None :
255+ request .structured_output_request .reasoning_ended = \
256+ self .reasoner .is_reasoning_end (request .prompt_token_ids )
257+ return request .structured_output_request .reasoning_ended
258+ return True
259+
207260 def should_advance (self , request : Request ) -> bool :
208261 if not request .use_structured_output :
209262 return False
0 commit comments