@@ -1283,42 +1283,52 @@ def apply_grammar_bitmask(
12831283        scheduler_output : "SchedulerOutput" ,
12841284        logits : torch .Tensor ,
12851285    ) ->  torch .Tensor :
1286-         # Serialization of np.ndarray is much more efficient than a tensor, 
1287-         # so we receive it in that format. 
12881286        grammar_bitmask  =  scheduler_output .grammar_bitmask 
1289-         if  grammar_bitmask  is  None :
1290-             return 
12911287
1292-         # We receive the structured output bitmask from the scheduler, but the 
1293-         # indices of the requests in the batch may not match the indices of 
1294-         # the bitmask since the scheduler doesn't know how the gpu runner is 
1295-         # ordering the requests in the batch. We need to sort the bitmask to 
1296-         # match the order of the requests used here. 
1288+         # We receive the structured output bitmask from the scheduler, 
1289+         # compacted to contain bitmasks only for structured output requests. 
1290+         # The order of the requests in the bitmask is not guaranteed to be the 
1291+         # same as the order of the requests in the gpu runner's batch. We need 
1292+         # to sort the bitmask to match the order of the requests used here. 
1293+ 
1294+         # Get the batch indices of the structured output requests. 
1295+         # Keep track of the number of speculative tokens scheduled for every 
1296+         # request in the batch, as the logit indices are offset by this amount. 
12971297        struct_out_req_batch_indices : dict [str , int ] =  {}
1298-         indices_match  =  True 
1299-         for  req_id  in  self .input_batch .req_ids :
1300-             mask_index  =  scheduler_output .structured_output_request_ids .get (
1301-                 req_id )
1302-             if  mask_index  is  None :
1303-                 # not a structured output request 
1304-                 continue 
1305-             batch_index  =  self .input_batch .req_id_to_index [req_id ]
1306-             if  batch_index  !=  mask_index :
1307-                 indices_match  =  False 
1308-             struct_out_req_batch_indices [req_id ] =  batch_index 
1309- 
1310-         if  not  indices_match :
1311-             # Sort the bitmask to match the order of the requests 
1312-             sorted_bitmask  =  np .zeros_like (grammar_bitmask )
1313-             for  req_id , batch_index  in  struct_out_req_batch_indices .items ():
1314-                 orig_index  =  scheduler_output .structured_output_request_ids [
1315-                     req_id ]
1316-                 sorted_bitmask [batch_index ] =  grammar_bitmask [orig_index ]
1317-             grammar_bitmask  =  sorted_bitmask 
1298+         cumulative_offset  =  0 
1299+         seq  =  sorted (self .input_batch .req_id_to_index .items (),
1300+                      key = lambda  x : x [1 ])
1301+         for  req_id , batch_index  in  seq :
1302+             logit_index  =  batch_index  +  cumulative_offset 
1303+             cumulative_offset  +=  len (
1304+                 scheduler_output .scheduled_spec_decode_tokens .get (req_id , []))
1305+             if  req_id  in  scheduler_output .structured_output_request_ids :
1306+                 struct_out_req_batch_indices [req_id ] =  logit_index 
1307+ 
1308+         out_indices  =  []
1309+ 
1310+         # Reorder the bitmask to match the order of the requests in the batch. 
1311+         sorted_bitmask  =  np .zeros_like (grammar_bitmask ,
1312+                                        shape = (logits .shape [0 ],
1313+                                               grammar_bitmask .shape [1 ]))
1314+         cumulative_index  =  0 
1315+         seq  =  sorted (scheduler_output .structured_output_request_ids .items (),
1316+                      key = lambda  x : x [1 ])
1317+         for  req_id , _  in  seq :
1318+             logit_index  =  struct_out_req_batch_indices [req_id ]
1319+             num_spec_tokens  =  len (
1320+                 scheduler_output .scheduled_spec_decode_tokens .get (req_id , []))
1321+             for  i  in  range (1  +  num_spec_tokens ):
1322+                 sorted_bitmask [logit_index  +  i ] =  \
1323+                     grammar_bitmask [cumulative_index  +  i ]
1324+                 out_indices .append (logit_index  +  i )
1325+             cumulative_index  +=  1  +  num_spec_tokens 
1326+         grammar_bitmask  =  sorted_bitmask 
13181327
1328+         # Serialization of np.ndarray is much more efficient than a tensor, 
1329+         # so we receive it in that format. 
13191330        grammar_bitmask  =  torch .from_numpy (grammar_bitmask )
13201331
1321-         # TODO: compatibility with spec decode. 
13221332        # NOTE: 
13231333        # 1. XGrammar bitmask applying only supports CPU and GPU. 
13241334        # 2. The logits and bitmask should be on the same device. 
@@ -1328,7 +1338,7 @@ def apply_grammar_bitmask(
13281338        xgr .apply_token_bitmask_inplace (
13291339            logits ,
13301340            grammar_bitmask ,
1331-             indices = list ( struct_out_req_batch_indices . values ()) ,
1341+             indices = out_indices ,
13321342        )
13331343        return  logits .to (self .device ).to (logits_dtype )
13341344
0 commit comments