Skip to content

Commit 61bb8b8

Browse files
committed
[Bugfix][Qwen3-Next] fixes the varlen issue in qwen3-next's MTP implementation.
Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com>
1 parent a0d8b97 commit 61bb8b8

File tree

3 files changed

+131
-26
lines changed

3 files changed

+131
-26
lines changed

vllm/model_executor/layers/mamba/ops/causal_conv1d.py

Lines changed: 113 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -626,6 +626,7 @@ def _causal_conv1d_update_kernel(
626626
cache_seqlens_ptr, # circular buffer
627627
conv_state_indices_ptr,
628628
num_accepted_tokens_ptr,
629+
query_start_loc_ptr, # (batch + 1)
629630
o_ptr, # (batch, dim, seqlen)
630631
# Matrix dimensions
631632
batch: int,
@@ -652,6 +653,7 @@ def _causal_conv1d_update_kernel(
652653
HAS_BIAS: tl.constexpr,
653654
KERNEL_WIDTH: tl.constexpr,
654655
SILU_ACTIVATION: tl.constexpr,
656+
IS_VARLEN: tl.constexpr,
655657
IS_CONTINUOUS_BATCHING: tl.constexpr,
656658
IS_SPEC_DECODING: tl.constexpr,
657659
NP2_STATELEN: tl.constexpr,
@@ -692,8 +694,8 @@ def _causal_conv1d_update_kernel(
692694
# - accept 1 tokens: [history2, ..., historyM, draft1]
693695
# - accept 2 tokens: [history3, ..., historyM, draft1, draft2]
694696
# - and so on.
695-
conv_state_token_offset = (tl.load(num_accepted_tokens_ptr + idx_seq) -
696-
1)
697+
conv_state_token_offset = (
698+
tl.load(num_accepted_tokens_ptr + idx_seq).to(tl.int64) - 1)
697699
else:
698700
conv_state_token_offset = 0
699701

@@ -713,9 +715,28 @@ def _causal_conv1d_update_kernel(
713715
if KERNEL_WIDTH >= 4:
714716
conv_states_ptrs = prior_tokens + 2 * stride_conv_state_tok # [BLOCK_N]
715717
col2 = tl.load(conv_states_ptrs, mask_w, 0.0)
716-
if KERNEL_WIDTH == 5:
718+
if KERNEL_WIDTH >= 5:
717719
conv_states_ptrs = prior_tokens + 3 * stride_conv_state_tok # [BLOCK_N]
718720
col3 = tl.load(conv_states_ptrs, mask_w, 0.0)
721+
if KERNEL_WIDTH >= 6:
722+
conv_states_ptrs = prior_tokens + 4 * stride_conv_state_tok # [BLOCK_N]
723+
col4 = tl.load(conv_states_ptrs, mask_w, 0.0)
724+
725+
if IS_VARLEN:
726+
query_start_index = tl.load(query_start_loc_ptr + idx_seq).to(tl.int64)
727+
query_end_index = tl.load(query_start_loc_ptr + (idx_seq + 1)).to(
728+
tl.int64)
729+
# revise state_len and seqlen
730+
state_len = state_len - (seqlen -
731+
(query_end_index - query_start_index))
732+
seqlen = query_end_index - query_start_index
733+
x_offset = query_start_index * stride_x_token
734+
o_offset = query_start_index * stride_o_token
735+
else:
736+
query_start_index = idx_seq * seqlen
737+
query_end_index = query_start_index + seqlen
738+
x_offset = idx_seq * stride_x_seq
739+
o_offset = idx_seq * stride_o_seq
719740

720741
# STEP 2: assume state_len > seqlen
721742
idx_tokens = tl.arange(0, NP2_STATELEN) # [BLOCK_M]
@@ -735,8 +756,7 @@ def _causal_conv1d_update_kernel(
735756
conv_state = tl.load(conv_state_ptrs_source, mask, other=0.0)
736757

737758
VAL = state_len - seqlen
738-
x_base = x_ptr + (idx_seq * stride_x_seq) + (idx_feats * stride_x_dim
739-
) # [BLOCK_N]
759+
x_base = x_ptr + x_offset + (idx_feats * stride_x_dim) # [BLOCK_N]
740760

741761
x_ptrs = x_base[None, :] + (
742762
(idx_tokens - VAL) * stride_x_token)[:, None] # [BLOCK_M, BLOCK_N]
@@ -782,12 +802,18 @@ def _causal_conv1d_update_kernel(
782802
if KERNEL_WIDTH >= 4:
783803
w_ptrs = w_base + (3 * stride_w_width) # [BLOCK_N] tensor
784804
w_col3 = tl.load(w_ptrs, mask_w, other=0.0)
805+
if KERNEL_WIDTH >= 5:
806+
w_ptrs = w_base + (4 * stride_w_width) # [BLOCK_N] tensor
807+
w_col4 = tl.load(w_ptrs, mask_w, other=0.0)
808+
if KERNEL_WIDTH >= 6:
809+
w_ptrs = w_base + (5 * stride_w_width) # [BLOCK_N] tensor
810+
w_col5 = tl.load(w_ptrs, mask_w, other=0.0)
785811

786812
x_base_1d = x_base # starting of chunk [BLOCK_N]
787813
mask_x_1d = idx_feats < dim
788814

789815
# STEP 5: compute each token
790-
for idx_token in tl.static_range(seqlen):
816+
for idx_token in tl.range(seqlen):
791817
acc = acc_preload
792818

793819
matrix_w = w_col0
@@ -817,6 +843,37 @@ def _causal_conv1d_update_kernel(
817843
matrix_w = w_col3
818844
x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
819845
matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)
846+
elif KERNEL_WIDTH == 5:
847+
if j == 1:
848+
matrix_w = w_col1
849+
matrix_x = col1
850+
elif j == 2:
851+
matrix_w = w_col2
852+
matrix_x = col2
853+
elif j == 3:
854+
matrix_w = w_col3
855+
matrix_x = col3
856+
elif j == 4:
857+
matrix_w = w_col4
858+
x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
859+
matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)
860+
elif KERNEL_WIDTH == 6:
861+
if j == 1:
862+
matrix_w = w_col1
863+
matrix_x = col1
864+
elif j == 2:
865+
matrix_w = w_col2
866+
matrix_x = col2
867+
elif j == 3:
868+
matrix_w = w_col3
869+
matrix_x = col3
870+
elif j == 4:
871+
matrix_w = w_col4
872+
matrix_x = col4
873+
elif j == 5:
874+
matrix_w = w_col5
875+
x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
876+
matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d)
820877

821878
acc += matrix_x * matrix_w # [BLOCK_N]
822879

@@ -829,14 +886,24 @@ def _causal_conv1d_update_kernel(
829886
col0 = col1
830887
col1 = col2
831888
col2 = matrix_x
889+
elif KERNEL_WIDTH == 5:
890+
col0 = col1
891+
col1 = col2
892+
col2 = col3
893+
col3 = matrix_x
894+
elif KERNEL_WIDTH == 6:
895+
col0 = col1
896+
col1 = col2
897+
col2 = col3
898+
col3 = col4
899+
col4 = matrix_x
832900

833901
if SILU_ACTIVATION:
834902
acc = acc / (1 + tl.exp(-acc))
835903
mask_1d = (idx_token < seqlen) & (idx_feats < dim
836904
) # token-index # feature-index
837-
o_ptrs = o_ptr + (
838-
idx_seq) * stride_o_seq + idx_token * stride_o_token + (
839-
idx_feats * stride_o_dim)
905+
o_ptrs = o_ptr + o_offset + idx_token * stride_o_token + (idx_feats *
906+
stride_o_dim)
840907

841908
tl.store(o_ptrs, acc, mask=mask_1d)
842909

@@ -850,14 +917,18 @@ def causal_conv1d_update(
850917
cache_seqlens: Optional[torch.Tensor] = None,
851918
conv_state_indices: Optional[torch.Tensor] = None,
852919
num_accepted_tokens: Optional[torch.Tensor] = None,
920+
query_start_loc: Optional[torch.Tensor] = None,
921+
max_query_len: int = -1,
853922
pad_slot_id: int = PAD_SLOT_ID,
854923
metadata=None,
855924
validate_data=False,
856925
):
857926
"""
858-
x: (batch, dim) or (batch, dim, seqlen)
927+
x: (batch, dim) or (batch, dim, seqlen) or (num_tokens, dim)
859928
[shape=2: single token prediction]
860929
[shape=3: single or multiple tokens prediction]
930+
[shape=2 with num_tokens: continuous batching, where num_tokens is the
931+
total tokens of all sequences in that batch]
861932
conv_state: (..., dim, state_len), where state_len >= width - 1
862933
weight: (dim, width)
863934
bias: (dim,)
@@ -870,13 +941,24 @@ def causal_conv1d_update(
870941
If not None, the conv_state is a larger tensor along the batch dim,
871942
and we are selecting the batch coords specified by conv_state_indices.
872943
Useful for a continuous batching scenario.
944+
num_accepted_tokens: (batch,), dtype int32
945+
If not None, it indicates the number of accepted tokens for each
946+
sequence in the batch.
947+
This is used in speculative decoding, where the conv_state is updated
948+
in a sliding window manner.
949+
query_start_loc: (batch + 1,) int32
950+
If not None, the inputs is given in a varlen fashion and this indicates
951+
the starting index of each sequence in the batch.
952+
max_query_len: int
953+
If query_start_loc is not None, this indicates the maximum query
954+
length in the batch.
873955
pad_slot_id: int
874956
if cache_indices is passed, lets the kernel identify padded
875957
entries that will not be processed,
876958
for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id]
877959
in this case, the kernel will not process entries at
878960
indices 0 and 3
879-
out: (batch, dim) or (batch, dim, seqlen)
961+
out: (batch, dim) or (batch, dim, seqlen) or (num_tokens, dim), same shape as `x`
880962
"""
881963
if validate_data:
882964
assert cache_seqlens is None # not implemented yet - ok for vLLM
@@ -886,11 +968,17 @@ def causal_conv1d_update(
886968
activation = "silu" if activation is True else None
887969
elif activation is not None:
888970
assert activation in ["silu", "swish"]
889-
unsqueeze = x.dim() == 2
971+
unsqueeze = query_start_loc is None and x.dim() == 2
890972
if unsqueeze:
891973
# make it (batch, dim, seqlen) with seqlen == 1
892974
x = x.unsqueeze(-1)
893-
batch, dim, seqlen = x.shape
975+
if query_start_loc is None:
976+
batch, dim, seqlen = x.shape
977+
else:
978+
assert conv_state_indices is not None
979+
batch = conv_state_indices.size(0)
980+
dim = x.size(1)
981+
seqlen = max_query_len
894982
_, width = weight.shape
895983
# conv_state: (..., dim, state_len), where state_len >= width - 1
896984
num_cache_lines, _, state_len = conv_state.size()
@@ -916,10 +1004,17 @@ def causal_conv1d_update(
9161004
out = x
9171005
stride_w_dim, stride_w_width = weight.stride()
9181006

919-
stride_x_seq, stride_x_dim, stride_x_token = x.stride(
920-
) # X (batch, dim, seqlen)
1007+
if query_start_loc is None:
1008+
# X (batch, dim, seqlen)
1009+
stride_x_seq, stride_x_dim, stride_x_token = x.stride()
1010+
stride_o_seq, stride_o_dim, stride_o_token = out.stride()
1011+
else:
1012+
# X (dim, cu_seqlen)
1013+
stride_x_token, stride_x_dim = x.stride()
1014+
stride_x_seq = 0
1015+
stride_o_token, stride_o_dim = out.stride()
1016+
stride_o_seq = 0
9211017

922-
stride_o_seq, stride_o_dim, stride_o_token = out.stride()
9231018
stride_istate_seq, stride_istate_dim, stride_istate_token = conv_state.stride(
9241019
)
9251020
stride_state_indices = conv_state_indices.stride(
@@ -945,6 +1040,7 @@ def grid(META):
9451040
cache_seqlens,
9461041
conv_state_indices,
9471042
num_accepted_tokens,
1043+
query_start_loc,
9481044
out,
9491045
# Matrix dimensions
9501046
batch,
@@ -971,6 +1067,7 @@ def grid(META):
9711067
HAS_BIAS=bias is not None,
9721068
KERNEL_WIDTH=width,
9731069
SILU_ACTIVATION=activation in ["silu", "swish"],
1070+
IS_VARLEN=query_start_loc is not None,
9741071
IS_CONTINUOUS_BATCHING=conv_state_indices is not None,
9751072
IS_SPEC_DECODING=num_accepted_tokens is not None,
9761073
NP2_STATELEN=np2_statelen,

vllm/model_executor/models/qwen3_next.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -458,9 +458,6 @@ def _forward(
458458

459459
# 2.1: process the mutli-query part
460460
if spec_sequence_masks is not None:
461-
mixed_qkv_spec = mixed_qkv_spec.view(
462-
attn_metadata.num_spec_decodes, -1, mixed_qkv_spec.size(-1))
463-
mixed_qkv_spec = rearrange(mixed_qkv_spec, 'b l d -> b d l')
464461
mixed_qkv_spec = causal_conv1d_update(
465462
mixed_qkv_spec,
466463
conv_state,
@@ -470,9 +467,10 @@ def _forward(
470467
conv_state_indices=spec_state_indices_tensor[:, 0]
471468
[:attn_metadata.num_spec_decodes],
472469
num_accepted_tokens=num_accepted_tokens,
470+
query_start_loc=spec_query_start_loc,
471+
max_query_len=spec_state_indices_tensor.size(-1),
473472
validate_data=False,
474473
)
475-
mixed_qkv_spec = rearrange(mixed_qkv_spec, 'b d l -> (b l) d')
476474

477475
# 2.2: process the remaining part
478476
if attn_metadata.num_prefills > 0:

vllm/v1/attention/backends/gdn_attn.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,8 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
7474
self.use_full_cuda_graph = \
7575
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
7676
self.decode_cudagraph_max_bs = min(
77-
self.vllm_config.scheduler_config.max_num_seqs,
78-
self.compilation_config.max_capture_size)
77+
self.vllm_config.scheduler_config.max_num_seqs *
78+
(self.num_spec + 1), self.compilation_config.max_capture_size)
7979

8080
self.spec_state_indices_tensor = torch.empty(
8181
(self.decode_cudagraph_max_bs, self.num_spec + 1),
@@ -194,9 +194,8 @@ def build( # type: ignore[override]
194194
dim=0,
195195
out=non_spec_query_start_loc[1:])
196196

197-
num_spec_decode_tokens = min(
198-
num_spec_decodes * (self.num_spec + 1),
199-
spec_token_masks.size(0))
197+
num_spec_decode_tokens = (query_lens.sum().item() -
198+
num_prefill_tokens - num_decode_tokens)
200199
assert num_accepted_tokens is not None
201200
num_accepted_tokens = num_accepted_tokens[spec_sequence_masks]
202201

@@ -208,9 +207,20 @@ def build( # type: ignore[override]
208207
has_initial_state = None
209208

210209
# prepare tensors for cudagraph
210+
#
211+
# With speculative decoding, the xgrammar backend may rollback tokens
212+
# and causing some sequences has less draft tokens than self.num_spec.
213+
#
214+
# During cudagraph capture, the GDN backends requires an assumption
215+
# that num_spec_decode_tokens == num_spec_decodes * (self.num_spec + 1).
216+
#
217+
# More than one such sequences may break the assumption (less tokens),
218+
# causing incompatible inputs for cuda graph replay.
211219
if (self.use_full_cuda_graph and num_prefills == 0 and num_decodes == 0
212220
and num_spec_decodes <= self.decode_cudagraph_max_bs
213-
and m.num_actual_tokens <= self.decode_cudagraph_max_bs):
221+
and num_spec_decode_tokens <= self.decode_cudagraph_max_bs
222+
and num_spec_decode_tokens == num_spec_decodes *
223+
(self.num_spec + 1)):
214224
num_total_tokens = self.vllm_config.pad_for_cudagraph(
215225
m.num_actual_tokens)
216226
batch_size = num_total_tokens // (self.num_spec + 1)

0 commit comments

Comments
 (0)