@@ -1317,40 +1317,52 @@ def apply_grammar_bitmask(
13171317 scheduler_output : "SchedulerOutput" ,
13181318 logits : torch .Tensor ,
13191319 ) -> torch .Tensor :
1320- # Serialization of np.ndarray is much more efficient than a tensor,
1321- # so we receive it in that format.
13221320 grammar_bitmask = scheduler_output .grammar_bitmask
13231321
1324- # We receive the structured output bitmask from the scheduler, but the
1325- # indices of the requests in the batch may not match the indices of
1326- # the bitmask since the scheduler doesn't know how the gpu runner is
1327- # ordering the requests in the batch. We need to sort the bitmask to
1328- # match the order of the requests used here.
1322+ # We receive the structured output bitmask from the scheduler,
1323+ # compacted to contain bitmasks only for structured output requests.
1324+ # The order of the requests in the bitmask is not guaranteed to be the
1325+ # same as the order of the requests in the gpu runner's batch. We need
1326+ # to sort the bitmask to match the order of the requests used here.
1327+
1328+ # Get the batch indices of the structured output requests.
1329+ # Keep track of the number of speculative tokens scheduled for every
1330+ # request in the batch, as the logit indices are offset by this amount.
13291331 struct_out_req_batch_indices : dict [str , int ] = {}
1330- indices_match = True
1331- for req_id in self .input_batch .req_ids :
1332- mask_index = scheduler_output .structured_output_request_ids .get (
1333- req_id )
1334- if mask_index is None :
1335- # not a structured output request
1336- continue
1337- batch_index = self .input_batch .req_id_to_index [req_id ]
1338- if batch_index != mask_index :
1339- indices_match = False
1340- struct_out_req_batch_indices [req_id ] = batch_index
1341-
1342- if not indices_match :
1343- # Sort the bitmask to match the order of the requests
1344- sorted_bitmask = np .zeros_like (grammar_bitmask )
1345- for req_id , batch_index in struct_out_req_batch_indices .items ():
1346- orig_index = scheduler_output .structured_output_request_ids [
1347- req_id ]
1348- sorted_bitmask [batch_index ] = grammar_bitmask [orig_index ]
1349- grammar_bitmask = sorted_bitmask
1332+ cumulative_offset = 0
1333+ seq = sorted (self .input_batch .req_id_to_index .items (),
1334+ key = lambda x : x [1 ])
1335+ for req_id , batch_index in seq :
1336+ logit_index = batch_index + cumulative_offset
1337+ cumulative_offset += len (
1338+ scheduler_output .scheduled_spec_decode_tokens .get (req_id , []))
1339+ if req_id in scheduler_output .structured_output_request_ids :
1340+ struct_out_req_batch_indices [req_id ] = logit_index
1341+
1342+ out_indices = []
1343+
1344+ # Reorder the bitmask to match the order of the requests in the batch.
1345+ sorted_bitmask = np .zeros_like (grammar_bitmask ,
1346+ shape = (logits .shape [0 ],
1347+ grammar_bitmask .shape [1 ]))
1348+ cumulative_index = 0
1349+ seq = sorted (scheduler_output .structured_output_request_ids .items (),
1350+ key = lambda x : x [1 ])
1351+ for req_id , _ in seq :
1352+ logit_index = struct_out_req_batch_indices [req_id ]
1353+ num_spec_tokens = len (
1354+ scheduler_output .scheduled_spec_decode_tokens .get (req_id , []))
1355+ for i in range (1 + num_spec_tokens ):
1356+ sorted_bitmask [logit_index + i ] = \
1357+ grammar_bitmask [cumulative_index + i ]
1358+ out_indices .append (logit_index + i )
1359+ cumulative_index += 1 + num_spec_tokens
1360+ grammar_bitmask = sorted_bitmask
13501361
1362+ # Serialization of np.ndarray is much more efficient than a tensor,
1363+ # so we receive it in that format.
13511364 grammar_bitmask = torch .from_numpy (grammar_bitmask )
13521365
1353- # TODO: compatibility with spec decode.
13541366 # NOTE:
13551367 # 1. XGrammar bitmask applying only supports CPU and GPU.
13561368 # 2. The logits and bitmask should be on the same device.
@@ -1360,7 +1372,7 @@ def apply_grammar_bitmask(
13601372 xgr .apply_token_bitmask_inplace (
13611373 logits ,
13621374 grammar_bitmask ,
1363- indices = list ( struct_out_req_batch_indices . values ()) ,
1375+ indices = out_indices ,
13641376 )
13651377 return logits .to (self .device ).to (logits_dtype )
13661378
0 commit comments