Skip to content

Commit a1c8265

Browse files
committed
Replace apply_grammar_bitmask() method with that in vllm
Signed-off-by: shen-shanshan <467638484@qq.com>
1 parent 0f81e03 commit a1c8265

File tree

1 file changed

+10
-65
lines changed

1 file changed

+10
-65
lines changed

vllm_ascend/worker/model_runner_v1.py

Lines changed: 10 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
from vllm.v1.sample.metadata import SamplingMetadata
7272
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
7373
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
74+
from vllm.v1.structured_output.utils import apply_grammar_bitmask
7475
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorOutput
7576
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
7677
from vllm.v1.worker.utils import (bind_kv_cache, gather_mm_placeholders,
@@ -1404,70 +1405,6 @@ def _calc_spec_decode_metadata(
14041405
)
14051406
return metadata
14061407

1407-
def apply_grammar_bitmask(
1408-
self,
1409-
scheduler_output: "SchedulerOutput",
1410-
logits: torch.Tensor,
1411-
) -> torch.Tensor:
1412-
grammar_bitmask = scheduler_output.grammar_bitmask
1413-
1414-
# We receive the structured output bitmask from the scheduler,
1415-
# compacted to contain bitmasks only for structured output requests.
1416-
# The order of the requests in the bitmask is not guaranteed to be the
1417-
# same as the order of the requests in the gpu runner's batch. We need
1418-
# to sort the bitmask to match the order of the requests used here.
1419-
1420-
# Get the batch indices of the structured output requests.
1421-
# Keep track of the number of speculative tokens scheduled for every
1422-
# request in the batch, as the logit indices are offset by this amount.
1423-
struct_out_req_batch_indices: dict[str, int] = {}
1424-
cumulative_offset = 0
1425-
seq = sorted(self.input_batch.req_id_to_index.items(),
1426-
key=lambda x: x[1])
1427-
for req_id, batch_index in seq:
1428-
logit_index = batch_index + cumulative_offset
1429-
cumulative_offset += len(
1430-
scheduler_output.scheduled_spec_decode_tokens.get(req_id, []))
1431-
if req_id in scheduler_output.structured_output_request_ids:
1432-
struct_out_req_batch_indices[req_id] = logit_index
1433-
1434-
out_indices = []
1435-
1436-
# Reorder the bitmask to match the order of the requests in the batch.
1437-
sorted_bitmask = np.zeros_like(grammar_bitmask,
1438-
shape=(logits.shape[0],
1439-
grammar_bitmask.shape[1]))
1440-
cumulative_index = 0
1441-
seq = sorted(scheduler_output.structured_output_request_ids.items(),
1442-
key=lambda x: x[1])
1443-
for req_id, _ in seq:
1444-
logit_index = struct_out_req_batch_indices[req_id]
1445-
num_spec_tokens = len(
1446-
scheduler_output.scheduled_spec_decode_tokens.get(req_id, []))
1447-
for i in range(1 + num_spec_tokens):
1448-
sorted_bitmask[logit_index + i] = \
1449-
grammar_bitmask[cumulative_index + i]
1450-
out_indices.append(logit_index + i)
1451-
cumulative_index += 1 + num_spec_tokens
1452-
grammar_bitmask = sorted_bitmask
1453-
1454-
# Serialization of np.ndarray is much more efficient than a tensor,
1455-
# so we receive it in that format.
1456-
grammar_bitmask = torch.from_numpy(grammar_bitmask)
1457-
1458-
# NOTE:
1459-
# 1. XGrammar bitmask applying only supports CPU and GPU.
1460-
# 2. The logits and bitmask should be on the same device.
1461-
# 3. XGrammar logits on CPU only supports float32 dtype.
1462-
logits_dtype = logits.dtype
1463-
logits = logits.to("cpu").float()
1464-
xgr.apply_token_bitmask_inplace(
1465-
logits,
1466-
grammar_bitmask,
1467-
indices=out_indices,
1468-
)
1469-
return logits.to(self.device).to(logits_dtype)
1470-
14711408
def propose_draft_token_ids(
14721409
self,
14731410
valid_sampled_token_ids: list[list[int]],
@@ -1657,7 +1594,15 @@ def execute_model(
16571594

16581595
# Apply structured output bitmasks if present
16591596
if scheduler_output.grammar_bitmask is not None:
1660-
logits = self.apply_grammar_bitmask(scheduler_output, logits)
1597+
# NOTE:
1598+
# 1. XGrammar bitmask applying only supports CPU and GPU.
1599+
# 2. The logits and bitmask should be on the same device.
1600+
# 3. XGrammar logits on CPU only supports float32 dtype.
1601+
logits_dtype = logits.dtype
1602+
logits = logits.to("cpu").float()
1603+
apply_grammar_bitmask(scheduler_output, self.input_batch,
1604+
logits, self.device)
1605+
logits = logits.to(self.device).to(logits_dtype)
16611606

16621607
# Sample the next token and get logprobs if needed.
16631608
sampling_metadata = self.input_batch.sampling_metadata

0 commit comments

Comments
 (0)