|
71 | 71 | from vllm.v1.sample.metadata import SamplingMetadata |
72 | 72 | from vllm.v1.spec_decode.metadata import SpecDecodeMetadata |
73 | 73 | from vllm.v1.spec_decode.ngram_proposer import NgramProposer |
| 74 | +from vllm.v1.structured_output.utils import apply_grammar_bitmask |
74 | 75 | from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorOutput |
75 | 76 | from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin |
76 | 77 | from vllm.v1.worker.utils import (bind_kv_cache, gather_mm_placeholders, |
@@ -1404,70 +1405,6 @@ def _calc_spec_decode_metadata( |
1404 | 1405 | ) |
1405 | 1406 | return metadata |
1406 | 1407 |
|
1407 | | - def apply_grammar_bitmask( |
1408 | | - self, |
1409 | | - scheduler_output: "SchedulerOutput", |
1410 | | - logits: torch.Tensor, |
1411 | | - ) -> torch.Tensor: |
1412 | | - grammar_bitmask = scheduler_output.grammar_bitmask |
1413 | | - |
1414 | | - # We receive the structured output bitmask from the scheduler, |
1415 | | - # compacted to contain bitmasks only for structured output requests. |
1416 | | - # The order of the requests in the bitmask is not guaranteed to be the |
1417 | | - # same as the order of the requests in the gpu runner's batch. We need |
1418 | | - # to sort the bitmask to match the order of the requests used here. |
1419 | | - |
1420 | | - # Get the batch indices of the structured output requests. |
1421 | | - # Keep track of the number of speculative tokens scheduled for every |
1422 | | - # request in the batch, as the logit indices are offset by this amount. |
1423 | | - struct_out_req_batch_indices: dict[str, int] = {} |
1424 | | - cumulative_offset = 0 |
1425 | | - seq = sorted(self.input_batch.req_id_to_index.items(), |
1426 | | - key=lambda x: x[1]) |
1427 | | - for req_id, batch_index in seq: |
1428 | | - logit_index = batch_index + cumulative_offset |
1429 | | - cumulative_offset += len( |
1430 | | - scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])) |
1431 | | - if req_id in scheduler_output.structured_output_request_ids: |
1432 | | - struct_out_req_batch_indices[req_id] = logit_index |
1433 | | - |
1434 | | - out_indices = [] |
1435 | | - |
1436 | | - # Reorder the bitmask to match the order of the requests in the batch. |
1437 | | - sorted_bitmask = np.zeros_like(grammar_bitmask, |
1438 | | - shape=(logits.shape[0], |
1439 | | - grammar_bitmask.shape[1])) |
1440 | | - cumulative_index = 0 |
1441 | | - seq = sorted(scheduler_output.structured_output_request_ids.items(), |
1442 | | - key=lambda x: x[1]) |
1443 | | - for req_id, _ in seq: |
1444 | | - logit_index = struct_out_req_batch_indices[req_id] |
1445 | | - num_spec_tokens = len( |
1446 | | - scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])) |
1447 | | - for i in range(1 + num_spec_tokens): |
1448 | | - sorted_bitmask[logit_index + i] = \ |
1449 | | - grammar_bitmask[cumulative_index + i] |
1450 | | - out_indices.append(logit_index + i) |
1451 | | - cumulative_index += 1 + num_spec_tokens |
1452 | | - grammar_bitmask = sorted_bitmask |
1453 | | - |
1454 | | - # Serialization of np.ndarray is much more efficient than a tensor, |
1455 | | - # so we receive it in that format. |
1456 | | - grammar_bitmask = torch.from_numpy(grammar_bitmask) |
1457 | | - |
1458 | | - # NOTE: |
1459 | | - # 1. XGrammar bitmask applying only supports CPU and GPU. |
1460 | | - # 2. The logits and bitmask should be on the same device. |
1461 | | - # 3. XGrammar logits on CPU only supports float32 dtype. |
1462 | | - logits_dtype = logits.dtype |
1463 | | - logits = logits.to("cpu").float() |
1464 | | - xgr.apply_token_bitmask_inplace( |
1465 | | - logits, |
1466 | | - grammar_bitmask, |
1467 | | - indices=out_indices, |
1468 | | - ) |
1469 | | - return logits.to(self.device).to(logits_dtype) |
1470 | | - |
1471 | 1408 | def propose_draft_token_ids( |
1472 | 1409 | self, |
1473 | 1410 | valid_sampled_token_ids: list[list[int]], |
@@ -1657,7 +1594,15 @@ def execute_model( |
1657 | 1594 |
|
1658 | 1595 | # Apply structured output bitmasks if present |
1659 | 1596 | if scheduler_output.grammar_bitmask is not None: |
1660 | | - logits = self.apply_grammar_bitmask(scheduler_output, logits) |
| 1597 | + # NOTE: |
| 1598 | + # 1. XGrammar bitmask applying only supports CPU and GPU. |
| 1599 | + # 2. The logits and bitmask should be on the same device. |
| 1600 | + # 3. XGrammar logits on CPU only supports float32 dtype. |
| 1601 | + logits_dtype = logits.dtype |
| 1602 | + logits = logits.to("cpu").float() |
| 1603 | + apply_grammar_bitmask(scheduler_output, self.input_batch, |
| 1604 | + logits, self.device) |
| 1605 | + logits = logits.to(self.device).to(logits_dtype) |
1661 | 1606 |
|
1662 | 1607 | # Sample the next token and get logprobs if needed. |
1663 | 1608 | sampling_metadata = self.input_batch.sampling_metadata |
|
0 commit comments