@@ -1348,40 +1348,52 @@ def apply_grammar_bitmask(
13481348 scheduler_output : "SchedulerOutput" ,
13491349 logits : torch .Tensor ,
13501350 ) -> torch .Tensor :
1351- # Serialization of np.ndarray is much more efficient than a tensor,
1352- # so we receive it in that format.
13531351 grammar_bitmask = scheduler_output .grammar_bitmask
13541352
1355- # We receive the structured output bitmask from the scheduler, but the
1356- # indices of the requests in the batch may not match the indices of
1357- # the bitmask since the scheduler doesn't know how the gpu runner is
1358- # ordering the requests in the batch. We need to sort the bitmask to
1359- # match the order of the requests used here.
1353+ # We receive the structured output bitmask from the scheduler,
1354+ # compacted to contain bitmasks only for structured output requests.
1355+ # The order of the requests in the bitmask is not guaranteed to be the
1356+ # same as the order of the requests in the gpu runner's batch. We need
1357+ # to sort the bitmask to match the order of the requests used here.
1358+
1359+ # Get the batch indices of the structured output requests.
1360+ # Keep track of the number of speculative tokens scheduled for every
1361+ # request in the batch, as the logit indices are offset by this amount.
13601362 struct_out_req_batch_indices : dict [str , int ] = {}
1361- indices_match = True
1362- for req_id in self .input_batch .req_ids :
1363- mask_index = scheduler_output .structured_output_request_ids .get (
1364- req_id )
1365- if mask_index is None :
1366- # not a structured output request
1367- continue
1368- batch_index = self .input_batch .req_id_to_index [req_id ]
1369- if batch_index != mask_index :
1370- indices_match = False
1371- struct_out_req_batch_indices [req_id ] = batch_index
1372-
1373- if not indices_match :
1374- # Sort the bitmask to match the order of the requests
1375- sorted_bitmask = np .zeros_like (grammar_bitmask )
1376- for req_id , batch_index in struct_out_req_batch_indices .items ():
1377- orig_index = scheduler_output .structured_output_request_ids [
1378- req_id ]
1379- sorted_bitmask [batch_index ] = grammar_bitmask [orig_index ]
1380- grammar_bitmask = sorted_bitmask
1363+ cumulative_offset = 0
1364+ seq = sorted (self .input_batch .req_id_to_index .items (),
1365+ key = lambda x : x [1 ])
1366+ for req_id , batch_index in seq :
1367+ logit_index = batch_index + cumulative_offset
1368+ cumulative_offset += len (
1369+ scheduler_output .scheduled_spec_decode_tokens .get (req_id , []))
1370+ if req_id in scheduler_output .structured_output_request_ids :
1371+ struct_out_req_batch_indices [req_id ] = logit_index
1372+
1373+ out_indices = []
1374+
1375+ # Reorder the bitmask to match the order of the requests in the batch.
1376+ sorted_bitmask = np .zeros_like (grammar_bitmask ,
1377+ shape = (logits .shape [0 ],
1378+ grammar_bitmask .shape [1 ]))
1379+ cumulative_index = 0
1380+ seq = sorted (scheduler_output .structured_output_request_ids .items (),
1381+ key = lambda x : x [1 ])
1382+ for req_id , _ in seq :
1383+ logit_index = struct_out_req_batch_indices [req_id ]
1384+ num_spec_tokens = len (
1385+ scheduler_output .scheduled_spec_decode_tokens .get (req_id , []))
1386+ for i in range (1 + num_spec_tokens ):
1387+ sorted_bitmask [logit_index + i ] = \
1388+ grammar_bitmask [cumulative_index + i ]
1389+ out_indices .append (logit_index + i )
1390+ cumulative_index += 1 + num_spec_tokens
1391+ grammar_bitmask = sorted_bitmask
13811392
1393+ # Serialization of np.ndarray is much more efficient than a tensor,
1394+ # so we receive it in that format.
13821395 grammar_bitmask = torch .from_numpy (grammar_bitmask )
13831396
1384- # TODO: compatibility with spec decode.
13851397 # NOTE:
13861398 # 1. XGrammar bitmask applying only supports CPU and GPU.
13871399 # 2. The logits and bitmask should be on the same device.
@@ -1391,7 +1403,7 @@ def apply_grammar_bitmask(
13911403 xgr .apply_token_bitmask_inplace (
13921404 logits ,
13931405 grammar_bitmask ,
1394- indices = list ( struct_out_req_batch_indices . values ()) ,
1406+ indices = out_indices ,
13951407 )
13961408 return logits .to (self .device ).to (logits_dtype )
13971409
0 commit comments