Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 62 additions & 1 deletion vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
LayerBlockType, cdiv)
LayerBlockType, LazyLoader, cdiv)
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheSpec)
Expand All @@ -52,7 +52,10 @@
from vllm_ascend.platform import NPUPlatform

if TYPE_CHECKING:
import xgrammar as xgr # type: ignore[import-untyped]
from vllm.v1.core.sched.output import SchedulerOutput
else:
xgr = LazyLoader("xgr", globals(), "xgrammar")


class NPUModelRunner:
Expand Down Expand Up @@ -493,6 +496,60 @@ def _process_reqs(

return hidden_states[sample_indices]

def apply_grammar_bitmask(
self,
scheduler_output: "SchedulerOutput",
logits: torch.Tensor,
) -> torch.Tensor:
# Serialization of np.ndarray is much more efficient than a tensor,
# so we receive it in that format.
grammar_bitmask = scheduler_output.grammar_bitmask
if grammar_bitmask is None:
return

# We receive the structured output bitmask from the scheduler, but the
# indices of the requests in the batch may not match the indices of
# the bitmask since the scheduler doesn't know how the gpu runner is
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GPU comment

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I will modify it later.

# ordering the requests in the batch. We need to sort the bitmask to
# match the order of the requests used here.
struct_out_req_batch_indices: dict[str, int] = {}
indices_match = True
for req_id in self.input_batch.req_ids:
mask_index = scheduler_output.structured_output_request_ids.get(
req_id)
if mask_index is None:
# not a structured output request
continue
batch_index = self.input_batch.req_id_to_index[req_id]
if batch_index != mask_index:
indices_match = False
struct_out_req_batch_indices[req_id] = batch_index

if not indices_match:
# Sort the bitmask to match the order of the requests
sorted_bitmask = np.zeros_like(grammar_bitmask)
for req_id, batch_index in struct_out_req_batch_indices.items():
orig_index = scheduler_output.structured_output_request_ids[
req_id]
sorted_bitmask[batch_index] = grammar_bitmask[orig_index]
grammar_bitmask = sorted_bitmask

grammar_bitmask = torch.from_numpy(grammar_bitmask)

# TODO: compatibility with spec decode.
# NOTE:
# 1. XGrammar bitmask applying only supports CPU and GPU.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

# 2. The logits and bitmask should be on the same device.
# 3. XGrammar logits on CPU only supports float32 dtype.
logits_dtype = logits.dtype
logits = logits.to("cpu").float()
xgr.apply_token_bitmask_inplace(
logits,
grammar_bitmask,
indices=list(struct_out_req_batch_indices.values()),
)
return logits.to(self.device).to(logits_dtype)

@torch.inference_mode()
def execute_model(
self,
Expand All @@ -507,6 +564,10 @@ def execute_model(
intermediate_tensors)
logits = self.model.compute_logits(hidden_states, None)

# Apply structured output bitmasks if present
if scheduler_output.grammar_bitmask is not None:
logits = self.apply_grammar_bitmask(scheduler_output, logits)

# Sample the next token and get logprobs if needed.
sampling_metadata = self.input_batch.sampling_metadata
sampler_output = self.model.sample(
Expand Down