Skip to content

Commit 470484a

Browse files
[Structured Output][Refactor] Move apply_grammar_bitmask() method from ModelRunner to structured output utils (#21999)
Signed-off-by: shen-shanshan <467638484@qq.com>
1 parent 21da733 commit 470484a

File tree

2 files changed

+84
-71
lines changed

2 files changed

+84
-71
lines changed

vllm/v1/structured_output/utils.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88
import os
99
from typing import TYPE_CHECKING
1010

11+
import numpy as np
1112
import regex as re
13+
import torch
1214
from cachetools import LRUCache
1315
from diskcache import Cache
1416

@@ -20,9 +22,13 @@
2022
import outlines_core as oc
2123
import transformers.file_utils as file_utils
2224
import transformers.models.gpt2.tokenization_gpt2 as tokenization_gpt2
25+
import xgrammar as xgr
2326

2427
from vllm.transformers_utils.tokenizer import AnyTokenizer
28+
from vllm.v1.core.sched.output import SchedulerOutput
29+
from vllm.v1.worker.gpu_input_batch import InputBatch
2530
else:
31+
xgr = LazyLoader("xgr", globals(), "xgrammar")
2632
oc = LazyLoader("oc", globals(), "outlines_core")
2733
file_utils = LazyLoader("file_utils", globals(), "transformers.file_utils")
2834
tokenization_gpt2 = LazyLoader(
@@ -36,6 +42,80 @@
3642
CACHE = None
3743

3844

45+
def apply_grammar_bitmask(
46+
scheduler_output: SchedulerOutput,
47+
input_batch: InputBatch,
48+
logits: torch.Tensor,
49+
device: torch.device,
50+
) -> None:
51+
"""
52+
Apply grammar bitmask to output logits of the model with xgrammar function.
53+
54+
Args:
55+
scheduler_output (SchedulerOutput): The result of engine scheduling.
56+
input_batch (InputBatch): The input of model runner.
57+
logits (torch.Tensor): The output logits of model forward.
58+
device (torch.device): The device that model runner running on.
59+
"""
60+
grammar_bitmask = scheduler_output.grammar_bitmask
61+
if grammar_bitmask is None:
62+
return
63+
64+
# We receive the structured output bitmask from the scheduler,
65+
# compacted to contain bitmasks only for structured output requests.
66+
# The order of the requests in the bitmask is not guaranteed to be the
67+
# same as the order of the requests in the gpu runner's batch. We need
68+
# to sort the bitmask to match the order of the requests used here.
69+
70+
# Get the batch indices of the structured output requests.
71+
# Keep track of the number of speculative tokens scheduled for every
72+
# request in the batch, as the logit indices are offset by this amount.
73+
struct_out_req_batch_indices: dict[str, int] = {}
74+
cumulative_offset = 0
75+
seq = sorted(input_batch.req_id_to_index.items(), key=lambda x: x[1])
76+
for req_id, batch_index in seq:
77+
logit_index = batch_index + cumulative_offset
78+
cumulative_offset += len(
79+
scheduler_output.scheduled_spec_decode_tokens.get(req_id, []))
80+
if req_id in scheduler_output.structured_output_request_ids:
81+
struct_out_req_batch_indices[req_id] = logit_index
82+
83+
out_indices = []
84+
85+
# Reorder the bitmask to match the order of the requests in the batch.
86+
sorted_bitmask = np.full(shape=(logits.shape[0], grammar_bitmask.shape[1]),
87+
fill_value=-1,
88+
dtype=grammar_bitmask.dtype)
89+
cumulative_index = 0
90+
seq = sorted(scheduler_output.structured_output_request_ids.items(),
91+
key=lambda x: x[1])
92+
for req_id, _ in seq:
93+
logit_index = struct_out_req_batch_indices[req_id]
94+
num_spec_tokens = len(
95+
scheduler_output.scheduled_spec_decode_tokens.get(req_id, []))
96+
for i in range(1 + num_spec_tokens):
97+
sorted_bitmask[logit_index + i] = \
98+
grammar_bitmask[cumulative_index + i]
99+
out_indices.append(logit_index + i)
100+
cumulative_index += 1 + num_spec_tokens
101+
grammar_bitmask = sorted_bitmask
102+
103+
# If the length of out indices and the logits have the same shape
104+
# we don't need to pass indices to the kernel,
105+
# since the bitmask is already aligned with the logits.
106+
skip_out_indices = len(out_indices) == logits.shape[0]
107+
108+
# Serialization of np.ndarray is much more efficient than a tensor,
109+
# so we receive it in that format.
110+
grammar_bitmask = torch.from_numpy(grammar_bitmask).contiguous()
111+
112+
xgr.apply_token_bitmask_inplace(
113+
logits,
114+
grammar_bitmask.to(device, non_blocking=True),
115+
indices=out_indices if not skip_out_indices else None,
116+
)
117+
118+
39119
class OutlinesVocabulary:
40120
"""
41121
Wrapper class for `outlines_core.Vocabulary`,

vllm/v1/worker/gpu_model_runner.py

Lines changed: 4 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454
from vllm.sequence import IntermediateTensors, PoolerOutput
5555
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
5656
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
57-
GiB_bytes, LazyLoader, check_use_alibi, get_dtype_size,
57+
GiB_bytes, check_use_alibi, get_dtype_size,
5858
is_pin_memory_available, round_up, supports_dynamo)
5959
from vllm.v1.attention.backends.flash_attn import AttentionMetadata
6060
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
@@ -85,6 +85,7 @@
8585
from vllm.v1.spec_decode.medusa import MedusaProposer
8686
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
8787
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
88+
from vllm.v1.structured_output.utils import apply_grammar_bitmask
8889
from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext
8990
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
9091
from vllm.v1.worker.gpu_ubatch_wrapper import UBatchWrapper
@@ -101,12 +102,8 @@
101102
scatter_mm_placeholders)
102103

103104
if TYPE_CHECKING:
104-
import xgrammar as xgr
105-
106105
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
107106
from vllm.v1.core.sched.output import SchedulerOutput
108-
else:
109-
xgr = LazyLoader("xgr", globals(), "xgrammar")
110107

111108
logger = init_logger(__name__)
112109

@@ -1617,71 +1614,6 @@ def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
16171614

16181615
return tuple(tasks)
16191616

1620-
def apply_grammar_bitmask(
1621-
self,
1622-
scheduler_output: "SchedulerOutput",
1623-
logits: torch.Tensor,
1624-
):
1625-
grammar_bitmask = scheduler_output.grammar_bitmask
1626-
if grammar_bitmask is None:
1627-
return
1628-
1629-
# We receive the structured output bitmask from the scheduler,
1630-
# compacted to contain bitmasks only for structured output requests.
1631-
# The order of the requests in the bitmask is not guaranteed to be the
1632-
# same as the order of the requests in the gpu runner's batch. We need
1633-
# to sort the bitmask to match the order of the requests used here.
1634-
1635-
# Get the batch indices of the structured output requests.
1636-
# Keep track of the number of speculative tokens scheduled for every
1637-
# request in the batch, as the logit indices are offset by this amount.
1638-
struct_out_req_batch_indices: dict[str, int] = {}
1639-
cumulative_offset = 0
1640-
seq = sorted(self.input_batch.req_id_to_index.items(),
1641-
key=lambda x: x[1])
1642-
for req_id, batch_index in seq:
1643-
logit_index = batch_index + cumulative_offset
1644-
cumulative_offset += len(
1645-
scheduler_output.scheduled_spec_decode_tokens.get(req_id, []))
1646-
if req_id in scheduler_output.structured_output_request_ids:
1647-
struct_out_req_batch_indices[req_id] = logit_index
1648-
1649-
out_indices = []
1650-
1651-
# Reorder the bitmask to match the order of the requests in the batch.
1652-
sorted_bitmask = np.full(shape=(logits.shape[0],
1653-
grammar_bitmask.shape[1]),
1654-
fill_value=-1,
1655-
dtype=grammar_bitmask.dtype)
1656-
cumulative_index = 0
1657-
seq = sorted(scheduler_output.structured_output_request_ids.items(),
1658-
key=lambda x: x[1])
1659-
for req_id, _ in seq:
1660-
logit_index = struct_out_req_batch_indices[req_id]
1661-
num_spec_tokens = len(
1662-
scheduler_output.scheduled_spec_decode_tokens.get(req_id, []))
1663-
for i in range(1 + num_spec_tokens):
1664-
sorted_bitmask[logit_index + i] = \
1665-
grammar_bitmask[cumulative_index + i]
1666-
out_indices.append(logit_index + i)
1667-
cumulative_index += 1 + num_spec_tokens
1668-
grammar_bitmask = sorted_bitmask
1669-
1670-
# If the length of out indices and the logits have the same shape
1671-
# we don't need to pass indices to the kernel,
1672-
# since the bitmask is already aligned with the logits.
1673-
skip_out_indices = len(out_indices) == logits.shape[0]
1674-
1675-
# Serialization of np.ndarray is much more efficient than a tensor,
1676-
# so we receive it in that format.
1677-
grammar_bitmask = torch.from_numpy(grammar_bitmask).contiguous()
1678-
1679-
xgr.apply_token_bitmask_inplace(
1680-
logits,
1681-
grammar_bitmask.to(self.device, non_blocking=True),
1682-
indices=out_indices if not skip_out_indices else None,
1683-
)
1684-
16851617
def sync_and_slice_intermediate_tensors(
16861618
self, num_tokens: int, intermediate_tensors: IntermediateTensors,
16871619
sync_self: bool) -> IntermediateTensors:
@@ -2232,7 +2164,8 @@ def execute_model(
22322164

22332165
# Apply structured output bitmasks if present
22342166
if scheduler_output.grammar_bitmask is not None:
2235-
self.apply_grammar_bitmask(scheduler_output, logits)
2167+
apply_grammar_bitmask(scheduler_output, self.input_batch,
2168+
logits, self.device)
22362169

22372170
with record_function_or_nullcontext("Sample"):
22382171
sampler_output = self._sample(logits, spec_decode_metadata)

0 commit comments

Comments
 (0)