Skip to content

Commit b667c7e

Browse files
committed
Fixes cuda graph of MTP verify under unaligned sps tokens.
Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com> (cherry picked from commit 8b83d23259ac24ec1f3e5e012da0c997a90031d8)
1 parent 61bb8b8 commit b667c7e

File tree

3 files changed

+32
-32
lines changed

3 files changed

+32
-32
lines changed

vllm/model_executor/layers/mamba/ops/causal_conv1d.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -680,6 +680,25 @@ def _causal_conv1d_update_kernel(
680680
# not processing as this is not the actual sequence
681681
return
682682

683+
if IS_VARLEN:
684+
query_start_index = tl.load(query_start_loc_ptr + idx_seq).to(tl.int64)
685+
query_end_index = tl.load(query_start_loc_ptr + (idx_seq + 1)).to(
686+
tl.int64)
687+
# revise state_len and seqlen
688+
state_len = state_len - (seqlen -
689+
(query_end_index - query_start_index))
690+
seqlen = query_end_index - query_start_index
691+
x_offset = query_start_index * stride_x_token
692+
o_offset = query_start_index * stride_o_token
693+
else:
694+
query_start_index = idx_seq * seqlen
695+
query_end_index = query_start_index + seqlen
696+
x_offset = idx_seq * stride_x_seq
697+
o_offset = idx_seq * stride_o_seq
698+
699+
if query_start_index == query_end_index:
700+
return
701+
683702
if IS_SPEC_DECODING:
684703
# The rolling of conv state:
685704
#
@@ -722,22 +741,6 @@ def _causal_conv1d_update_kernel(
722741
conv_states_ptrs = prior_tokens + 4 * stride_conv_state_tok # [BLOCK_N]
723742
col4 = tl.load(conv_states_ptrs, mask_w, 0.0)
724743

725-
if IS_VARLEN:
726-
query_start_index = tl.load(query_start_loc_ptr + idx_seq).to(tl.int64)
727-
query_end_index = tl.load(query_start_loc_ptr + (idx_seq + 1)).to(
728-
tl.int64)
729-
# revise state_len and seqlen
730-
state_len = state_len - (seqlen -
731-
(query_end_index - query_start_index))
732-
seqlen = query_end_index - query_start_index
733-
x_offset = query_start_index * stride_x_token
734-
o_offset = query_start_index * stride_o_token
735-
else:
736-
query_start_index = idx_seq * seqlen
737-
query_end_index = query_start_index + seqlen
738-
x_offset = idx_seq * stride_x_seq
739-
o_offset = idx_seq * stride_o_seq
740-
741744
# STEP 2: assume state_len > seqlen
742745
idx_tokens = tl.arange(0, NP2_STATELEN) # [BLOCK_M]
743746

vllm/model_executor/models/qwen3_next.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -417,9 +417,7 @@ def _forward(
417417
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
418418
conv_state = self_kv_cache[0].transpose(-1, -2)
419419
ssm_state = self_kv_cache[1]
420-
num_actual_tokens = (attn_metadata.num_prefill_tokens +
421-
attn_metadata.num_decode_tokens +
422-
attn_metadata.num_spec_decode_tokens)
420+
num_actual_tokens = attn_metadata.num_actual_tokens
423421
num_accepted_tokens = attn_metadata.num_accepted_tokens
424422

425423
# 1. Set up dimensions for reshapes later

vllm/v1/attention/backends/gdn_attn.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ class GDNAttentionMetadata:
3131
num_decode_tokens: int
3232
num_spec_decodes: int
3333
num_spec_decode_tokens: int
34+
num_actual_tokens: int
3435

3536
has_initial_state: Optional[torch.Tensor] = None
3637

@@ -205,25 +206,22 @@ def build( # type: ignore[override]
205206
has_initial_state = has_initial_state[~spec_sequence_masks]
206207
else:
207208
has_initial_state = None
209+
num_actual_tokens = num_prefill_tokens + num_decode_tokens + \
210+
num_spec_decode_tokens
208211

209212
# prepare tensors for cudagraph
210213
#
211214
# With speculative decoding, the xgrammar backend may rollback tokens
212215
# and causing some sequences has less draft tokens than self.num_spec.
213216
#
214-
# During cudagraph capture, the GDN backends requires an assumption
215-
# that num_spec_decode_tokens == num_spec_decodes * (self.num_spec + 1).
216-
#
217-
# More than one such sequences may break the assumption (less tokens),
218-
# causing incompatible inputs for cuda graph replay.
217+
# In above cases, the max possible batch size for n tokens, can be
218+
# min(n, cudagraph_max_bs).
219219
if (self.use_full_cuda_graph and num_prefills == 0 and num_decodes == 0
220220
and num_spec_decodes <= self.decode_cudagraph_max_bs
221-
and num_spec_decode_tokens <= self.decode_cudagraph_max_bs
222-
and num_spec_decode_tokens == num_spec_decodes *
223-
(self.num_spec + 1)):
224-
num_total_tokens = self.vllm_config.pad_for_cudagraph(
221+
and num_spec_decode_tokens <= self.decode_cudagraph_max_bs):
222+
num_actual_tokens = self.vllm_config.pad_for_cudagraph(
225223
m.num_actual_tokens)
226-
batch_size = num_total_tokens // (self.num_spec + 1)
224+
batch_size = min(self.decode_cudagraph_max_bs, num_actual_tokens)
227225

228226
self.spec_state_indices_tensor[:num_spec_decodes].copy_(
229227
spec_state_indices_tensor, non_blocking=True)
@@ -239,7 +237,7 @@ def build( # type: ignore[override]
239237
assert spec_token_masks is not None
240238
self.spec_token_masks[:spec_token_masks.size(0)].copy_(
241239
spec_token_masks, non_blocking=True)
242-
spec_token_masks = self.spec_token_masks[:m.num_actual_tokens]
240+
spec_token_masks = self.spec_token_masks[:num_actual_tokens]
243241
spec_token_masks[spec_token_masks.size(0):].fill_(False)
244242

245243
self.spec_query_start_loc[:num_spec_decodes + 1].copy_(
@@ -258,9 +256,9 @@ def build( # type: ignore[override]
258256
if (self.use_full_cuda_graph and num_prefills == 0
259257
and num_spec_decodes == 0
260258
and num_decodes <= self.decode_cudagraph_max_bs):
261-
num_total_tokens = self.vllm_config.pad_for_cudagraph(
259+
num_actual_tokens = self.vllm_config.pad_for_cudagraph(
262260
m.num_actual_tokens)
263-
batch_size = num_total_tokens
261+
batch_size = num_actual_tokens
264262

265263
self.non_spec_state_indices_tensor[:num_decodes].copy_(
266264
non_spec_state_indices_tensor, non_blocking=True)
@@ -284,6 +282,7 @@ def build( # type: ignore[override]
284282
num_decode_tokens=num_decode_tokens,
285283
num_spec_decodes=num_spec_decodes,
286284
num_spec_decode_tokens=num_spec_decode_tokens,
285+
num_actual_tokens=num_actual_tokens,
287286
has_initial_state=has_initial_state,
288287
spec_query_start_loc=spec_query_start_loc,
289288
non_spec_query_start_loc=non_spec_query_start_loc,

0 commit comments

Comments
 (0)