|
21 | 21 | import gc |
22 | 22 | import itertools |
23 | 23 | import math |
24 | | -import re |
25 | 24 | import time |
26 | 25 | from collections import defaultdict |
27 | 26 | from collections.abc import Iterator |
|
34 | 33 |
|
35 | 34 | import numpy as np |
36 | 35 | import numpy.typing as npt |
| 36 | +import regex as re |
37 | 37 | import torch |
38 | 38 | import torch._dynamo.cache_size |
39 | 39 | import torch.distributed as dist |
|
92 | 92 | from vllm.v1.sample.metadata import SamplingMetadata |
93 | 93 | from vllm.v1.spec_decode.metadata import SpecDecodeMetadata |
94 | 94 | from vllm.v1.spec_decode.ngram_proposer import NgramProposer |
| 95 | +from vllm.v1.structured_output.utils import apply_grammar_bitmask |
95 | 96 | from vllm.v1.utils import CpuGpuBuffer |
96 | 97 | from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorOutput |
97 | 98 | from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin |
@@ -1699,70 +1700,6 @@ def _calc_spec_decode_metadata( |
1699 | 1700 | ) |
1700 | 1701 | return metadata |
1701 | 1702 |
|
1702 | | - def apply_grammar_bitmask( |
1703 | | - self, |
1704 | | - scheduler_output: "SchedulerOutput", |
1705 | | - logits: torch.Tensor, |
1706 | | - ) -> torch.Tensor: |
1707 | | - grammar_bitmask = scheduler_output.grammar_bitmask |
1708 | | - |
1709 | | - # We receive the structured output bitmask from the scheduler, |
1710 | | - # compacted to contain bitmasks only for structured output requests. |
1711 | | - # The order of the requests in the bitmask is not guaranteed to be the |
1712 | | - # same as the order of the requests in the gpu runner's batch. We need |
1713 | | - # to sort the bitmask to match the order of the requests used here. |
1714 | | - |
1715 | | - # Get the batch indices of the structured output requests. |
1716 | | - # Keep track of the number of speculative tokens scheduled for every |
1717 | | - # request in the batch, as the logit indices are offset by this amount. |
1718 | | - struct_out_req_batch_indices: dict[str, int] = {} |
1719 | | - cumulative_offset = 0 |
1720 | | - seq = sorted(self.input_batch.req_id_to_index.items(), |
1721 | | - key=lambda x: x[1]) |
1722 | | - for req_id, batch_index in seq: |
1723 | | - logit_index = batch_index + cumulative_offset |
1724 | | - cumulative_offset += len( |
1725 | | - scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])) |
1726 | | - if req_id in scheduler_output.structured_output_request_ids: |
1727 | | - struct_out_req_batch_indices[req_id] = logit_index |
1728 | | - |
1729 | | - out_indices = [] |
1730 | | - |
1731 | | - # Reorder the bitmask to match the order of the requests in the batch. |
1732 | | - sorted_bitmask = np.zeros_like(grammar_bitmask, |
1733 | | - shape=(logits.shape[0], |
1734 | | - grammar_bitmask.shape[1])) |
1735 | | - cumulative_index = 0 |
1736 | | - seq = sorted(scheduler_output.structured_output_request_ids.items(), |
1737 | | - key=lambda x: x[1]) |
1738 | | - for req_id, _ in seq: |
1739 | | - logit_index = struct_out_req_batch_indices[req_id] |
1740 | | - num_spec_tokens = len( |
1741 | | - scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])) |
1742 | | - for i in range(1 + num_spec_tokens): |
1743 | | - sorted_bitmask[logit_index + i] = \ |
1744 | | - grammar_bitmask[cumulative_index + i] |
1745 | | - out_indices.append(logit_index + i) |
1746 | | - cumulative_index += 1 + num_spec_tokens |
1747 | | - grammar_bitmask = sorted_bitmask |
1748 | | - |
1749 | | - # Serialization of np.ndarray is much more efficient than a tensor, |
1750 | | - # so we receive it in that format. |
1751 | | - grammar_bitmask = torch.from_numpy(grammar_bitmask) |
1752 | | - |
1753 | | - # NOTE: |
1754 | | - # 1. XGrammar bitmask applying only supports CPU and GPU. |
1755 | | - # 2. The logits and bitmask should be on the same device. |
1756 | | - # 3. XGrammar logits on CPU only supports float32 dtype. |
1757 | | - logits_dtype = logits.dtype |
1758 | | - logits = logits.to("cpu").float() |
1759 | | - xgr.apply_token_bitmask_inplace( |
1760 | | - logits, |
1761 | | - grammar_bitmask, |
1762 | | - indices=out_indices, |
1763 | | - ) |
1764 | | - return logits.to(self.device).to(logits_dtype) |
1765 | | - |
1766 | 1703 | def propose_draft_token_ids( |
1767 | 1704 | self, |
1768 | 1705 | valid_sampled_token_ids: list[list[int]], |
@@ -2011,7 +1948,16 @@ def execute_model( |
2011 | 1948 |
|
2012 | 1949 | # Apply structured output bitmasks if present |
2013 | 1950 | if scheduler_output.grammar_bitmask is not None: |
2014 | | - logits = self.apply_grammar_bitmask(scheduler_output, logits) |
| 1951 | + assert logits is not None |
| 1952 | + # NOTE: |
| 1953 | + # 1. XGrammar bitmask applying only supports CPU and GPU. |
| 1954 | + # 2. The logits and bitmask should be on the same device. |
| 1955 | + # 3. XGrammar logits on CPU only supports float32 dtype. |
| 1956 | + logits_dtype = logits.dtype |
| 1957 | + logits = logits.to("cpu").float() |
| 1958 | + apply_grammar_bitmask(scheduler_output, self.input_batch, |
| 1959 | + logits, torch.device("cpu")) |
| 1960 | + logits = logits.to(self.device).to(logits_dtype) |
2015 | 1961 |
|
2016 | 1962 | # Sample the next token and get logprobs if needed. |
2017 | 1963 | sampling_metadata = self.input_batch.sampling_metadata |
|
0 commit comments