diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index f5438e3fce..3425ce76d4 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -38,7 +38,7 @@ from vllm.sampling_params import SamplingType from vllm.sequence import IntermediateTensors from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, - LayerBlockType, cdiv) + LayerBlockType, LazyLoader, cdiv) from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheSpec) @@ -52,7 +52,10 @@ from vllm_ascend.platform import NPUPlatform if TYPE_CHECKING: + import xgrammar as xgr # type: ignore[import-untyped] from vllm.v1.core.sched.output import SchedulerOutput +else: + xgr = LazyLoader("xgr", globals(), "xgrammar") class NPUModelRunner: @@ -493,6 +496,60 @@ def _process_reqs( return hidden_states[sample_indices] + def apply_grammar_bitmask( + self, + scheduler_output: "SchedulerOutput", + logits: torch.Tensor, + ) -> torch.Tensor: + # Serialization of np.ndarray is much more efficient than a tensor, + # so we receive it in that format. + grammar_bitmask = scheduler_output.grammar_bitmask + if grammar_bitmask is None: + return + + # We receive the structured output bitmask from the scheduler, but the + # indices of the requests in the batch may not match the indices of + # the bitmask since the scheduler doesn't know how the gpu runner is + # ordering the requests in the batch. We need to sort the bitmask to + # match the order of the requests used here. + struct_out_req_batch_indices: dict[str, int] = {} + indices_match = True + for req_id in self.input_batch.req_ids: + mask_index = scheduler_output.structured_output_request_ids.get( + req_id) + if mask_index is None: + # not a structured output request + continue + batch_index = self.input_batch.req_id_to_index[req_id] + if batch_index != mask_index: + indices_match = False + struct_out_req_batch_indices[req_id] = batch_index + + if not indices_match: + # Sort the bitmask to match the order of the requests + sorted_bitmask = np.zeros_like(grammar_bitmask) + for req_id, batch_index in struct_out_req_batch_indices.items(): + orig_index = scheduler_output.structured_output_request_ids[ + req_id] + sorted_bitmask[batch_index] = grammar_bitmask[orig_index] + grammar_bitmask = sorted_bitmask + + grammar_bitmask = torch.from_numpy(grammar_bitmask) + + # TODO: compatibility with spec decode. + # NOTE: + # 1. XGrammar bitmask applying only supports CPU and GPU. + # 2. The logits and bitmask should be on the same device. + # 3. XGrammar logits on CPU only supports float32 dtype. + logits_dtype = logits.dtype + logits = logits.to("cpu").float() + xgr.apply_token_bitmask_inplace( + logits, + grammar_bitmask, + indices=list(struct_out_req_batch_indices.values()), + ) + return logits.to(self.device).to(logits_dtype) + @torch.inference_mode() def execute_model( self, @@ -507,6 +564,10 @@ def execute_model( intermediate_tensors) logits = self.model.compute_logits(hidden_states, None) + # Apply structured output bitmasks if present + if scheduler_output.grammar_bitmask is not None: + logits = self.apply_grammar_bitmask(scheduler_output, logits) + # Sample the next token and get logprobs if needed. sampling_metadata = self.input_batch.sampling_metadata sampler_output = self.model.sample(