Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 41 additions & 31 deletions vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -1283,42 +1283,52 @@ def apply_grammar_bitmask(
scheduler_output: "SchedulerOutput",
logits: torch.Tensor,
) -> torch.Tensor:
# Serialization of np.ndarray is much more efficient than a tensor,
# so we receive it in that format.
grammar_bitmask = scheduler_output.grammar_bitmask
if grammar_bitmask is None:
return

# We receive the structured output bitmask from the scheduler, but the
# indices of the requests in the batch may not match the indices of
# the bitmask since the scheduler doesn't know how the gpu runner is
# ordering the requests in the batch. We need to sort the bitmask to
# match the order of the requests used here.
# We receive the structured output bitmask from the scheduler,
# compacted to contain bitmasks only for structured output requests.
# The order of the requests in the bitmask is not guaranteed to be the
# same as the order of the requests in the gpu runner's batch. We need
# to sort the bitmask to match the order of the requests used here.

# Get the batch indices of the structured output requests.
# Keep track of the number of speculative tokens scheduled for every
# request in the batch, as the logit indices are offset by this amount.
struct_out_req_batch_indices: dict[str, int] = {}
indices_match = True
for req_id in self.input_batch.req_ids:
mask_index = scheduler_output.structured_output_request_ids.get(
req_id)
if mask_index is None:
# not a structured output request
continue
batch_index = self.input_batch.req_id_to_index[req_id]
if batch_index != mask_index:
indices_match = False
struct_out_req_batch_indices[req_id] = batch_index

if not indices_match:
# Sort the bitmask to match the order of the requests
sorted_bitmask = np.zeros_like(grammar_bitmask)
for req_id, batch_index in struct_out_req_batch_indices.items():
orig_index = scheduler_output.structured_output_request_ids[
req_id]
sorted_bitmask[batch_index] = grammar_bitmask[orig_index]
grammar_bitmask = sorted_bitmask
cumulative_offset = 0
seq = sorted(self.input_batch.req_id_to_index.items(),
key=lambda x: x[1])
for req_id, batch_index in seq:
Comment on lines +1299 to +1301
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For clarity and to avoid confusion, consider renaming seq to something more descriptive, like requests_in_batch_order. This is especially helpful because seq is reused later with a different meaning.

Suggested change
seq = sorted(self.input_batch.req_id_to_index.items(),
key=lambda x: x[1])
for req_id, batch_index in seq:
requests_in_batch_order = sorted(self.input_batch.req_id_to_index.items(),
key=lambda x: x[1])
for req_id, batch_index in requests_in_batch_order:

logit_index = batch_index + cumulative_offset
cumulative_offset += len(
scheduler_output.scheduled_spec_decode_tokens.get(req_id, []))
if req_id in scheduler_output.structured_output_request_ids:
struct_out_req_batch_indices[req_id] = logit_index

out_indices = []

# Reorder the bitmask to match the order of the requests in the batch.
sorted_bitmask = np.zeros_like(grammar_bitmask,
shape=(logits.shape[0],
grammar_bitmask.shape[1]))
cumulative_index = 0
seq = sorted(scheduler_output.structured_output_request_ids.items(),
key=lambda x: x[1])
for req_id, _ in seq:
Comment on lines +1315 to +1317
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The variable seq is reused from a previous loop. To improve readability and prevent potential bugs during future modifications, it's better to use a distinct and descriptive name, such as grammar_requests_in_mask_order.

Suggested change
seq = sorted(scheduler_output.structured_output_request_ids.items(),
key=lambda x: x[1])
for req_id, _ in seq:
grammar_requests_in_mask_order = sorted(scheduler_output.structured_output_request_ids.items(),
key=lambda x: x[1])
for req_id, _ in grammar_requests_in_mask_order:

logit_index = struct_out_req_batch_indices[req_id]
num_spec_tokens = len(
scheduler_output.scheduled_spec_decode_tokens.get(req_id, []))
for i in range(1 + num_spec_tokens):
sorted_bitmask[logit_index + i] = \
grammar_bitmask[cumulative_index + i]
out_indices.append(logit_index + i)
cumulative_index += 1 + num_spec_tokens
grammar_bitmask = sorted_bitmask

# Serialization of np.ndarray is much more efficient than a tensor,
# so we receive it in that format.
grammar_bitmask = torch.from_numpy(grammar_bitmask)

# TODO: compatibility with spec decode.
# NOTE:
# 1. XGrammar bitmask applying only supports CPU and GPU.
# 2. The logits and bitmask should be on the same device.
Expand All @@ -1328,7 +1338,7 @@ def apply_grammar_bitmask(
xgr.apply_token_bitmask_inplace(
logits,
grammar_bitmask,
indices=list(struct_out_req_batch_indices.values()),
indices=out_indices,
)
return logits.to(self.device).to(logits_dtype)

Expand Down