Skip to content

Commit b207e9e

Browse files
committed
add structured output mask apply to ModelRunner
Signed-off-by: shen-shanshan <467638484@qq.com>
1 parent 697908f commit b207e9e

File tree

1 file changed

+61
-1
lines changed

1 file changed

+61
-1
lines changed

vllm_ascend/worker/model_runner_v1.py

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
3838
from vllm.sampling_params import SamplingType
3939
from vllm.sequence import IntermediateTensors
40-
from vllm.utils import DeviceMemoryProfiler, LayerBlockType, cdiv
40+
from vllm.utils import (DeviceMemoryProfiler, LayerBlockType, cdiv, LazyLoader)
4141
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
4242
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
4343
KVCacheSpec)
@@ -50,7 +50,10 @@
5050
from vllm_ascend.platform import NPUPlatform
5151

5252
if TYPE_CHECKING:
53+
import xgrammar as xgr
5354
from vllm.v1.core.sched.output import SchedulerOutput
55+
else:
56+
xgr = LazyLoader("xgr", globals(), "xgrammar")
5457

5558
NPU_PAGED_ATTENTION_MASK_VALUE = -10000
5659

@@ -474,6 +477,59 @@ def _process_reqs(
474477

475478
return hidden_states[cu_num_tokens - 1]
476479

480+
def apply_grammar_bitmask(
481+
self,
482+
scheduler_output: "SchedulerOutput",
483+
logits: torch.Tensor,
484+
):
485+
# Serialization of np.ndarray is much more efficient than a tensor,
486+
# so we receive it in that format.
487+
grammar_bitmask = scheduler_output.grammar_bitmask
488+
if grammar_bitmask is None:
489+
return
490+
491+
# We receive the structured output bitmask from the scheduler, but the
492+
# indices of the requests in the batch may not match the indices of
493+
# the bitmask since the scheduler doesn't know how the gpu runner is
494+
# ordering the requests in the batch. We need to sort the bitmask to
495+
# match the order of the requests used here.
496+
struct_out_req_batch_indices: dict[str, int] = {}
497+
indices_match = True
498+
for req_id in self.input_batch.req_ids:
499+
mask_index = scheduler_output.structured_output_request_ids.get(
500+
req_id)
501+
if mask_index is None:
502+
# not a structured output request
503+
continue
504+
batch_index = self.input_batch.req_id_to_index[req_id]
505+
if batch_index != mask_index:
506+
indices_match = False
507+
struct_out_req_batch_indices[req_id] = batch_index
508+
509+
if not indices_match:
510+
# Sort the bitmask to match the order of the requests
511+
sorted_bitmask = np.zeros_like(grammar_bitmask)
512+
for req_id, batch_index in struct_out_req_batch_indices.items():
513+
orig_index = scheduler_output.structured_output_request_ids[
514+
req_id]
515+
sorted_bitmask[batch_index] = grammar_bitmask[orig_index]
516+
grammar_bitmask = sorted_bitmask
517+
518+
grammar_bitmask = torch.from_numpy(grammar_bitmask)
519+
520+
# TODO: compatibility with spec decode.
521+
# NOTE:
522+
# 1. The logits and bitmask should be on the same device.
523+
# 2. XGrammar on cpu only supports float32 logits.
524+
logits_dtype = logits.dtype
525+
logits = logits.to("cpu").float()
526+
xgr.apply_token_bitmask_inplace(
527+
logits,
528+
grammar_bitmask,
529+
indices=list(struct_out_req_batch_indices.values()),
530+
)
531+
logits = logits.to(self.device).to(logits_dtype)
532+
477533
@torch.inference_mode()
478534
def execute_model(
479535
self,
@@ -488,6 +544,10 @@ def execute_model(
488544
intermediate_tensors)
489545
logits = self.model.compute_logits(hidden_states, None)
490546

547+
# Apply structured output bitmasks if present
548+
if scheduler_output.grammar_bitmask is not None:
549+
self.apply_grammar_bitmask(scheduler_output, logits)
550+
491551
# Sample the next token and get logprobs if needed.
492552
sampling_metadata = self.input_batch.sampling_metadata
493553
sampler_output = self.model.sample(

0 commit comments

Comments
 (0)