3838from vllm .sampling_params import SamplingType
3939from vllm .sequence import IntermediateTensors
4040from vllm .utils import (STR_DTYPE_TO_TORCH_DTYPE , DeviceMemoryProfiler ,
41- LayerBlockType , cdiv )
41+ LayerBlockType , LazyLoader , cdiv )
4242from vllm .v1 .core .encoder_cache_manager import compute_encoder_budget
4343from vllm .v1 .kv_cache_interface import (FullAttentionSpec , KVCacheConfig ,
4444 KVCacheSpec )
5252from vllm_ascend .platform import NPUPlatform
5353
5454if TYPE_CHECKING :
55+ import xgrammar as xgr # type: ignore[import-untyped]
5556 from vllm .v1 .core .sched .output import SchedulerOutput
57+ else :
58+ xgr = LazyLoader ("xgr" , globals (), "xgrammar" )
5659
5760
5861class NPUModelRunner :
@@ -493,6 +496,60 @@ def _process_reqs(
493496
494497 return hidden_states [sample_indices ]
495498
499+ def apply_grammar_bitmask (
500+ self ,
501+ scheduler_output : "SchedulerOutput" ,
502+ logits : torch .Tensor ,
503+ ) -> torch .Tensor :
504+ # Serialization of np.ndarray is much more efficient than a tensor,
505+ # so we receive it in that format.
506+ grammar_bitmask = scheduler_output .grammar_bitmask
507+ if grammar_bitmask is None :
508+ return
509+
510+ # We receive the structured output bitmask from the scheduler, but the
511+ # indices of the requests in the batch may not match the indices of
512+ # the bitmask since the scheduler doesn't know how the gpu runner is
513+ # ordering the requests in the batch. We need to sort the bitmask to
514+ # match the order of the requests used here.
515+ struct_out_req_batch_indices : dict [str , int ] = {}
516+ indices_match = True
517+ for req_id in self .input_batch .req_ids :
518+ mask_index = scheduler_output .structured_output_request_ids .get (
519+ req_id )
520+ if mask_index is None :
521+ # not a structured output request
522+ continue
523+ batch_index = self .input_batch .req_id_to_index [req_id ]
524+ if batch_index != mask_index :
525+ indices_match = False
526+ struct_out_req_batch_indices [req_id ] = batch_index
527+
528+ if not indices_match :
529+ # Sort the bitmask to match the order of the requests
530+ sorted_bitmask = np .zeros_like (grammar_bitmask )
531+ for req_id , batch_index in struct_out_req_batch_indices .items ():
532+ orig_index = scheduler_output .structured_output_request_ids [
533+ req_id ]
534+ sorted_bitmask [batch_index ] = grammar_bitmask [orig_index ]
535+ grammar_bitmask = sorted_bitmask
536+
537+ grammar_bitmask = torch .from_numpy (grammar_bitmask )
538+
539+ # TODO: compatibility with spec decode.
540+ # NOTE:
541+ # 1. XGrammar bitmask applying only supports CPU and GPU.
542+ # 2. The logits and bitmask should be on the same device.
543+ # 3. XGrammar logits on CPU only supports float32 dtype.
544+ logits_dtype = logits .dtype
545+ logits = logits .to ("cpu" ).float ()
546+ xgr .apply_token_bitmask_inplace (
547+ logits ,
548+ grammar_bitmask ,
549+ indices = list (struct_out_req_batch_indices .values ()),
550+ )
551+ return logits .to (self .device ).to (logits_dtype )
552+
496553 @torch .inference_mode ()
497554 def execute_model (
498555 self ,
@@ -507,6 +564,10 @@ def execute_model(
507564 intermediate_tensors )
508565 logits = self .model .compute_logits (hidden_states , None )
509566
567+ # Apply structured output bitmasks if present
568+ if scheduler_output .grammar_bitmask is not None :
569+ logits = self .apply_grammar_bitmask (scheduler_output , logits )
570+
510571 # Sample the next token and get logprobs if needed.
511572 sampling_metadata = self .input_batch .sampling_metadata
512573 sampler_output = self .model .sample (
0 commit comments