3838from  vllm .sampling_params  import  SamplingType 
3939from  vllm .sequence  import  IntermediateTensors 
4040from  vllm .utils  import  (STR_DTYPE_TO_TORCH_DTYPE , DeviceMemoryProfiler ,
41-                         LayerBlockType , cdiv )
41+                         LayerBlockType , LazyLoader ,  cdiv )
4242from  vllm .v1 .core .encoder_cache_manager  import  compute_encoder_budget 
4343from  vllm .v1 .kv_cache_interface  import  (FullAttentionSpec , KVCacheConfig ,
4444                                        KVCacheSpec )
5252from  vllm_ascend .platform  import  NPUPlatform 
5353
5454if  TYPE_CHECKING :
55+     import  xgrammar  as  xgr   # type: ignore[import-untyped] 
5556    from  vllm .v1 .core .sched .output  import  SchedulerOutput 
57+ else :
58+     xgr  =  LazyLoader ("xgr" , globals (), "xgrammar" )
5659
5760
5861class  NPUModelRunner :
@@ -493,6 +496,60 @@ def _process_reqs(
493496
494497        return  hidden_states [sample_indices ]
495498
499+     def  apply_grammar_bitmask (
500+         self ,
501+         scheduler_output : "SchedulerOutput" ,
502+         logits : torch .Tensor ,
503+     ) ->  torch .Tensor :
504+         # Serialization of np.ndarray is much more efficient than a tensor, 
505+         # so we receive it in that format. 
506+         grammar_bitmask  =  scheduler_output .grammar_bitmask 
507+         if  grammar_bitmask  is  None :
508+             return 
509+ 
510+         # We receive the structured output bitmask from the scheduler, but the 
511+         # indices of the requests in the batch may not match the indices of 
512+         # the bitmask since the scheduler doesn't know how the gpu runner is 
513+         # ordering the requests in the batch. We need to sort the bitmask to 
514+         # match the order of the requests used here. 
515+         struct_out_req_batch_indices : dict [str , int ] =  {}
516+         indices_match  =  True 
517+         for  req_id  in  self .input_batch .req_ids :
518+             mask_index  =  scheduler_output .structured_output_request_ids .get (
519+                 req_id )
520+             if  mask_index  is  None :
521+                 # not a structured output request 
522+                 continue 
523+             batch_index  =  self .input_batch .req_id_to_index [req_id ]
524+             if  batch_index  !=  mask_index :
525+                 indices_match  =  False 
526+             struct_out_req_batch_indices [req_id ] =  batch_index 
527+ 
528+         if  not  indices_match :
529+             # Sort the bitmask to match the order of the requests 
530+             sorted_bitmask  =  np .zeros_like (grammar_bitmask )
531+             for  req_id , batch_index  in  struct_out_req_batch_indices .items ():
532+                 orig_index  =  scheduler_output .structured_output_request_ids [
533+                     req_id ]
534+                 sorted_bitmask [batch_index ] =  grammar_bitmask [orig_index ]
535+             grammar_bitmask  =  sorted_bitmask 
536+ 
537+         grammar_bitmask  =  torch .from_numpy (grammar_bitmask )
538+ 
539+         # TODO: compatibility with spec decode. 
540+         # NOTE: 
541+         # 1. XGrammar bitmask applying only supports CPU and GPU. 
542+         # 2. The logits and bitmask should be on the same device. 
543+         # 3. XGrammar logits on CPU only supports float32 dtype. 
544+         logits_dtype  =  logits .dtype 
545+         logits  =  logits .to ("cpu" ).float ()
546+         xgr .apply_token_bitmask_inplace (
547+             logits ,
548+             grammar_bitmask ,
549+             indices = list (struct_out_req_batch_indices .values ()),
550+         )
551+         return  logits .to (self .device ).to (logits_dtype )
552+ 
496553    @torch .inference_mode () 
497554    def  execute_model (
498555        self ,
@@ -507,6 +564,10 @@ def execute_model(
507564                                           intermediate_tensors )
508565        logits  =  self .model .compute_logits (hidden_states , None )
509566
567+         # Apply structured output bitmasks if present 
568+         if  scheduler_output .grammar_bitmask  is  not   None :
569+             logits  =  self .apply_grammar_bitmask (scheduler_output , logits )
570+ 
510571        # Sample the next token and get logprobs if needed. 
511572        sampling_metadata  =  self .input_batch .sampling_metadata 
512573        sampler_output  =  self .model .sample (
0 commit comments