@@ -578,53 +578,63 @@ def apply_grammar_bitmask(
578578 scheduler_output : "SchedulerOutput" ,
579579 logits : torch .Tensor ,
580580 ) -> torch .Tensor :
581- # Serialization of np.ndarray is much more efficient than a tensor,
582- # so we receive it in that format.
583581 grammar_bitmask = scheduler_output .grammar_bitmask
584582 if grammar_bitmask is None :
585583 return
586584
587- # We receive the structured output bitmask from the scheduler, but the
588- # indices of the requests in the batch may not match the indices of
589- # the bitmask since the scheduler doesn't know how the gpu runner is
590- # ordering the requests in the batch. We need to sort the bitmask to
591- # match the order of the requests used here.
585+ # We receive the structured output bitmask from the scheduler,
586+ # compacted to contain bitmasks only for structured output requests.
587+ # The order of the requests in the bitmask is not guaranteed to be the
588+ # same as the order of the requests in the gpu runner's batch. We need
589+ # to sort the bitmask to match the order of the requests used here.
590+
591+ # Get the batch indices of the structured output requests.
592+ # Keep track of the number of speculative tokens scheduled for every
593+ # request in the batch, as the logit indices are offset by this amount.
592594 struct_out_req_batch_indices : dict [str , int ] = {}
593- indices_match = True
594- for req_id in self .input_batch .req_ids :
595- mask_index = scheduler_output .structured_output_request_ids .get (
596- req_id )
597- if mask_index is None :
598- # not a structured output request
599- continue
600- batch_index = self .input_batch .req_id_to_index [req_id ]
601- if batch_index != mask_index :
602- indices_match = False
603- struct_out_req_batch_indices [req_id ] = batch_index
604-
605- if not indices_match :
606- # Sort the bitmask to match the order of the requests
607- sorted_bitmask = np .zeros_like (grammar_bitmask )
608- for req_id , batch_index in struct_out_req_batch_indices .items ():
609- orig_index = scheduler_output .structured_output_request_ids [
610- req_id ]
611- sorted_bitmask [batch_index ] = grammar_bitmask [orig_index ]
612- grammar_bitmask = sorted_bitmask
613595
596+ cumulative_offset = 0
597+ seq = sorted (self .input_batch .req_id_to_index .items (),
598+ key = lambda x : x [1 ])
599+ for req_id , batch_index in seq :
600+ logit_index = batch_index + cumulative_offset
601+ cumulative_offset += len (
602+ scheduler_output .scheduled_spec_decode_tokens .get (req_id , []))
603+ if req_id in scheduler_output .structured_output_request_ids :
604+ struct_out_req_batch_indices [req_id ] = logit_index
605+
606+ out_indices = []
607+
608+ # Reorder the bitmask to match the order of the requests in the batch.
609+ sorted_bitmask = np .zeros_like (grammar_bitmask ,
610+ shape = (logits .shape [0 ],
611+ grammar_bitmask .shape [1 ]))
612+ cumulative_index = 0
613+ seq = sorted (scheduler_output .structured_output_request_ids .items (),
614+ key = lambda x : x [1 ])
615+ for req_id , _ in seq :
616+ logit_index = struct_out_req_batch_indices [req_id ]
617+ num_spec_tokens = len (
618+ scheduler_output .scheduled_spec_decode_tokens .get (req_id , []))
619+ for i in range (1 + num_spec_tokens ):
620+ sorted_bitmask [logit_index + i ] = \
621+ grammar_bitmask [cumulative_index + i ]
622+ out_indices .append (logit_index + i )
623+ cumulative_index += 1 + num_spec_tokens
624+ grammar_bitmask = sorted_bitmask
625+
626+ # Serialization of np.ndarray is much more efficient than a tensor,
627+ # so we receive it in that format.
614628 grammar_bitmask = torch .from_numpy (grammar_bitmask )
615629
616- # TODO: compatibility with spec decode.
617630 # NOTE:
618- # 1. XGrammar bitmask applying only supports CPU and GPU .
631+ # 1. XGrammar bitmask applying only supports CPU device .
619632 # 2. The logits and bitmask should be on the same device.
620633 # 3. XGrammar logits on CPU only supports float32 dtype.
621634 logits_dtype = logits .dtype
622635 logits = logits .to ("cpu" ).float ()
623- xgr .apply_token_bitmask_inplace (
624- logits ,
625- grammar_bitmask ,
626- indices = list (struct_out_req_batch_indices .values ()),
627- )
636+ xgr .apply_token_bitmask_inplace (logits , grammar_bitmask ,
637+ indices = out_indices )
628638 return logits .to (self .device ).to (logits_dtype )
629639
630640 @torch .inference_mode ()
0 commit comments