@@ -626,6 +626,7 @@ def _causal_conv1d_update_kernel(
626626 cache_seqlens_ptr , # circular buffer
627627 conv_state_indices_ptr ,
628628 num_accepted_tokens_ptr ,
629+ query_start_loc_ptr , # (batch + 1)
629630 o_ptr , # (batch, dim, seqlen)
630631 # Matrix dimensions
631632 batch : int ,
@@ -652,6 +653,7 @@ def _causal_conv1d_update_kernel(
652653 HAS_BIAS : tl .constexpr ,
653654 KERNEL_WIDTH : tl .constexpr ,
654655 SILU_ACTIVATION : tl .constexpr ,
656+ IS_VARLEN : tl .constexpr ,
655657 IS_CONTINUOUS_BATCHING : tl .constexpr ,
656658 IS_SPEC_DECODING : tl .constexpr ,
657659 NP2_STATELEN : tl .constexpr ,
@@ -692,8 +694,8 @@ def _causal_conv1d_update_kernel(
692694 # - accept 1 tokens: [history2, ..., historyM, draft1]
693695 # - accept 2 tokens: [history3, ..., historyM, draft1, draft2]
694696 # - and so on.
695- conv_state_token_offset = (tl . load ( num_accepted_tokens_ptr + idx_seq ) -
696- 1 )
697+ conv_state_token_offset = (
698+ tl . load ( num_accepted_tokens_ptr + idx_seq ). to ( tl . int64 ) - 1 )
697699 else :
698700 conv_state_token_offset = 0
699701
@@ -713,9 +715,28 @@ def _causal_conv1d_update_kernel(
713715 if KERNEL_WIDTH >= 4 :
714716 conv_states_ptrs = prior_tokens + 2 * stride_conv_state_tok # [BLOCK_N]
715717 col2 = tl .load (conv_states_ptrs , mask_w , 0.0 )
716- if KERNEL_WIDTH = = 5 :
718+ if KERNEL_WIDTH > = 5 :
717719 conv_states_ptrs = prior_tokens + 3 * stride_conv_state_tok # [BLOCK_N]
718720 col3 = tl .load (conv_states_ptrs , mask_w , 0.0 )
721+ if KERNEL_WIDTH >= 6 :
722+ conv_states_ptrs = prior_tokens + 4 * stride_conv_state_tok # [BLOCK_N]
723+ col4 = tl .load (conv_states_ptrs , mask_w , 0.0 )
724+
725+ if IS_VARLEN :
726+ query_start_index = tl .load (query_start_loc_ptr + idx_seq ).to (tl .int64 )
727+ query_end_index = tl .load (query_start_loc_ptr + (idx_seq + 1 )).to (
728+ tl .int64 )
729+ # revise state_len and seqlen
730+ state_len = state_len - (seqlen -
731+ (query_end_index - query_start_index ))
732+ seqlen = query_end_index - query_start_index
733+ x_offset = query_start_index * stride_x_token
734+ o_offset = query_start_index * stride_o_token
735+ else :
736+ query_start_index = idx_seq * seqlen
737+ query_end_index = query_start_index + seqlen
738+ x_offset = idx_seq * stride_x_seq
739+ o_offset = idx_seq * stride_o_seq
719740
720741 # STEP 2: assume state_len > seqlen
721742 idx_tokens = tl .arange (0 , NP2_STATELEN ) # [BLOCK_M]
@@ -735,8 +756,7 @@ def _causal_conv1d_update_kernel(
735756 conv_state = tl .load (conv_state_ptrs_source , mask , other = 0.0 )
736757
737758 VAL = state_len - seqlen
738- x_base = x_ptr + (idx_seq * stride_x_seq ) + (idx_feats * stride_x_dim
739- ) # [BLOCK_N]
759+ x_base = x_ptr + x_offset + (idx_feats * stride_x_dim ) # [BLOCK_N]
740760
741761 x_ptrs = x_base [None , :] + (
742762 (idx_tokens - VAL ) * stride_x_token )[:, None ] # [BLOCK_M, BLOCK_N]
@@ -782,12 +802,18 @@ def _causal_conv1d_update_kernel(
782802 if KERNEL_WIDTH >= 4 :
783803 w_ptrs = w_base + (3 * stride_w_width ) # [BLOCK_N] tensor
784804 w_col3 = tl .load (w_ptrs , mask_w , other = 0.0 )
805+ if KERNEL_WIDTH >= 5 :
806+ w_ptrs = w_base + (4 * stride_w_width ) # [BLOCK_N] tensor
807+ w_col4 = tl .load (w_ptrs , mask_w , other = 0.0 )
808+ if KERNEL_WIDTH >= 6 :
809+ w_ptrs = w_base + (5 * stride_w_width ) # [BLOCK_N] tensor
810+ w_col5 = tl .load (w_ptrs , mask_w , other = 0.0 )
785811
786812 x_base_1d = x_base # starting of chunk [BLOCK_N]
787813 mask_x_1d = idx_feats < dim
788814
789815 # STEP 5: compute each token
790- for idx_token in tl .static_range (seqlen ):
816+ for idx_token in tl .range (seqlen ):
791817 acc = acc_preload
792818
793819 matrix_w = w_col0
@@ -817,6 +843,37 @@ def _causal_conv1d_update_kernel(
817843 matrix_w = w_col3
818844 x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
819845 matrix_x = tl .load (x_ptrs_1d , mask = mask_x_1d )
846+ elif KERNEL_WIDTH == 5 :
847+ if j == 1 :
848+ matrix_w = w_col1
849+ matrix_x = col1
850+ elif j == 2 :
851+ matrix_w = w_col2
852+ matrix_x = col2
853+ elif j == 3 :
854+ matrix_w = w_col3
855+ matrix_x = col3
856+ elif j == 4 :
857+ matrix_w = w_col4
858+ x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
859+ matrix_x = tl .load (x_ptrs_1d , mask = mask_x_1d )
860+ elif KERNEL_WIDTH == 6 :
861+ if j == 1 :
862+ matrix_w = w_col1
863+ matrix_x = col1
864+ elif j == 2 :
865+ matrix_w = w_col2
866+ matrix_x = col2
867+ elif j == 3 :
868+ matrix_w = w_col3
869+ matrix_x = col3
870+ elif j == 4 :
871+ matrix_w = w_col4
872+ matrix_x = col4
873+ elif j == 5 :
874+ matrix_w = w_col5
875+ x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N]
876+ matrix_x = tl .load (x_ptrs_1d , mask = mask_x_1d )
820877
821878 acc += matrix_x * matrix_w # [BLOCK_N]
822879
@@ -829,14 +886,24 @@ def _causal_conv1d_update_kernel(
829886 col0 = col1
830887 col1 = col2
831888 col2 = matrix_x
889+ elif KERNEL_WIDTH == 5 :
890+ col0 = col1
891+ col1 = col2
892+ col2 = col3
893+ col3 = matrix_x
894+ elif KERNEL_WIDTH == 6 :
895+ col0 = col1
896+ col1 = col2
897+ col2 = col3
898+ col3 = col4
899+ col4 = matrix_x
832900
833901 if SILU_ACTIVATION :
834902 acc = acc / (1 + tl .exp (- acc ))
835903 mask_1d = (idx_token < seqlen ) & (idx_feats < dim
836904 ) # token-index # feature-index
837- o_ptrs = o_ptr + (
838- idx_seq ) * stride_o_seq + idx_token * stride_o_token + (
839- idx_feats * stride_o_dim )
905+ o_ptrs = o_ptr + o_offset + idx_token * stride_o_token + (idx_feats *
906+ stride_o_dim )
840907
841908 tl .store (o_ptrs , acc , mask = mask_1d )
842909
@@ -850,14 +917,18 @@ def causal_conv1d_update(
850917 cache_seqlens : Optional [torch .Tensor ] = None ,
851918 conv_state_indices : Optional [torch .Tensor ] = None ,
852919 num_accepted_tokens : Optional [torch .Tensor ] = None ,
920+ query_start_loc : Optional [torch .Tensor ] = None ,
921+ max_query_len : int = - 1 ,
853922 pad_slot_id : int = PAD_SLOT_ID ,
854923 metadata = None ,
855924 validate_data = False ,
856925):
857926 """
858- x: (batch, dim) or (batch, dim, seqlen)
927+ x: (batch, dim) or (batch, dim, seqlen) or (num_tokens, dim)
859928 [shape=2: single token prediction]
860929 [shape=3: single or multiple tokens prediction]
930+ [shape=2 with num_tokens: continuous batching, where num_tokens is the
931+ total tokens of all sequences in that batch]
861932 conv_state: (..., dim, state_len), where state_len >= width - 1
862933 weight: (dim, width)
863934 bias: (dim,)
@@ -870,13 +941,24 @@ def causal_conv1d_update(
870941 If not None, the conv_state is a larger tensor along the batch dim,
871942 and we are selecting the batch coords specified by conv_state_indices.
872943 Useful for a continuous batching scenario.
944+ num_accepted_tokens: (batch,), dtype int32
945+ If not None, it indicates the number of accepted tokens for each
946+ sequence in the batch.
947+ This is used in speculative decoding, where the conv_state is updated
948+ in a sliding window manner.
949+ query_start_loc: (batch + 1,) int32
950+ If not None, the inputs is given in a varlen fashion and this indicates
951+ the starting index of each sequence in the batch.
952+ max_query_len: int
953+ If query_start_loc is not None, this indicates the maximum query
954+ length in the batch.
873955 pad_slot_id: int
874956 if cache_indices is passed, lets the kernel identify padded
875957 entries that will not be processed,
876958 for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id]
877959 in this case, the kernel will not process entries at
878960 indices 0 and 3
879- out: (batch, dim) or (batch, dim, seqlen)
961+ out: (batch, dim) or (batch, dim, seqlen) or (num_tokens, dim), same shape as `x`
880962 """
881963 if validate_data :
882964 assert cache_seqlens is None # not implemented yet - ok for vLLM
@@ -886,11 +968,17 @@ def causal_conv1d_update(
886968 activation = "silu" if activation is True else None
887969 elif activation is not None :
888970 assert activation in ["silu" , "swish" ]
889- unsqueeze = x .dim () == 2
971+ unsqueeze = query_start_loc is None and x .dim () == 2
890972 if unsqueeze :
891973 # make it (batch, dim, seqlen) with seqlen == 1
892974 x = x .unsqueeze (- 1 )
893- batch , dim , seqlen = x .shape
975+ if query_start_loc is None :
976+ batch , dim , seqlen = x .shape
977+ else :
978+ assert conv_state_indices is not None
979+ batch = conv_state_indices .size (0 )
980+ dim = x .size (1 )
981+ seqlen = max_query_len
894982 _ , width = weight .shape
895983 # conv_state: (..., dim, state_len), where state_len >= width - 1
896984 num_cache_lines , _ , state_len = conv_state .size ()
@@ -916,10 +1004,17 @@ def causal_conv1d_update(
9161004 out = x
9171005 stride_w_dim , stride_w_width = weight .stride ()
9181006
919- stride_x_seq , stride_x_dim , stride_x_token = x .stride (
920- ) # X (batch, dim, seqlen)
1007+ if query_start_loc is None :
1008+ # X (batch, dim, seqlen)
1009+ stride_x_seq , stride_x_dim , stride_x_token = x .stride ()
1010+ stride_o_seq , stride_o_dim , stride_o_token = out .stride ()
1011+ else :
1012+ # X (dim, cu_seqlen)
1013+ stride_x_token , stride_x_dim = x .stride ()
1014+ stride_x_seq = 0
1015+ stride_o_token , stride_o_dim = out .stride ()
1016+ stride_o_seq = 0
9211017
922- stride_o_seq , stride_o_dim , stride_o_token = out .stride ()
9231018 stride_istate_seq , stride_istate_dim , stride_istate_token = conv_state .stride (
9241019 )
9251020 stride_state_indices = conv_state_indices .stride (
@@ -945,6 +1040,7 @@ def grid(META):
9451040 cache_seqlens ,
9461041 conv_state_indices ,
9471042 num_accepted_tokens ,
1043+ query_start_loc ,
9481044 out ,
9491045 # Matrix dimensions
9501046 batch ,
@@ -971,6 +1067,7 @@ def grid(META):
9711067 HAS_BIAS = bias is not None ,
9721068 KERNEL_WIDTH = width ,
9731069 SILU_ACTIVATION = activation in ["silu" , "swish" ],
1070+ IS_VARLEN = query_start_loc is not None ,
9741071 IS_CONTINUOUS_BATCHING = conv_state_indices is not None ,
9751072 IS_SPEC_DECODING = num_accepted_tokens is not None ,
9761073 NP2_STATELEN = np2_statelen ,
0 commit comments