@@ -626,6 +626,7 @@ def _causal_conv1d_update_kernel(
626626    cache_seqlens_ptr ,  # circular buffer 
627627    conv_state_indices_ptr ,
628628    num_accepted_tokens_ptr ,
629+     query_start_loc_ptr ,  # (batch + 1) 
629630    o_ptr ,  # (batch, dim, seqlen) 
630631    # Matrix dimensions 
631632    batch : int ,
@@ -652,6 +653,7 @@ def _causal_conv1d_update_kernel(
652653    HAS_BIAS : tl .constexpr ,
653654    KERNEL_WIDTH : tl .constexpr ,
654655    SILU_ACTIVATION : tl .constexpr ,
656+     IS_VARLEN : tl .constexpr ,
655657    IS_CONTINUOUS_BATCHING : tl .constexpr ,
656658    IS_SPEC_DECODING : tl .constexpr ,
657659    NP2_STATELEN : tl .constexpr ,
@@ -678,6 +680,25 @@ def _causal_conv1d_update_kernel(
678680            # not processing as this is not the actual sequence 
679681            return 
680682
683+     if  IS_VARLEN :
684+         query_start_index  =  tl .load (query_start_loc_ptr  +  idx_seq ).to (tl .int64 )
685+         query_end_index  =  tl .load (query_start_loc_ptr  +  (idx_seq  +  1 )).to (
686+             tl .int64 )
687+         # revise state_len and seqlen 
688+         state_len  =  state_len  -  (seqlen  - 
689+                                  (query_end_index  -  query_start_index ))
690+         seqlen  =  query_end_index  -  query_start_index 
691+         x_offset  =  query_start_index  *  stride_x_token 
692+         o_offset  =  query_start_index  *  stride_o_token 
693+     else :
694+         query_start_index  =  idx_seq  *  seqlen 
695+         query_end_index  =  query_start_index  +  seqlen 
696+         x_offset  =  idx_seq  *  stride_x_seq 
697+         o_offset  =  idx_seq  *  stride_o_seq 
698+ 
699+     if  query_start_index  ==  query_end_index :
700+         return 
701+ 
681702    if  IS_SPEC_DECODING :
682703        # The rolling of conv state: 
683704        # 
@@ -692,8 +713,8 @@ def _causal_conv1d_update_kernel(
692713        # - accept 1 tokens: [history2, ..., historyM, draft1] 
693714        # - accept 2 tokens: [history3, ..., historyM, draft1, draft2] 
694715        # - and so on. 
695-         conv_state_token_offset  =  (tl . load ( num_accepted_tokens_ptr   +   idx_seq )  - 
696-                                     1 )
716+         conv_state_token_offset  =  (
717+             tl . load ( num_accepted_tokens_ptr   +   idx_seq ). to ( tl . int64 )  -  1 )
697718    else :
698719        conv_state_token_offset  =  0 
699720
@@ -713,9 +734,12 @@ def _causal_conv1d_update_kernel(
713734    if  KERNEL_WIDTH  >=  4 :
714735        conv_states_ptrs  =  prior_tokens  +  2  *  stride_conv_state_tok   # [BLOCK_N] 
715736        col2  =  tl .load (conv_states_ptrs , mask_w , 0.0 )
716-     if  KERNEL_WIDTH  = =5 :
737+     if  KERNEL_WIDTH  > =5 :
717738        conv_states_ptrs  =  prior_tokens  +  3  *  stride_conv_state_tok   # [BLOCK_N] 
718739        col3  =  tl .load (conv_states_ptrs , mask_w , 0.0 )
740+     if  KERNEL_WIDTH  >=  6 :
741+         conv_states_ptrs  =  prior_tokens  +  4  *  stride_conv_state_tok   # [BLOCK_N] 
742+         col4  =  tl .load (conv_states_ptrs , mask_w , 0.0 )
719743
720744    # STEP 2: assume state_len > seqlen 
721745    idx_tokens  =  tl .arange (0 , NP2_STATELEN )  # [BLOCK_M] 
@@ -735,8 +759,7 @@ def _causal_conv1d_update_kernel(
735759    conv_state  =  tl .load (conv_state_ptrs_source , mask , other = 0.0 )
736760
737761    VAL  =  state_len  -  seqlen 
738-     x_base  =  x_ptr  +  (idx_seq  *  stride_x_seq ) +  (idx_feats  *  stride_x_dim 
739-                                                  )  # [BLOCK_N] 
762+     x_base  =  x_ptr  +  x_offset  +  (idx_feats  *  stride_x_dim )  # [BLOCK_N] 
740763
741764    x_ptrs  =  x_base [None , :] +  (
742765        (idx_tokens  -  VAL ) *  stride_x_token )[:, None ]  # [BLOCK_M, BLOCK_N] 
@@ -782,12 +805,18 @@ def _causal_conv1d_update_kernel(
782805    if  KERNEL_WIDTH  >=  4 :
783806        w_ptrs  =  w_base  +  (3  *  stride_w_width )  # [BLOCK_N] tensor 
784807        w_col3  =  tl .load (w_ptrs , mask_w , other = 0.0 )
808+     if  KERNEL_WIDTH  >=  5 :
809+         w_ptrs  =  w_base  +  (4  *  stride_w_width )  # [BLOCK_N] tensor 
810+         w_col4  =  tl .load (w_ptrs , mask_w , other = 0.0 )
811+     if  KERNEL_WIDTH  >=  6 :
812+         w_ptrs  =  w_base  +  (5  *  stride_w_width )  # [BLOCK_N] tensor 
813+         w_col5  =  tl .load (w_ptrs , mask_w , other = 0.0 )
785814
786815    x_base_1d  =  x_base   # starting of chunk [BLOCK_N] 
787816    mask_x_1d  =  idx_feats  <  dim 
788817
789818    # STEP 5: compute each token 
790-     for  idx_token  in  tl .static_range (seqlen ):
819+     for  idx_token  in  tl .range (seqlen ):
791820        acc  =  acc_preload 
792821
793822        matrix_w  =  w_col0 
@@ -817,6 +846,37 @@ def _causal_conv1d_update_kernel(
817846                    matrix_w  =  w_col3 
818847                    x_ptrs_1d  =  x_base_1d  +  idx_token  *  stride_x_token   # [BLOCK_N] 
819848                    matrix_x  =  tl .load (x_ptrs_1d , mask = mask_x_1d )
849+             elif  KERNEL_WIDTH  ==  5 :
850+                 if  j  ==  1 :
851+                     matrix_w  =  w_col1 
852+                     matrix_x  =  col1 
853+                 elif  j  ==  2 :
854+                     matrix_w  =  w_col2 
855+                     matrix_x  =  col2 
856+                 elif  j  ==  3 :
857+                     matrix_w  =  w_col3 
858+                     matrix_x  =  col3 
859+                 elif  j  ==  4 :
860+                     matrix_w  =  w_col4 
861+                     x_ptrs_1d  =  x_base_1d  +  idx_token  *  stride_x_token   # [BLOCK_N] 
862+                     matrix_x  =  tl .load (x_ptrs_1d , mask = mask_x_1d )
863+             elif  KERNEL_WIDTH  ==  6 :
864+                 if  j  ==  1 :
865+                     matrix_w  =  w_col1 
866+                     matrix_x  =  col1 
867+                 elif  j  ==  2 :
868+                     matrix_w  =  w_col2 
869+                     matrix_x  =  col2 
870+                 elif  j  ==  3 :
871+                     matrix_w  =  w_col3 
872+                     matrix_x  =  col3 
873+                 elif  j  ==  4 :
874+                     matrix_w  =  w_col4 
875+                     matrix_x  =  col4 
876+                 elif  j  ==  5 :
877+                     matrix_w  =  w_col5 
878+                     x_ptrs_1d  =  x_base_1d  +  idx_token  *  stride_x_token   # [BLOCK_N] 
879+                     matrix_x  =  tl .load (x_ptrs_1d , mask = mask_x_1d )
820880
821881            acc  +=  matrix_x  *  matrix_w   # [BLOCK_N] 
822882
@@ -829,14 +889,24 @@ def _causal_conv1d_update_kernel(
829889            col0  =  col1 
830890            col1  =  col2 
831891            col2  =  matrix_x 
892+         elif  KERNEL_WIDTH  ==  5 :
893+             col0  =  col1 
894+             col1  =  col2 
895+             col2  =  col3 
896+             col3  =  matrix_x 
897+         elif  KERNEL_WIDTH  ==  6 :
898+             col0  =  col1 
899+             col1  =  col2 
900+             col2  =  col3 
901+             col3  =  col4 
902+             col4  =  matrix_x 
832903
833904        if  SILU_ACTIVATION :
834905            acc  =  acc  /  (1  +  tl .exp (- acc ))
835906        mask_1d  =  (idx_token  <  seqlen ) &  (idx_feats  <  dim 
836907                                          )  # token-index  # feature-index 
837-         o_ptrs  =  o_ptr  +  (
838-             idx_seq ) *  stride_o_seq  +  idx_token  *  stride_o_token  +  (
839-                 idx_feats  *  stride_o_dim )
908+         o_ptrs  =  o_ptr  +  o_offset  +  idx_token  *  stride_o_token  +  (idx_feats  * 
909+                                                                   stride_o_dim )
840910
841911        tl .store (o_ptrs , acc , mask = mask_1d )
842912
@@ -850,14 +920,18 @@ def causal_conv1d_update(
850920    cache_seqlens : Optional [torch .Tensor ] =  None ,
851921    conv_state_indices : Optional [torch .Tensor ] =  None ,
852922    num_accepted_tokens : Optional [torch .Tensor ] =  None ,
923+     query_start_loc : Optional [torch .Tensor ] =  None ,
924+     max_query_len : int  =  - 1 ,
853925    pad_slot_id : int  =  PAD_SLOT_ID ,
854926    metadata = None ,
855927    validate_data = False ,
856928):
857929    """ 
858-     x: (batch, dim) or (batch, dim, seqlen) 
930+     x: (batch, dim) or (batch, dim, seqlen) or (num_tokens, dim)  
859931        [shape=2: single token prediction] 
860932        [shape=3: single or multiple tokens prediction] 
933+         [shape=2 with num_tokens: continuous batching, where num_tokens is the 
934+                                   total tokens of all sequences in that batch] 
861935    conv_state: (..., dim, state_len), where state_len >= width - 1 
862936    weight: (dim, width) 
863937    bias: (dim,) 
@@ -870,13 +944,24 @@ def causal_conv1d_update(
870944        If not None, the conv_state is a larger tensor along the batch dim, 
871945        and we are selecting the batch coords specified by conv_state_indices. 
872946        Useful for a continuous batching scenario. 
947+     num_accepted_tokens: (batch,), dtype int32 
948+         If not None, it indicates the number of accepted tokens for each 
949+         sequence in the batch. 
950+         This is used in speculative decoding, where the conv_state is updated 
951+         in a sliding window manner. 
952+     query_start_loc: (batch + 1,) int32 
953+         If not None, the inputs is given in a varlen fashion and this indicates 
954+         the starting index of each sequence in the batch. 
955+     max_query_len: int 
956+         If query_start_loc is not None, this indicates the maximum query 
957+         length in the batch. 
873958    pad_slot_id: int 
874959            if cache_indices is passed, lets the kernel identify padded 
875960            entries that will not be processed, 
876961            for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id] 
877962            in this case, the kernel will not process entries at 
878963            indices 0 and 3 
879-     out: (batch, dim) or (batch, dim, seqlen) 
964+     out: (batch, dim) or (batch, dim, seqlen) or (num_tokens, dim), same shape as `x`  
880965    """ 
881966    if  validate_data :
882967        assert  cache_seqlens  is  None   # not implemented yet - ok for vLLM 
@@ -886,11 +971,17 @@ def causal_conv1d_update(
886971        activation  =  "silu"  if  activation  is  True  else  None 
887972    elif  activation  is  not None :
888973        assert  activation  in  ["silu" , "swish" ]
889-     unsqueeze  =  x .dim () ==  2 
974+     unsqueeze  =  query_start_loc   is   None   and   x .dim () ==  2 
890975    if  unsqueeze :
891976        # make it (batch, dim, seqlen) with seqlen == 1 
892977        x  =  x .unsqueeze (- 1 )
893-     batch , dim , seqlen  =  x .shape 
978+     if  query_start_loc  is  None :
979+         batch , dim , seqlen  =  x .shape 
980+     else :
981+         assert  conv_state_indices  is  not None 
982+         batch  =  conv_state_indices .size (0 )
983+         dim  =  x .size (1 )
984+         seqlen  =  max_query_len 
894985    _ , width  =  weight .shape 
895986    # conv_state: (..., dim, state_len), where state_len >= width - 1 
896987    num_cache_lines , _ , state_len  =  conv_state .size ()
@@ -916,10 +1007,17 @@ def causal_conv1d_update(
9161007    out  =  x 
9171008    stride_w_dim , stride_w_width  =  weight .stride ()
9181009
919-     stride_x_seq , stride_x_dim , stride_x_token  =  x .stride (
920-     )  # X (batch, dim, seqlen) 
1010+     if  query_start_loc  is  None :
1011+         # X (batch, dim, seqlen) 
1012+         stride_x_seq , stride_x_dim , stride_x_token  =  x .stride ()
1013+         stride_o_seq , stride_o_dim , stride_o_token  =  out .stride ()
1014+     else :
1015+         # X (dim, cu_seqlen) 
1016+         stride_x_token , stride_x_dim  =  x .stride ()
1017+         stride_x_seq  =  0 
1018+         stride_o_token , stride_o_dim  =  out .stride ()
1019+         stride_o_seq  =  0 
9211020
922-     stride_o_seq , stride_o_dim , stride_o_token  =  out .stride ()
9231021    stride_istate_seq , stride_istate_dim , stride_istate_token  =  conv_state .stride (
9241022    )
9251023    stride_state_indices  =  conv_state_indices .stride (
@@ -945,6 +1043,7 @@ def grid(META):
9451043        cache_seqlens ,
9461044        conv_state_indices ,
9471045        num_accepted_tokens ,
1046+         query_start_loc ,
9481047        out ,
9491048        # Matrix dimensions 
9501049        batch ,
@@ -971,6 +1070,7 @@ def grid(META):
9711070        HAS_BIAS = bias  is  not None ,
9721071        KERNEL_WIDTH = width ,
9731072        SILU_ACTIVATION = activation  in  ["silu" , "swish" ],
1073+         IS_VARLEN = query_start_loc  is  not None ,
9741074        IS_CONTINUOUS_BATCHING = conv_state_indices  is  not None ,
9751075        IS_SPEC_DECODING = num_accepted_tokens  is  not None ,
9761076        NP2_STATELEN = np2_statelen ,
0 commit comments