Skip to content

Commit aacd8f5

Browse files
sighingnowxuebwang-amd
authored andcommitted
[Bugfix][Qwen3-Next] fixes the varlen issue in qwen3-next's MTP implementation. (vllm-project#24957)
Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
1 parent d478744 commit aacd8f5

File tree

3 files changed

+139
-34
lines changed

3 files changed

+139
-34
lines changed

vllm/model_executor/layers/mamba/ops/causal_conv1d.py

Lines changed: 116 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -626,6 +626,7 @@ def _causal_conv1d_update_kernel(
626626
cache_seqlens_ptr, # circular buffer
627627
conv_state_indices_ptr,
628628
num_accepted_tokens_ptr,
629+
query_start_loc_ptr, # (batch + 1)
629630
o_ptr, # (batch, dim, seqlen)
630631
# Matrix dimensions
631632
batch: int,
@@ -652,6 +653,7 @@ def _causal_conv1d_update_kernel(
652653
HAS_BIAS: tl.constexpr,
653654
KERNEL_WIDTH: tl.constexpr,
654655
SILU_ACTIVATION: tl.constexpr,
656+
IS_VARLEN: tl.constexpr,
655657
IS_CONTINUOUS_BATCHING: tl.constexpr,
656658
IS_SPEC_DECODING: tl.constexpr,
657659
NP2_STATELEN: tl.constexpr,
@@ -678,6 +680,25 @@ def _causal_conv1d_update_kernel(
678680
# not processing as this is not the actual sequence
679681
return
680682

683+
if IS_VARLEN:
684+
query_start_index = tl.load(query_start_loc_ptr + idx_seq).to(tl.int64)
685+
query_end_index = tl.load(query_start_loc_ptr + (idx_seq + 1)).to(
686+
tl.int64)
687+
# revise state_len and seqlen
688+
state_len = state_len - (seqlen -
689+
(query_end_index - query_start_index))
690+
seqlen = query_end_index - query_start_index
691+
x_offset = query_start_index * stride_x_token
692+
o_offset = query_start_index * stride_o_token
693+
else:
694+
query_start_index = idx_seq * seqlen
695+
query_end_index = query_start_index + seqlen
696+
x_offset = idx_seq * stride_x_seq
697+
o_offset = idx_seq * stride_o_seq
698+
699+
if query_start_index == query_end_index:
700+
return
701+
681702
if IS_SPEC_DECODING:
682703
# The rolling of conv state:
683704
#
@@ -692,8 +713,8 @@ def _causal_conv1d_update_kernel(
692713
# - accept 1 tokens: [history2, ..., historyM, draft1]
693714
# - accept 2 tokens: [history3, ..., historyM, draft1, draft2]
694715
# - and so on.
695-
conv_state_token_offset = (tl.load(num_accepted_tokens_ptr + idx_seq) -
696-
1)
716+
conv_state_token_offset = (
717+
tl.load(num_accepted_tokens_ptr + idx_seq).to(tl.int64) - 1)
697718
else:
698719
conv_state_token_offset = 0
699720

@@ -713,9 +734,12 @@ def _causal_conv1d_update_kernel(
713734
if KERNEL_WIDTH >= 4:
714735
conv_states_ptrs = prior_tokens + 2 * stride_conv_state_tok # [BLOCK_N]
715736
col2 = tl.load(conv_states_ptrs, mask_w, 0.0)
716-
if KERNEL_WIDTH == 5:
737+
if KERNEL_WIDTH >= 5:
717738
conv_states_ptrs = prior_tokens + 3 * stride_conv_state_tok # [BLOCK_N]
718739
col3 = tl.load(conv_states_ptrs, mask_w, 0.0)
740+
if KERNEL_WIDTH >= 6:
741+
conv_states_ptrs = prior_tokens + 4 * stride_conv_state_tok # [BLOCK_N]
742+
col4 = tl.load(conv_states_ptrs, mask_w, 0.0)
719743

720744
# STEP 2: assume state_len > seqlen
721745
idx_tokens = tl.arange(0, NP2_STATELEN) # [BLOCK_M]
@@ -735,8 +759,7 @@ def _causal_conv1d_update_kernel(
735759
conv_state = tl.load(conv_state_ptrs_source, mask, other=0.0)
736760

737761
VAL = state_len - seqlen
738-
x_base = x_ptr + (idx_seq * stride_x_seq) + (idx_feats * stride_x_dim
739-
) # [BLOCK_N]
762+
x_base = x_ptr + x_offset + (idx_feats * stride_x_dim) # [BLOCK_N]
740763

741764
x_ptrs = x_base[None, :] + (
742765
(idx_tokens - VAL) * stride_x_token)[:, None] # [BLOCK_M, BLOCK_N]
@@ -782,12 +805,18 @@ def _causal_conv1d_update_kernel(
782805
if KERNEL_WIDTH >= 4:
783806
w_ptrs = w_base + (3 * stride_w_width) # [BLOCK_N] tensor
784807
w_col3 = tl.load(w_ptrs, mask_w, other=0.0)
808+
if KERNEL_WIDTH >= 5:
809+
w_ptrs = w_base + (4 * stride_w_width) # [BLOCK_N] tensor
810+
w_col4 = tl.load(w_ptrs, mask_w, other=0.0)
811+
if KERNEL_WIDTH >= 6:
812+
w_ptrs = w_base + (5 * stride_w_width) # [BLOCK_N] tensor
813+
w_col5 = tl.load(w_ptrs, mask_w, other=0.0)
785814

786815
x_base_1d = x_base # starting of chunk [BLOCK_N]
787816
mask_x_1d = idx_feats < dim
788817

789818
# STEP 5: compute each token
790-
for idx_token in tl.static_range(seqlen):
819+
for idx_token in tl.range(seqlen):
791820
acc = acc_preload
792821

793822
matrix_w = w_col0
@@ -817,6 +846,37 @@ def _causal_conv1d_update_kernel(
817846
matrix_w = w_col3
818847
x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
819848
matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)
849+
elif KERNEL_WIDTH == 5:
850+
if j == 1:
851+
matrix_w = w_col1
852+
matrix_x = col1
853+
elif j == 2:
854+
matrix_w = w_col2
855+
matrix_x = col2
856+
elif j == 3:
857+
matrix_w = w_col3
858+
matrix_x = col3
859+
elif j == 4:
860+
matrix_w = w_col4
861+
x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
862+
matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)
863+
elif KERNEL_WIDTH == 6:
864+
if j == 1:
865+
matrix_w = w_col1
866+
matrix_x = col1
867+
elif j == 2:
868+
matrix_w = w_col2
869+
matrix_x = col2
870+
elif j == 3:
871+
matrix_w = w_col3
872+
matrix_x = col3
873+
elif j == 4:
874+
matrix_w = w_col4
875+
matrix_x = col4
876+
elif j == 5:
877+
matrix_w = w_col5
878+
x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
879+
matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)
820880

821881
acc += matrix_x * matrix_w # [BLOCK_N]
822882

@@ -829,14 +889,24 @@ def _causal_conv1d_update_kernel(
829889
col0 = col1
830890
col1 = col2
831891
col2 = matrix_x
892+
elif KERNEL_WIDTH == 5:
893+
col0 = col1
894+
col1 = col2
895+
col2 = col3
896+
col3 = matrix_x
897+
elif KERNEL_WIDTH == 6:
898+
col0 = col1
899+
col1 = col2
900+
col2 = col3
901+
col3 = col4
902+
col4 = matrix_x
832903

833904
if SILU_ACTIVATION:
834905
acc = acc / (1 + tl.exp(-acc))
835906
mask_1d = (idx_token < seqlen) & (idx_feats < dim
836907
) # token-index # feature-index
837-
o_ptrs = o_ptr + (
838-
idx_seq) * stride_o_seq + idx_token * stride_o_token + (
839-
idx_feats * stride_o_dim)
908+
o_ptrs = o_ptr + o_offset + idx_token * stride_o_token + (idx_feats *
909+
stride_o_dim)
840910

841911
tl.store(o_ptrs, acc, mask=mask_1d)
842912

@@ -850,14 +920,18 @@ def causal_conv1d_update(
850920
cache_seqlens: Optional[torch.Tensor] = None,
851921
conv_state_indices: Optional[torch.Tensor] = None,
852922
num_accepted_tokens: Optional[torch.Tensor] = None,
923+
query_start_loc: Optional[torch.Tensor] = None,
924+
max_query_len: int = -1,
853925
pad_slot_id: int = PAD_SLOT_ID,
854926
metadata=None,
855927
validate_data=False,
856928
):
857929
"""
858-
x: (batch, dim) or (batch, dim, seqlen)
930+
x: (batch, dim) or (batch, dim, seqlen) or (num_tokens, dim)
859931
[shape=2: single token prediction]
860932
[shape=3: single or multiple tokens prediction]
933+
[shape=2 with num_tokens: continuous batching, where num_tokens is the
934+
total tokens of all sequences in that batch]
861935
conv_state: (..., dim, state_len), where state_len >= width - 1
862936
weight: (dim, width)
863937
bias: (dim,)
@@ -870,13 +944,24 @@ def causal_conv1d_update(
870944
If not None, the conv_state is a larger tensor along the batch dim,
871945
and we are selecting the batch coords specified by conv_state_indices.
872946
Useful for a continuous batching scenario.
947+
num_accepted_tokens: (batch,), dtype int32
948+
If not None, it indicates the number of accepted tokens for each
949+
sequence in the batch.
950+
This is used in speculative decoding, where the conv_state is updated
951+
in a sliding window manner.
952+
query_start_loc: (batch + 1,) int32
953+
If not None, the inputs is given in a varlen fashion and this indicates
954+
the starting index of each sequence in the batch.
955+
max_query_len: int
956+
If query_start_loc is not None, this indicates the maximum query
957+
length in the batch.
873958
pad_slot_id: int
874959
if cache_indices is passed, lets the kernel identify padded
875960
entries that will not be processed,
876961
for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id]
877962
in this case, the kernel will not process entries at
878963
indices 0 and 3
879-
out: (batch, dim) or (batch, dim, seqlen)
964+
out: (batch, dim) or (batch, dim, seqlen) or (num_tokens, dim), same shape as `x`
880965
"""
881966
if validate_data:
882967
assert cache_seqlens is None # not implemented yet - ok for vLLM
@@ -886,11 +971,17 @@ def causal_conv1d_update(
886971
activation = "silu" if activation is True else None
887972
elif activation is not None:
888973
assert activation in ["silu", "swish"]
889-
unsqueeze = x.dim() == 2
974+
unsqueeze = query_start_loc is None and x.dim() == 2
890975
if unsqueeze:
891976
# make it (batch, dim, seqlen) with seqlen == 1
892977
x = x.unsqueeze(-1)
893-
batch, dim, seqlen = x.shape
978+
if query_start_loc is None:
979+
batch, dim, seqlen = x.shape
980+
else:
981+
assert conv_state_indices is not None
982+
batch = conv_state_indices.size(0)
983+
dim = x.size(1)
984+
seqlen = max_query_len
894985
_, width = weight.shape
895986
# conv_state: (..., dim, state_len), where state_len >= width - 1
896987
num_cache_lines, _, state_len = conv_state.size()
@@ -916,10 +1007,17 @@ def causal_conv1d_update(
9161007
out = x
9171008
stride_w_dim, stride_w_width = weight.stride()
9181009

919-
stride_x_seq, stride_x_dim, stride_x_token = x.stride(
920-
) # X (batch, dim, seqlen)
1010+
if query_start_loc is None:
1011+
# X (batch, dim, seqlen)
1012+
stride_x_seq, stride_x_dim, stride_x_token = x.stride()
1013+
stride_o_seq, stride_o_dim, stride_o_token = out.stride()
1014+
else:
1015+
# X (dim, cu_seqlen)
1016+
stride_x_token, stride_x_dim = x.stride()
1017+
stride_x_seq = 0
1018+
stride_o_token, stride_o_dim = out.stride()
1019+
stride_o_seq = 0
9211020

922-
stride_o_seq, stride_o_dim, stride_o_token = out.stride()
9231021
stride_istate_seq, stride_istate_dim, stride_istate_token = conv_state.stride(
9241022
)
9251023
stride_state_indices = conv_state_indices.stride(
@@ -945,6 +1043,7 @@ def grid(META):
9451043
cache_seqlens,
9461044
conv_state_indices,
9471045
num_accepted_tokens,
1046+
query_start_loc,
9481047
out,
9491048
# Matrix dimensions
9501049
batch,
@@ -971,6 +1070,7 @@ def grid(META):
9711070
HAS_BIAS=bias is not None,
9721071
KERNEL_WIDTH=width,
9731072
SILU_ACTIVATION=activation in ["silu", "swish"],
1073+
IS_VARLEN=query_start_loc is not None,
9741074
IS_CONTINUOUS_BATCHING=conv_state_indices is not None,
9751075
IS_SPEC_DECODING=num_accepted_tokens is not None,
9761076
NP2_STATELEN=np2_statelen,

vllm/model_executor/models/qwen3_next.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -417,9 +417,7 @@ def _forward(
417417
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
418418
conv_state = self_kv_cache[0].transpose(-1, -2)
419419
ssm_state = self_kv_cache[1]
420-
num_actual_tokens = (attn_metadata.num_prefill_tokens +
421-
attn_metadata.num_decode_tokens +
422-
attn_metadata.num_spec_decode_tokens)
420+
num_actual_tokens = attn_metadata.num_actual_tokens
423421
num_accepted_tokens = attn_metadata.num_accepted_tokens
424422

425423
# 1. Set up dimensions for reshapes later
@@ -458,9 +456,6 @@ def _forward(
458456

459457
# 2.1: process the mutli-query part
460458
if spec_sequence_masks is not None:
461-
mixed_qkv_spec = mixed_qkv_spec.view(
462-
attn_metadata.num_spec_decodes, -1, mixed_qkv_spec.size(-1))
463-
mixed_qkv_spec = rearrange(mixed_qkv_spec, 'b l d -> b d l')
464459
mixed_qkv_spec = causal_conv1d_update(
465460
mixed_qkv_spec,
466461
conv_state,
@@ -470,9 +465,10 @@ def _forward(
470465
conv_state_indices=spec_state_indices_tensor[:, 0]
471466
[:attn_metadata.num_spec_decodes],
472467
num_accepted_tokens=num_accepted_tokens,
468+
query_start_loc=spec_query_start_loc,
469+
max_query_len=spec_state_indices_tensor.size(-1),
473470
validate_data=False,
474471
)
475-
mixed_qkv_spec = rearrange(mixed_qkv_spec, 'b d l -> (b l) d')
476472

477473
# 2.2: process the remaining part
478474
if attn_metadata.num_prefills > 0:

vllm/v1/attention/backends/gdn_attn.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ class GDNAttentionMetadata:
3131
num_decode_tokens: int
3232
num_spec_decodes: int
3333
num_spec_decode_tokens: int
34+
num_actual_tokens: int
3435

3536
has_initial_state: Optional[torch.Tensor] = None
3637

@@ -74,8 +75,8 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
7475
self.use_full_cuda_graph = \
7576
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
7677
self.decode_cudagraph_max_bs = min(
77-
self.vllm_config.scheduler_config.max_num_seqs,
78-
self.compilation_config.max_capture_size)
78+
self.vllm_config.scheduler_config.max_num_seqs *
79+
(self.num_spec + 1), self.compilation_config.max_capture_size)
7980

8081
self.spec_state_indices_tensor = torch.empty(
8182
(self.decode_cudagraph_max_bs, self.num_spec + 1),
@@ -194,9 +195,8 @@ def build( # type: ignore[override]
194195
dim=0,
195196
out=non_spec_query_start_loc[1:])
196197

197-
num_spec_decode_tokens = min(
198-
num_spec_decodes * (self.num_spec + 1),
199-
spec_token_masks.size(0))
198+
num_spec_decode_tokens = (query_lens.sum().item() -
199+
num_prefill_tokens - num_decode_tokens)
200200
assert num_accepted_tokens is not None
201201
num_accepted_tokens = num_accepted_tokens[spec_sequence_masks]
202202

@@ -206,14 +206,22 @@ def build( # type: ignore[override]
206206
has_initial_state = has_initial_state[~spec_sequence_masks]
207207
else:
208208
has_initial_state = None
209+
num_actual_tokens = num_prefill_tokens + num_decode_tokens + \
210+
num_spec_decode_tokens
209211

210212
# prepare tensors for cudagraph
213+
#
214+
# With speculative decoding, the xgrammar backend may rollback tokens
215+
# and causing some sequences has less draft tokens than self.num_spec.
216+
#
217+
# In above cases, the max possible batch size for n tokens, can be
218+
# min(n, cudagraph_max_bs).
211219
if (self.use_full_cuda_graph and num_prefills == 0 and num_decodes == 0
212220
and num_spec_decodes <= self.decode_cudagraph_max_bs
213-
and m.num_actual_tokens <= self.decode_cudagraph_max_bs):
214-
num_total_tokens = self.vllm_config.pad_for_cudagraph(
221+
and num_spec_decode_tokens <= self.decode_cudagraph_max_bs):
222+
num_actual_tokens = self.vllm_config.pad_for_cudagraph(
215223
m.num_actual_tokens)
216-
batch_size = num_total_tokens // (self.num_spec + 1)
224+
batch_size = min(self.decode_cudagraph_max_bs, num_actual_tokens)
217225

218226
self.spec_state_indices_tensor[:num_spec_decodes].copy_(
219227
spec_state_indices_tensor, non_blocking=True)
@@ -229,7 +237,7 @@ def build( # type: ignore[override]
229237
assert spec_token_masks is not None
230238
self.spec_token_masks[:spec_token_masks.size(0)].copy_(
231239
spec_token_masks, non_blocking=True)
232-
spec_token_masks = self.spec_token_masks[:m.num_actual_tokens]
240+
spec_token_masks = self.spec_token_masks[:num_actual_tokens]
233241
spec_token_masks[spec_token_masks.size(0):].fill_(False)
234242

235243
self.spec_query_start_loc[:num_spec_decodes + 1].copy_(
@@ -248,9 +256,9 @@ def build( # type: ignore[override]
248256
if (self.use_full_cuda_graph and num_prefills == 0
249257
and num_spec_decodes == 0
250258
and num_decodes <= self.decode_cudagraph_max_bs):
251-
num_total_tokens = self.vllm_config.pad_for_cudagraph(
259+
num_actual_tokens = self.vllm_config.pad_for_cudagraph(
252260
m.num_actual_tokens)
253-
batch_size = num_total_tokens
261+
batch_size = num_actual_tokens
254262

255263
self.non_spec_state_indices_tensor[:num_decodes].copy_(
256264
non_spec_state_indices_tensor, non_blocking=True)
@@ -274,6 +282,7 @@ def build( # type: ignore[override]
274282
num_decode_tokens=num_decode_tokens,
275283
num_spec_decodes=num_spec_decodes,
276284
num_spec_decode_tokens=num_spec_decode_tokens,
285+
num_actual_tokens=num_actual_tokens,
277286
has_initial_state=has_initial_state,
278287
spec_query_start_loc=spec_query_start_loc,
279288
non_spec_query_start_loc=non_spec_query_start_loc,

0 commit comments

Comments
 (0)