Skip to content

Commit 65c1f45

Browse files
[V1][Structured Output] Add apply_grammar_bitmask() method to model runner (#555)
### What this PR does / why we need it? Add `apply_grammar_bitmask()` method to model runner. This method is necessary for `xgrammar` structured output. --------- Signed-off-by: shen-shanshan <467638484@qq.com>
1 parent 2c903bc commit 65c1f45

File tree

1 file changed

+62
-1
lines changed

1 file changed

+62
-1
lines changed

vllm_ascend/worker/model_runner_v1.py

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
from vllm.sampling_params import SamplingType
3939
from vllm.sequence import IntermediateTensors
4040
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
41-
LayerBlockType, cdiv)
41+
LayerBlockType, LazyLoader, cdiv)
4242
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
4343
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
4444
KVCacheSpec)
@@ -52,7 +52,10 @@
5252
from vllm_ascend.platform import NPUPlatform
5353

5454
if TYPE_CHECKING:
55+
import xgrammar as xgr # type: ignore[import-untyped]
5556
from vllm.v1.core.sched.output import SchedulerOutput
57+
else:
58+
xgr = LazyLoader("xgr", globals(), "xgrammar")
5659

5760

5861
class NPUModelRunner:
@@ -493,6 +496,60 @@ def _process_reqs(
493496

494497
return hidden_states[sample_indices]
495498

499+
def apply_grammar_bitmask(
500+
self,
501+
scheduler_output: "SchedulerOutput",
502+
logits: torch.Tensor,
503+
) -> torch.Tensor:
504+
# Serialization of np.ndarray is much more efficient than a tensor,
505+
# so we receive it in that format.
506+
grammar_bitmask = scheduler_output.grammar_bitmask
507+
if grammar_bitmask is None:
508+
return
509+
510+
# We receive the structured output bitmask from the scheduler, but the
511+
# indices of the requests in the batch may not match the indices of
512+
# the bitmask since the scheduler doesn't know how the gpu runner is
513+
# ordering the requests in the batch. We need to sort the bitmask to
514+
# match the order of the requests used here.
515+
struct_out_req_batch_indices: dict[str, int] = {}
516+
indices_match = True
517+
for req_id in self.input_batch.req_ids:
518+
mask_index = scheduler_output.structured_output_request_ids.get(
519+
req_id)
520+
if mask_index is None:
521+
# not a structured output request
522+
continue
523+
batch_index = self.input_batch.req_id_to_index[req_id]
524+
if batch_index != mask_index:
525+
indices_match = False
526+
struct_out_req_batch_indices[req_id] = batch_index
527+
528+
if not indices_match:
529+
# Sort the bitmask to match the order of the requests
530+
sorted_bitmask = np.zeros_like(grammar_bitmask)
531+
for req_id, batch_index in struct_out_req_batch_indices.items():
532+
orig_index = scheduler_output.structured_output_request_ids[
533+
req_id]
534+
sorted_bitmask[batch_index] = grammar_bitmask[orig_index]
535+
grammar_bitmask = sorted_bitmask
536+
537+
grammar_bitmask = torch.from_numpy(grammar_bitmask)
538+
539+
# TODO: compatibility with spec decode.
540+
# NOTE:
541+
# 1. XGrammar bitmask applying only supports CPU and GPU.
542+
# 2. The logits and bitmask should be on the same device.
543+
# 3. XGrammar logits on CPU only supports float32 dtype.
544+
logits_dtype = logits.dtype
545+
logits = logits.to("cpu").float()
546+
xgr.apply_token_bitmask_inplace(
547+
logits,
548+
grammar_bitmask,
549+
indices=list(struct_out_req_batch_indices.values()),
550+
)
551+
return logits.to(self.device).to(logits_dtype)
552+
496553
@torch.inference_mode()
497554
def execute_model(
498555
self,
@@ -507,6 +564,10 @@ def execute_model(
507564
intermediate_tensors)
508565
logits = self.model.compute_logits(hidden_states, None)
509566

567+
# Apply structured output bitmasks if present
568+
if scheduler_output.grammar_bitmask is not None:
569+
logits = self.apply_grammar_bitmask(scheduler_output, logits)
570+
510571
# Sample the next token and get logprobs if needed.
511572
sampling_metadata = self.input_batch.sampling_metadata
512573
sampler_output = self.model.sample(

0 commit comments

Comments
 (0)