|
54 | 54 | from vllm.sequence import IntermediateTensors, PoolerOutput |
55 | 55 | from vllm.tasks import GenerationTask, PoolingTask, SupportedTask |
56 | 56 | from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, |
57 | | - GiB_bytes, LazyLoader, check_use_alibi, get_dtype_size, |
| 57 | + GiB_bytes, check_use_alibi, get_dtype_size, |
58 | 58 | is_pin_memory_available, round_up, supports_dynamo) |
59 | 59 | from vllm.v1.attention.backends.flash_attn import AttentionMetadata |
60 | 60 | from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder |
|
85 | 85 | from vllm.v1.spec_decode.medusa import MedusaProposer |
86 | 86 | from vllm.v1.spec_decode.metadata import SpecDecodeMetadata |
87 | 87 | from vllm.v1.spec_decode.ngram_proposer import NgramProposer |
| 88 | +from vllm.v1.structured_output.utils import apply_grammar_bitmask |
88 | 89 | from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext |
89 | 90 | from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch |
90 | 91 | from vllm.v1.worker.gpu_ubatch_wrapper import UBatchWrapper |
|
101 | 102 | scatter_mm_placeholders) |
102 | 103 |
|
103 | 104 | if TYPE_CHECKING: |
104 | | - import xgrammar as xgr |
105 | | - |
106 | 105 | from vllm.model_executor.model_loader.tensorizer import TensorizerConfig |
107 | 106 | from vllm.v1.core.sched.output import SchedulerOutput |
108 | | -else: |
109 | | - xgr = LazyLoader("xgr", globals(), "xgrammar") |
110 | 107 |
|
111 | 108 | logger = init_logger(__name__) |
112 | 109 |
|
@@ -1617,71 +1614,6 @@ def get_supported_tasks(self) -> tuple[SupportedTask, ...]: |
1617 | 1614 |
|
1618 | 1615 | return tuple(tasks) |
1619 | 1616 |
|
1620 | | - def apply_grammar_bitmask( |
1621 | | - self, |
1622 | | - scheduler_output: "SchedulerOutput", |
1623 | | - logits: torch.Tensor, |
1624 | | - ): |
1625 | | - grammar_bitmask = scheduler_output.grammar_bitmask |
1626 | | - if grammar_bitmask is None: |
1627 | | - return |
1628 | | - |
1629 | | - # We receive the structured output bitmask from the scheduler, |
1630 | | - # compacted to contain bitmasks only for structured output requests. |
1631 | | - # The order of the requests in the bitmask is not guaranteed to be the |
1632 | | - # same as the order of the requests in the gpu runner's batch. We need |
1633 | | - # to sort the bitmask to match the order of the requests used here. |
1634 | | - |
1635 | | - # Get the batch indices of the structured output requests. |
1636 | | - # Keep track of the number of speculative tokens scheduled for every |
1637 | | - # request in the batch, as the logit indices are offset by this amount. |
1638 | | - struct_out_req_batch_indices: dict[str, int] = {} |
1639 | | - cumulative_offset = 0 |
1640 | | - seq = sorted(self.input_batch.req_id_to_index.items(), |
1641 | | - key=lambda x: x[1]) |
1642 | | - for req_id, batch_index in seq: |
1643 | | - logit_index = batch_index + cumulative_offset |
1644 | | - cumulative_offset += len( |
1645 | | - scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])) |
1646 | | - if req_id in scheduler_output.structured_output_request_ids: |
1647 | | - struct_out_req_batch_indices[req_id] = logit_index |
1648 | | - |
1649 | | - out_indices = [] |
1650 | | - |
1651 | | - # Reorder the bitmask to match the order of the requests in the batch. |
1652 | | - sorted_bitmask = np.full(shape=(logits.shape[0], |
1653 | | - grammar_bitmask.shape[1]), |
1654 | | - fill_value=-1, |
1655 | | - dtype=grammar_bitmask.dtype) |
1656 | | - cumulative_index = 0 |
1657 | | - seq = sorted(scheduler_output.structured_output_request_ids.items(), |
1658 | | - key=lambda x: x[1]) |
1659 | | - for req_id, _ in seq: |
1660 | | - logit_index = struct_out_req_batch_indices[req_id] |
1661 | | - num_spec_tokens = len( |
1662 | | - scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])) |
1663 | | - for i in range(1 + num_spec_tokens): |
1664 | | - sorted_bitmask[logit_index + i] = \ |
1665 | | - grammar_bitmask[cumulative_index + i] |
1666 | | - out_indices.append(logit_index + i) |
1667 | | - cumulative_index += 1 + num_spec_tokens |
1668 | | - grammar_bitmask = sorted_bitmask |
1669 | | - |
1670 | | - # If the length of out indices and the logits have the same shape |
1671 | | - # we don't need to pass indices to the kernel, |
1672 | | - # since the bitmask is already aligned with the logits. |
1673 | | - skip_out_indices = len(out_indices) == logits.shape[0] |
1674 | | - |
1675 | | - # Serialization of np.ndarray is much more efficient than a tensor, |
1676 | | - # so we receive it in that format. |
1677 | | - grammar_bitmask = torch.from_numpy(grammar_bitmask).contiguous() |
1678 | | - |
1679 | | - xgr.apply_token_bitmask_inplace( |
1680 | | - logits, |
1681 | | - grammar_bitmask.to(self.device, non_blocking=True), |
1682 | | - indices=out_indices if not skip_out_indices else None, |
1683 | | - ) |
1684 | | - |
1685 | 1617 | def sync_and_slice_intermediate_tensors( |
1686 | 1618 | self, num_tokens: int, intermediate_tensors: IntermediateTensors, |
1687 | 1619 | sync_self: bool) -> IntermediateTensors: |
@@ -2232,7 +2164,8 @@ def execute_model( |
2232 | 2164 |
|
2233 | 2165 | # Apply structured output bitmasks if present |
2234 | 2166 | if scheduler_output.grammar_bitmask is not None: |
2235 | | - self.apply_grammar_bitmask(scheduler_output, logits) |
| 2167 | + apply_grammar_bitmask(scheduler_output, self.input_batch, |
| 2168 | + logits, self.device) |
2236 | 2169 |
|
2237 | 2170 | with record_function_or_nullcontext("Sample"): |
2238 | 2171 | sampler_output = self._sample(logits, spec_decode_metadata) |
|
0 commit comments