Skip to content

Commit dcb597e

Browse files
committed
Enable Speculative Decoding with Structured Outputs
Signed-off-by: Shanshan Shen <87969357+shen-shanshan@users.noreply.github.com>
1 parent d2ead05 commit dcb597e

File tree

1 file changed

+44
-34
lines changed

1 file changed

+44
-34
lines changed

vllm_ascend/worker/model_runner_v1.py

Lines changed: 44 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -578,53 +578,63 @@ def apply_grammar_bitmask(
578578
scheduler_output: "SchedulerOutput",
579579
logits: torch.Tensor,
580580
) -> torch.Tensor:
581-
# Serialization of np.ndarray is much more efficient than a tensor,
582-
# so we receive it in that format.
583581
grammar_bitmask = scheduler_output.grammar_bitmask
584582
if grammar_bitmask is None:
585583
return
586584

587-
# We receive the structured output bitmask from the scheduler, but the
588-
# indices of the requests in the batch may not match the indices of
589-
# the bitmask since the scheduler doesn't know how the gpu runner is
590-
# ordering the requests in the batch. We need to sort the bitmask to
591-
# match the order of the requests used here.
585+
# We receive the structured output bitmask from the scheduler,
586+
# compacted to contain bitmasks only for structured output requests.
587+
# The order of the requests in the bitmask is not guaranteed to be the
588+
# same as the order of the requests in the gpu runner's batch. We need
589+
# to sort the bitmask to match the order of the requests used here.
590+
591+
# Get the batch indices of the structured output requests.
592+
# Keep track of the number of speculative tokens scheduled for every
593+
# request in the batch, as the logit indices are offset by this amount.
592594
struct_out_req_batch_indices: dict[str, int] = {}
593-
indices_match = True
594-
for req_id in self.input_batch.req_ids:
595-
mask_index = scheduler_output.structured_output_request_ids.get(
596-
req_id)
597-
if mask_index is None:
598-
# not a structured output request
599-
continue
600-
batch_index = self.input_batch.req_id_to_index[req_id]
601-
if batch_index != mask_index:
602-
indices_match = False
603-
struct_out_req_batch_indices[req_id] = batch_index
604-
605-
if not indices_match:
606-
# Sort the bitmask to match the order of the requests
607-
sorted_bitmask = np.zeros_like(grammar_bitmask)
608-
for req_id, batch_index in struct_out_req_batch_indices.items():
609-
orig_index = scheduler_output.structured_output_request_ids[
610-
req_id]
611-
sorted_bitmask[batch_index] = grammar_bitmask[orig_index]
612-
grammar_bitmask = sorted_bitmask
613595

596+
cumulative_offset = 0
597+
seq = sorted(self.input_batch.req_id_to_index.items(),
598+
key=lambda x: x[1])
599+
for req_id, batch_index in seq:
600+
logit_index = batch_index + cumulative_offset
601+
cumulative_offset += len(
602+
scheduler_output.scheduled_spec_decode_tokens.get(req_id, []))
603+
if req_id in scheduler_output.structured_output_request_ids:
604+
struct_out_req_batch_indices[req_id] = logit_index
605+
606+
out_indices = []
607+
608+
# Reorder the bitmask to match the order of the requests in the batch.
609+
sorted_bitmask = np.zeros_like(grammar_bitmask,
610+
shape=(logits.shape[0],
611+
grammar_bitmask.shape[1]))
612+
cumulative_index = 0
613+
seq = sorted(scheduler_output.structured_output_request_ids.items(),
614+
key=lambda x: x[1])
615+
for req_id, _ in seq:
616+
logit_index = struct_out_req_batch_indices[req_id]
617+
num_spec_tokens = len(
618+
scheduler_output.scheduled_spec_decode_tokens.get(req_id, []))
619+
for i in range(1 + num_spec_tokens):
620+
sorted_bitmask[logit_index + i] = \
621+
grammar_bitmask[cumulative_index + i]
622+
out_indices.append(logit_index + i)
623+
cumulative_index += 1 + num_spec_tokens
624+
grammar_bitmask = sorted_bitmask
625+
626+
# Serialization of np.ndarray is much more efficient than a tensor,
627+
# so we receive it in that format.
614628
grammar_bitmask = torch.from_numpy(grammar_bitmask)
615629

616-
# TODO: compatibility with spec decode.
617630
# NOTE:
618-
# 1. XGrammar bitmask applying only supports CPU and GPU.
631+
# 1. XGrammar bitmask applying only supports CPU device.
619632
# 2. The logits and bitmask should be on the same device.
620633
# 3. XGrammar logits on CPU only supports float32 dtype.
621634
logits_dtype = logits.dtype
622635
logits = logits.to("cpu").float()
623-
xgr.apply_token_bitmask_inplace(
624-
logits,
625-
grammar_bitmask,
626-
indices=list(struct_out_req_batch_indices.values()),
627-
)
636+
xgr.apply_token_bitmask_inplace(logits, grammar_bitmask,
637+
indices=out_indices)
628638
return logits.to(self.device).to(logits_dtype)
629639

630640
@torch.inference_mode()

0 commit comments

Comments
 (0)