From 78a0479258abf37b7e84a1c34442b061ecdab298 Mon Sep 17 00:00:00 2001 From: Shanshan Shen <87969357+shen-shanshan@users.noreply.github.com> Date: Mon, 5 May 2025 09:45:06 +0800 Subject: [PATCH 1/2] Enable Speculative Decoding with Structured Outputs Signed-off-by: Shanshan Shen <87969357+shen-shanshan@users.noreply.github.com> --- vllm_ascend/worker/model_runner_v1.py | 78 +++++++++++++++------------ 1 file changed, 44 insertions(+), 34 deletions(-) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index bd508a4717..95593647f6 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -578,53 +578,63 @@ def apply_grammar_bitmask( scheduler_output: "SchedulerOutput", logits: torch.Tensor, ) -> torch.Tensor: - # Serialization of np.ndarray is much more efficient than a tensor, - # so we receive it in that format. grammar_bitmask = scheduler_output.grammar_bitmask if grammar_bitmask is None: return - # We receive the structured output bitmask from the scheduler, but the - # indices of the requests in the batch may not match the indices of - # the bitmask since the scheduler doesn't know how the gpu runner is - # ordering the requests in the batch. We need to sort the bitmask to - # match the order of the requests used here. + # We receive the structured output bitmask from the scheduler, + # compacted to contain bitmasks only for structured output requests. + # The order of the requests in the bitmask is not guaranteed to be the + # same as the order of the requests in the gpu runner's batch. We need + # to sort the bitmask to match the order of the requests used here. + + # Get the batch indices of the structured output requests. + # Keep track of the number of speculative tokens scheduled for every + # request in the batch, as the logit indices are offset by this amount. struct_out_req_batch_indices: dict[str, int] = {} - indices_match = True - for req_id in self.input_batch.req_ids: - mask_index = scheduler_output.structured_output_request_ids.get( - req_id) - if mask_index is None: - # not a structured output request - continue - batch_index = self.input_batch.req_id_to_index[req_id] - if batch_index != mask_index: - indices_match = False - struct_out_req_batch_indices[req_id] = batch_index - - if not indices_match: - # Sort the bitmask to match the order of the requests - sorted_bitmask = np.zeros_like(grammar_bitmask) - for req_id, batch_index in struct_out_req_batch_indices.items(): - orig_index = scheduler_output.structured_output_request_ids[ - req_id] - sorted_bitmask[batch_index] = grammar_bitmask[orig_index] - grammar_bitmask = sorted_bitmask + cumulative_offset = 0 + seq = sorted(self.input_batch.req_id_to_index.items(), + key=lambda x: x[1]) + for req_id, batch_index in seq: + logit_index = batch_index + cumulative_offset + cumulative_offset += len( + scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])) + if req_id in scheduler_output.structured_output_request_ids: + struct_out_req_batch_indices[req_id] = logit_index + + out_indices = [] + + # Reorder the bitmask to match the order of the requests in the batch. + sorted_bitmask = np.zeros_like(grammar_bitmask, + shape=(logits.shape[0], + grammar_bitmask.shape[1])) + cumulative_index = 0 + seq = sorted(scheduler_output.structured_output_request_ids.items(), + key=lambda x: x[1]) + for req_id, _ in seq: + logit_index = struct_out_req_batch_indices[req_id] + num_spec_tokens = len( + scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])) + for i in range(1 + num_spec_tokens): + sorted_bitmask[logit_index + i] = \ + grammar_bitmask[cumulative_index + i] + out_indices.append(logit_index + i) + cumulative_index += 1 + num_spec_tokens + grammar_bitmask = sorted_bitmask + + # Serialization of np.ndarray is much more efficient than a tensor, + # so we receive it in that format. grammar_bitmask = torch.from_numpy(grammar_bitmask) - # TODO: compatibility with spec decode. # NOTE: - # 1. XGrammar bitmask applying only supports CPU and GPU. + # 1. XGrammar bitmask applying only supports CPU device. # 2. The logits and bitmask should be on the same device. # 3. XGrammar logits on CPU only supports float32 dtype. logits_dtype = logits.dtype logits = logits.to("cpu").float() - xgr.apply_token_bitmask_inplace( - logits, - grammar_bitmask, - indices=list(struct_out_req_batch_indices.values()), - ) + xgr.apply_token_bitmask_inplace(logits, grammar_bitmask, + indices=out_indices) return logits.to(self.device).to(logits_dtype) @torch.inference_mode() From 32ae41112665c3118586d0689ead7e56974cf311 Mon Sep 17 00:00:00 2001 From: shen-shanshan <467638484@qq.com> Date: Thu, 8 May 2025 07:42:45 +0000 Subject: [PATCH 2/2] format Signed-off-by: shen-shanshan <467638484@qq.com> --- vllm_ascend/worker/model_runner_v1.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 95593647f6..79a2189031 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -633,7 +633,8 @@ def apply_grammar_bitmask( # 3. XGrammar logits on CPU only supports float32 dtype. logits_dtype = logits.dtype logits = logits.to("cpu").float() - xgr.apply_token_bitmask_inplace(logits, grammar_bitmask, + xgr.apply_token_bitmask_inplace(logits, + grammar_bitmask, indices=out_indices) return logits.to(self.device).to(logits_dtype)