@@ -626,6 +626,7 @@ def _causal_conv1d_update_kernel(
626626 cache_seqlens_ptr , # circular buffer
627627 conv_state_indices_ptr ,
628628 num_accepted_tokens_ptr ,
629+ query_start_loc_ptr , # (batch + 1)
629630 o_ptr , # (batch, dim, seqlen)
630631 # Matrix dimensions
631632 batch : int ,
@@ -652,6 +653,7 @@ def _causal_conv1d_update_kernel(
652653 HAS_BIAS : tl .constexpr ,
653654 KERNEL_WIDTH : tl .constexpr ,
654655 SILU_ACTIVATION : tl .constexpr ,
656+ IS_VARLEN : tl .constexpr ,
655657 IS_CONTINUOUS_BATCHING : tl .constexpr ,
656658 IS_SPEC_DECODING : tl .constexpr ,
657659 NP2_STATELEN : tl .constexpr ,
@@ -678,6 +680,25 @@ def _causal_conv1d_update_kernel(
678680 # not processing as this is not the actual sequence
679681 return
680682
683+ if IS_VARLEN :
684+ query_start_index = tl .load (query_start_loc_ptr + idx_seq ).to (tl .int64 )
685+ query_end_index = tl .load (query_start_loc_ptr + (idx_seq + 1 )).to (
686+ tl .int64 )
687+ # revise state_len and seqlen
688+ state_len = state_len - (seqlen -
689+ (query_end_index - query_start_index ))
690+ seqlen = query_end_index - query_start_index
691+ x_offset = query_start_index * stride_x_token
692+ o_offset = query_start_index * stride_o_token
693+ else :
694+ query_start_index = idx_seq * seqlen
695+ query_end_index = query_start_index + seqlen
696+ x_offset = idx_seq * stride_x_seq
697+ o_offset = idx_seq * stride_o_seq
698+
699+ if query_start_index == query_end_index :
700+ return
701+
681702 if IS_SPEC_DECODING :
682703 # The rolling of conv state:
683704 #
@@ -692,8 +713,8 @@ def _causal_conv1d_update_kernel(
692713 # - accept 1 tokens: [history2, ..., historyM, draft1]
693714 # - accept 2 tokens: [history3, ..., historyM, draft1, draft2]
694715 # - and so on.
695- conv_state_token_offset = (tl . load ( num_accepted_tokens_ptr + idx_seq ) -
696- 1 )
716+ conv_state_token_offset = (
717+ tl . load ( num_accepted_tokens_ptr + idx_seq ). to ( tl . int64 ) - 1 )
697718 else :
698719 conv_state_token_offset = 0
699720
@@ -713,9 +734,12 @@ def _causal_conv1d_update_kernel(
713734 if KERNEL_WIDTH >= 4 :
714735 conv_states_ptrs = prior_tokens + 2 * stride_conv_state_tok # [BLOCK_N]
715736 col2 = tl .load (conv_states_ptrs , mask_w , 0.0 )
716- if KERNEL_WIDTH = = 5 :
737+ if KERNEL_WIDTH > = 5 :
717738 conv_states_ptrs = prior_tokens + 3 * stride_conv_state_tok # [BLOCK_N]
718739 col3 = tl .load (conv_states_ptrs , mask_w , 0.0 )
740+ if KERNEL_WIDTH >= 6 :
741+ conv_states_ptrs = prior_tokens + 4 * stride_conv_state_tok # [BLOCK_N]
742+ col4 = tl .load (conv_states_ptrs , mask_w , 0.0 )
719743
720744 # STEP 2: assume state_len > seqlen
721745 idx_tokens = tl .arange (0 , NP2_STATELEN ) # [BLOCK_M]
@@ -735,8 +759,7 @@ def _causal_conv1d_update_kernel(
735759 conv_state = tl .load (conv_state_ptrs_source , mask , other = 0.0 )
736760
737761 VAL = state_len - seqlen
738- x_base = x_ptr + (idx_seq * stride_x_seq ) + (idx_feats * stride_x_dim
739- ) # [BLOCK_N]
762+ x_base = x_ptr + x_offset + (idx_feats * stride_x_dim ) # [BLOCK_N]
740763
741764 x_ptrs = x_base [None , :] + (
742765 (idx_tokens - VAL ) * stride_x_token )[:, None ] # [BLOCK_M, BLOCK_N]
@@ -782,12 +805,18 @@ def _causal_conv1d_update_kernel(
782805 if KERNEL_WIDTH >= 4 :
783806 w_ptrs = w_base + (3 * stride_w_width ) # [BLOCK_N] tensor
784807 w_col3 = tl .load (w_ptrs , mask_w , other = 0.0 )
808+ if KERNEL_WIDTH >= 5 :
809+ w_ptrs = w_base + (4 * stride_w_width ) # [BLOCK_N] tensor
810+ w_col4 = tl .load (w_ptrs , mask_w , other = 0.0 )
811+ if KERNEL_WIDTH >= 6 :
812+ w_ptrs = w_base + (5 * stride_w_width ) # [BLOCK_N] tensor
813+ w_col5 = tl .load (w_ptrs , mask_w , other = 0.0 )
785814
786815 x_base_1d = x_base # starting of chunk [BLOCK_N]
787816 mask_x_1d = idx_feats < dim
788817
789818 # STEP 5: compute each token
790- for idx_token in tl .static_range (seqlen ):
819+ for idx_token in tl .range (seqlen ):
791820 acc = acc_preload
792821
793822 matrix_w = w_col0
@@ -817,6 +846,37 @@ def _causal_conv1d_update_kernel(
817846 matrix_w = w_col3
818847 x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
819848 matrix_x = tl .load (x_ptrs_1d , mask = mask_x_1d )
849+ elif KERNEL_WIDTH == 5 :
850+ if j == 1 :
851+ matrix_w = w_col1
852+ matrix_x = col1
853+ elif j == 2 :
854+ matrix_w = w_col2
855+ matrix_x = col2
856+ elif j == 3 :
857+ matrix_w = w_col3
858+ matrix_x = col3
859+ elif j == 4 :
860+ matrix_w = w_col4
861+ x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
862+ matrix_x = tl .load (x_ptrs_1d , mask = mask_x_1d )
863+ elif KERNEL_WIDTH == 6 :
864+ if j == 1 :
865+ matrix_w = w_col1
866+ matrix_x = col1
867+ elif j == 2 :
868+ matrix_w = w_col2
869+ matrix_x = col2
870+ elif j == 3 :
871+ matrix_w = w_col3
872+ matrix_x = col3
873+ elif j == 4 :
874+ matrix_w = w_col4
875+ matrix_x = col4
876+ elif j == 5 :
877+ matrix_w = w_col5
878+ x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
879+ matrix_x = tl .load (x_ptrs_1d , mask = mask_x_1d )
820880
821881 acc += matrix_x * matrix_w # [BLOCK_N]
822882
@@ -829,14 +889,24 @@ def _causal_conv1d_update_kernel(
829889 col0 = col1
830890 col1 = col2
831891 col2 = matrix_x
892+ elif KERNEL_WIDTH == 5 :
893+ col0 = col1
894+ col1 = col2
895+ col2 = col3
896+ col3 = matrix_x
897+ elif KERNEL_WIDTH == 6 :
898+ col0 = col1
899+ col1 = col2
900+ col2 = col3
901+ col3 = col4
902+ col4 = matrix_x
832903
833904 if SILU_ACTIVATION :
834905 acc = acc / (1 + tl .exp (- acc ))
835906 mask_1d = (idx_token < seqlen ) & (idx_feats < dim
836907 ) # token-index # feature-index
837- o_ptrs = o_ptr + (
838- idx_seq ) * stride_o_seq + idx_token * stride_o_token + (
839- idx_feats * stride_o_dim )
908+ o_ptrs = o_ptr + o_offset + idx_token * stride_o_token + (idx_feats *
909+ stride_o_dim )
840910
841911 tl .store (o_ptrs , acc , mask = mask_1d )
842912
@@ -850,14 +920,18 @@ def causal_conv1d_update(
850920 cache_seqlens : Optional [torch .Tensor ] = None ,
851921 conv_state_indices : Optional [torch .Tensor ] = None ,
852922 num_accepted_tokens : Optional [torch .Tensor ] = None ,
923+ query_start_loc : Optional [torch .Tensor ] = None ,
924+ max_query_len : int = - 1 ,
853925 pad_slot_id : int = PAD_SLOT_ID ,
854926 metadata = None ,
855927 validate_data = False ,
856928):
857929 """
858- x: (batch, dim) or (batch, dim, seqlen)
930+ x: (batch, dim) or (batch, dim, seqlen) or (num_tokens, dim)
859931 [shape=2: single token prediction]
860932 [shape=3: single or multiple tokens prediction]
933+ [shape=2 with num_tokens: continuous batching, where num_tokens is the
934+ total tokens of all sequences in that batch]
861935 conv_state: (..., dim, state_len), where state_len >= width - 1
862936 weight: (dim, width)
863937 bias: (dim,)
@@ -870,13 +944,24 @@ def causal_conv1d_update(
870944 If not None, the conv_state is a larger tensor along the batch dim,
871945 and we are selecting the batch coords specified by conv_state_indices.
872946 Useful for a continuous batching scenario.
947+ num_accepted_tokens: (batch,), dtype int32
948+ If not None, it indicates the number of accepted tokens for each
949+ sequence in the batch.
950+ This is used in speculative decoding, where the conv_state is updated
951+ in a sliding window manner.
952+ query_start_loc: (batch + 1,) int32
953+ If not None, the inputs is given in a varlen fashion and this indicates
954+ the starting index of each sequence in the batch.
955+ max_query_len: int
956+ If query_start_loc is not None, this indicates the maximum query
957+ length in the batch.
873958 pad_slot_id: int
874959 if cache_indices is passed, lets the kernel identify padded
875960 entries that will not be processed,
876961 for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id]
877962 in this case, the kernel will not process entries at
878963 indices 0 and 3
879- out: (batch, dim) or (batch, dim, seqlen)
964+ out: (batch, dim) or (batch, dim, seqlen) or (num_tokens, dim), same shape as `x`
880965 """
881966 if validate_data :
882967 assert cache_seqlens is None # not implemented yet - ok for vLLM
@@ -886,11 +971,17 @@ def causal_conv1d_update(
886971 activation = "silu" if activation is True else None
887972 elif activation is not None :
888973 assert activation in ["silu" , "swish" ]
889- unsqueeze = x .dim () == 2
974+ unsqueeze = query_start_loc is None and x .dim () == 2
890975 if unsqueeze :
891976 # make it (batch, dim, seqlen) with seqlen == 1
892977 x = x .unsqueeze (- 1 )
893- batch , dim , seqlen = x .shape
978+ if query_start_loc is None :
979+ batch , dim , seqlen = x .shape
980+ else :
981+ assert conv_state_indices is not None
982+ batch = conv_state_indices .size (0 )
983+ dim = x .size (1 )
984+ seqlen = max_query_len
894985 _ , width = weight .shape
895986 # conv_state: (..., dim, state_len), where state_len >= width - 1
896987 num_cache_lines , _ , state_len = conv_state .size ()
@@ -916,10 +1007,17 @@ def causal_conv1d_update(
9161007 out = x
9171008 stride_w_dim , stride_w_width = weight .stride ()
9181009
919- stride_x_seq , stride_x_dim , stride_x_token = x .stride (
920- ) # X (batch, dim, seqlen)
1010+ if query_start_loc is None :
1011+ # X (batch, dim, seqlen)
1012+ stride_x_seq , stride_x_dim , stride_x_token = x .stride ()
1013+ stride_o_seq , stride_o_dim , stride_o_token = out .stride ()
1014+ else :
1015+ # X (dim, cu_seqlen)
1016+ stride_x_token , stride_x_dim = x .stride ()
1017+ stride_x_seq = 0
1018+ stride_o_token , stride_o_dim = out .stride ()
1019+ stride_o_seq = 0
9211020
922- stride_o_seq , stride_o_dim , stride_o_token = out .stride ()
9231021 stride_istate_seq , stride_istate_dim , stride_istate_token = conv_state .stride (
9241022 )
9251023 stride_state_indices = conv_state_indices .stride (
@@ -945,6 +1043,7 @@ def grid(META):
9451043 cache_seqlens ,
9461044 conv_state_indices ,
9471045 num_accepted_tokens ,
1046+ query_start_loc ,
9481047 out ,
9491048 # Matrix dimensions
9501049 batch ,
@@ -971,6 +1070,7 @@ def grid(META):
9711070 HAS_BIAS = bias is not None ,
9721071 KERNEL_WIDTH = width ,
9731072 SILU_ACTIVATION = activation in ["silu" , "swish" ],
1073+ IS_VARLEN = query_start_loc is not None ,
9741074 IS_CONTINUOUS_BATCHING = conv_state_indices is not None ,
9751075 IS_SPEC_DECODING = num_accepted_tokens is not None ,
9761076 NP2_STATELEN = np2_statelen ,
0 commit comments