3737from vllm .multimodal import MULTIMODAL_REGISTRY , MultiModalKwargs
3838from vllm .sampling_params import SamplingType
3939from vllm .sequence import IntermediateTensors
40- from vllm .utils import DeviceMemoryProfiler , LayerBlockType , cdiv
40+ from vllm .utils import ( DeviceMemoryProfiler , LayerBlockType , cdiv , LazyLoader )
4141from vllm .v1 .core .encoder_cache_manager import compute_encoder_budget
4242from vllm .v1 .kv_cache_interface import (FullAttentionSpec , KVCacheConfig ,
4343 KVCacheSpec )
5050from vllm_ascend .platform import NPUPlatform
5151
5252if TYPE_CHECKING :
53+ import xgrammar as xgr
5354 from vllm .v1 .core .sched .output import SchedulerOutput
55+ else :
56+ xgr = LazyLoader ("xgr" , globals (), "xgrammar" )
5457
5558NPU_PAGED_ATTENTION_MASK_VALUE = - 10000
5659
@@ -474,6 +477,59 @@ def _process_reqs(
474477
475478 return hidden_states [cu_num_tokens - 1 ]
476479
480+ def apply_grammar_bitmask (
481+ self ,
482+ scheduler_output : "SchedulerOutput" ,
483+ logits : torch .Tensor ,
484+ ):
485+ # Serialization of np.ndarray is much more efficient than a tensor,
486+ # so we receive it in that format.
487+ grammar_bitmask = scheduler_output .grammar_bitmask
488+ if grammar_bitmask is None :
489+ return
490+
491+ # We receive the structured output bitmask from the scheduler, but the
492+ # indices of the requests in the batch may not match the indices of
493+ # the bitmask since the scheduler doesn't know how the gpu runner is
494+ # ordering the requests in the batch. We need to sort the bitmask to
495+ # match the order of the requests used here.
496+ struct_out_req_batch_indices : dict [str , int ] = {}
497+ indices_match = True
498+ for req_id in self .input_batch .req_ids :
499+ mask_index = scheduler_output .structured_output_request_ids .get (
500+ req_id )
501+ if mask_index is None :
502+ # not a structured output request
503+ continue
504+ batch_index = self .input_batch .req_id_to_index [req_id ]
505+ if batch_index != mask_index :
506+ indices_match = False
507+ struct_out_req_batch_indices [req_id ] = batch_index
508+
509+ if not indices_match :
510+ # Sort the bitmask to match the order of the requests
511+ sorted_bitmask = np .zeros_like (grammar_bitmask )
512+ for req_id , batch_index in struct_out_req_batch_indices .items ():
513+ orig_index = scheduler_output .structured_output_request_ids [
514+ req_id ]
515+ sorted_bitmask [batch_index ] = grammar_bitmask [orig_index ]
516+ grammar_bitmask = sorted_bitmask
517+
518+ grammar_bitmask = torch .from_numpy (grammar_bitmask )
519+
520+ # TODO: compatibility with spec decode.
521+ # NOTE:
522+ # 1. The logits and bitmask should be on the same device.
523+ # 2. XGrammar on cpu only supports float32 logits.
524+ logits_dtype = logits .dtype
525+ logits = logits .to ("cpu" ).float ()
526+ xgr .apply_token_bitmask_inplace (
527+ logits ,
528+ grammar_bitmask ,
529+ indices = list (struct_out_req_batch_indices .values ()),
530+ )
531+ logits = logits .to (self .device ).to (logits_dtype )
532+
477533 @torch .inference_mode ()
478534 def execute_model (
479535 self ,
@@ -488,6 +544,10 @@ def execute_model(
488544 intermediate_tensors )
489545 logits = self .model .compute_logits (hidden_states , None )
490546
547+ # Apply structured output bitmasks if present
548+ if scheduler_output .grammar_bitmask is not None :
549+ self .apply_grammar_bitmask (scheduler_output , logits )
550+
491551 # Sample the next token and get logprobs if needed.
492552 sampling_metadata = self .input_batch .sampling_metadata
493553 sampler_output = self .model .sample (
0 commit comments