@@ -27,6 +27,7 @@ class StructuredOutputManager:
2727 def __init__ (self , vllm_config : VllmConfig ):
2828 self .backend : Optional [StructuredOutputBackend ] = None
2929 self .vllm_config = vllm_config
30+
3031 self ._grammar_bitmask : Optional [torch .Tensor ] = None
3132
3233 # The default max_workers if not specified is the number of CPUs * 5,
@@ -80,28 +81,60 @@ def grammar_bitmask(
8081 self ,
8182 requests : dict [str , Request ],
8283 structured_output_request_ids : dict [str , int ],
83- batch_len : int ,
84+ scheduled_spec_decode_tokens : dict [ str , list [ int ]] ,
8485 ) -> Optional [npt .NDArray [np .int32 ]]:
8586 # Prepare the structured output bitmask for this batch.
8687 if not structured_output_request_ids :
8788 return None
8889
8990 if self ._grammar_bitmask is None :
9091 assert self .backend is not None
91- self ._grammar_bitmask = self .backend .allocate_token_bitmask (
92- self .vllm_config .scheduler_config .max_num_seqs )
93-
94- # Fill the bitmask using the index of each request equal to its
95- # position in the batch. Resize the bitmask down to the size of
96- # the batch.
97- bitmask_tensor = self ._grammar_bitmask
98- for req_id , batch_index in structured_output_request_ids .items ():
92+ max_batch_size = self .vllm_config .scheduler_config .max_num_seqs
93+ if self .vllm_config .speculative_config is not None :
94+ max_num_spec_tokens = self .vllm_config .\
95+ speculative_config .num_speculative_tokens
96+ else :
97+ max_num_spec_tokens = 0
98+
99+ # Allocate a bitmask for each token needing to be checked:
100+ # one for each speculative position, and one more for the
101+ # bonus token / non-speculative token.
102+ self ._grammar_bitmask = \
103+ self .backend .allocate_token_bitmask (
104+ max_batch_size * (1 + max_num_spec_tokens ))
105+
106+ # Generate a batched bitmask for all structured output requests.
107+ # When speculative decoding is enabled, we need to include multiple
108+ # masks for each request, one for each possible bonus token position.
109+ # These are stored inline in the tensor and unpacked by the gpu runner.
110+ cumulative_index = 0
111+ ordered_seq = sorted (structured_output_request_ids .items (),
112+ key = lambda x : x [1 ])
113+ # NOTE: This outer loop can likely be parallelized to improve
114+ # performance of bitmask generation for large batches.
115+ for req_id , _ in ordered_seq :
99116 request = requests [req_id ].structured_output_request
100117 assert request is not None and request .grammar is not None
101- if not request .grammar .is_terminated ():
102- request .grammar .fill_bitmask (bitmask_tensor , batch_index )
103- if batch_len < self ._grammar_bitmask .shape [0 ]:
104- bitmask_tensor = self ._grammar_bitmask [:batch_len ]
118+ state_advancements = 0
119+ req_tokens = scheduled_spec_decode_tokens .get (req_id , []) + [None ]
120+ for i , token in enumerate (req_tokens ):
121+ if not request .grammar .is_terminated ():
122+ request .grammar .fill_bitmask (self ._grammar_bitmask ,
123+ cumulative_index )
124+ if token is not None :
125+ # In order to generate the correct bitmask for each
126+ # position in the speculative sequence, we advance
127+ # the FSM state for each speculative token and rollback
128+ # to restore the previous state when we are finished.
129+ assert request .grammar .accept_tokens (req_id , [token ])
130+ state_advancements += 1
131+ cumulative_index += 1
132+ if state_advancements > 0 :
133+ request .grammar .rollback (state_advancements )
134+
135+ bitmask_tensor = self ._grammar_bitmask
136+ if cumulative_index < self ._grammar_bitmask .shape [0 ]:
137+ bitmask_tensor = self ._grammar_bitmask [:cumulative_index ]
105138
106139 # After finishing with the xgrammar operations, we convert to
107140 # np.ndarray, because that is much more efficient for serialization
0 commit comments