Skip to content

Commit e3c1ac8

Browse files
[Structured Output] Replace apply_grammar_bitmask() method with that in vllm to avoid maintenance (#2524)
### What this PR does / why we need it? Replace `apply_grammar_bitmask()` method with that in vllm to avoid maintenance. - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 Signed-off-by: shen-shanshan <467638484@qq.com>
1 parent 9434f24 commit e3c1ac8

File tree

1 file changed

+12
-66
lines changed

1 file changed

+12
-66
lines changed

vllm_ascend/worker/model_runner_v1.py

Lines changed: 12 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import gc
2222
import itertools
2323
import math
24-
import re
2524
import time
2625
from collections import defaultdict
2726
from collections.abc import Iterator
@@ -34,6 +33,7 @@
3433

3534
import numpy as np
3635
import numpy.typing as npt
36+
import regex as re
3737
import torch
3838
import torch._dynamo.cache_size
3939
import torch.distributed as dist
@@ -92,6 +92,7 @@
9292
from vllm.v1.sample.metadata import SamplingMetadata
9393
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
9494
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
95+
from vllm.v1.structured_output.utils import apply_grammar_bitmask
9596
from vllm.v1.utils import CpuGpuBuffer
9697
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorOutput
9798
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
@@ -1699,70 +1700,6 @@ def _calc_spec_decode_metadata(
16991700
)
17001701
return metadata
17011702

1702-
def apply_grammar_bitmask(
1703-
self,
1704-
scheduler_output: "SchedulerOutput",
1705-
logits: torch.Tensor,
1706-
) -> torch.Tensor:
1707-
grammar_bitmask = scheduler_output.grammar_bitmask
1708-
1709-
# We receive the structured output bitmask from the scheduler,
1710-
# compacted to contain bitmasks only for structured output requests.
1711-
# The order of the requests in the bitmask is not guaranteed to be the
1712-
# same as the order of the requests in the gpu runner's batch. We need
1713-
# to sort the bitmask to match the order of the requests used here.
1714-
1715-
# Get the batch indices of the structured output requests.
1716-
# Keep track of the number of speculative tokens scheduled for every
1717-
# request in the batch, as the logit indices are offset by this amount.
1718-
struct_out_req_batch_indices: dict[str, int] = {}
1719-
cumulative_offset = 0
1720-
seq = sorted(self.input_batch.req_id_to_index.items(),
1721-
key=lambda x: x[1])
1722-
for req_id, batch_index in seq:
1723-
logit_index = batch_index + cumulative_offset
1724-
cumulative_offset += len(
1725-
scheduler_output.scheduled_spec_decode_tokens.get(req_id, []))
1726-
if req_id in scheduler_output.structured_output_request_ids:
1727-
struct_out_req_batch_indices[req_id] = logit_index
1728-
1729-
out_indices = []
1730-
1731-
# Reorder the bitmask to match the order of the requests in the batch.
1732-
sorted_bitmask = np.zeros_like(grammar_bitmask,
1733-
shape=(logits.shape[0],
1734-
grammar_bitmask.shape[1]))
1735-
cumulative_index = 0
1736-
seq = sorted(scheduler_output.structured_output_request_ids.items(),
1737-
key=lambda x: x[1])
1738-
for req_id, _ in seq:
1739-
logit_index = struct_out_req_batch_indices[req_id]
1740-
num_spec_tokens = len(
1741-
scheduler_output.scheduled_spec_decode_tokens.get(req_id, []))
1742-
for i in range(1 + num_spec_tokens):
1743-
sorted_bitmask[logit_index + i] = \
1744-
grammar_bitmask[cumulative_index + i]
1745-
out_indices.append(logit_index + i)
1746-
cumulative_index += 1 + num_spec_tokens
1747-
grammar_bitmask = sorted_bitmask
1748-
1749-
# Serialization of np.ndarray is much more efficient than a tensor,
1750-
# so we receive it in that format.
1751-
grammar_bitmask = torch.from_numpy(grammar_bitmask)
1752-
1753-
# NOTE:
1754-
# 1. XGrammar bitmask applying only supports CPU and GPU.
1755-
# 2. The logits and bitmask should be on the same device.
1756-
# 3. XGrammar logits on CPU only supports float32 dtype.
1757-
logits_dtype = logits.dtype
1758-
logits = logits.to("cpu").float()
1759-
xgr.apply_token_bitmask_inplace(
1760-
logits,
1761-
grammar_bitmask,
1762-
indices=out_indices,
1763-
)
1764-
return logits.to(self.device).to(logits_dtype)
1765-
17661703
def propose_draft_token_ids(
17671704
self,
17681705
valid_sampled_token_ids: list[list[int]],
@@ -2011,7 +1948,16 @@ def execute_model(
20111948

20121949
# Apply structured output bitmasks if present
20131950
if scheduler_output.grammar_bitmask is not None:
2014-
logits = self.apply_grammar_bitmask(scheduler_output, logits)
1951+
assert logits is not None
1952+
# NOTE:
1953+
# 1. XGrammar bitmask applying only supports CPU and GPU.
1954+
# 2. The logits and bitmask should be on the same device.
1955+
# 3. XGrammar logits on CPU only supports float32 dtype.
1956+
logits_dtype = logits.dtype
1957+
logits = logits.to("cpu").float()
1958+
apply_grammar_bitmask(scheduler_output, self.input_batch,
1959+
logits, torch.device("cpu"))
1960+
logits = logits.to(self.device).to(logits_dtype)
20151961

20161962
# Sample the next token and get logprobs if needed.
20171963
sampling_metadata = self.input_batch.sampling_metadata

0 commit comments

Comments
 (0)