@@ -68,14 +68,19 @@ class Mamba2AttentionMetadata:
6868 query_start_loc : torch .Tensor
6969 seq_lens : torch .Tensor
7070
71- has_initial_states : torch .Tensor
7271 prep_initial_states : bool
7372 chunk_size : int
74- seq_idx : torch .Tensor
75- chunk_indices : torch .Tensor
76- chunk_offsets : torch .Tensor
73+
74+ # The following tensors only contain prefill requests and will be None if
75+ # the batch has no prefill request.
76+ has_initial_states_p : Optional [torch .Tensor ]
77+ seq_idx_p : Optional [torch .Tensor ]
78+ chunk_indices_p : Optional [torch .Tensor ]
79+ chunk_offsets_p : Optional [torch .Tensor ]
7780
7881 state_indices_tensor : torch .Tensor # shape: [batch,]
82+
83+ # The following attributes are for triton implementation of causal_conv1d
7984 nums_dict : Optional [dict ] = None
8085 cu_seqlen : Optional [int ] = None
8186 batch_ptr : Optional [torch .tensor ] = None
@@ -115,11 +120,11 @@ def build(self,
115120 query_start_loc = common_attn_metadata .query_start_loc
116121 seq_lens = common_attn_metadata .seq_lens
117122
118- seq_idx = None
119- chunk_indices , chunk_offsets = None , None
123+ seq_idx_p = None
124+ chunk_indices_p , chunk_offsets_p = None , None
120125 # Need flags to indicate if there are initial states
121126 # currently we really only support the FlashAttention backend
122- has_initial_states = None
127+ has_initial_states_p = None
123128 prep_initial_states = False
124129
125130 state_indices_tensor = common_attn_metadata .block_table_tensor [:, 0 ]
@@ -135,25 +140,25 @@ def build(self,
135140 common_attn_metadata .
136141 num_computed_tokens_cpu [num_reqs - num_prefills :num_reqs ] > 0 )
137142 prep_initial_states = torch .any (has_initial_states_cpu ).item ()
138- has_initial_states = has_initial_states_cpu .to (
143+ has_initial_states_p = has_initial_states_cpu .to (
139144 query_start_loc .device )
140145
141146 query_start_loc_p = common_attn_metadata .query_start_loc [
142147 - num_prefills - 1 :] - num_decode_tokens
143148
144- seq_idx = torch .repeat_interleave (torch .arange (
149+ seq_idx_p = torch .repeat_interleave (torch .arange (
145150 num_prefills ,
146151 dtype = torch .int32 ,
147152 device = query_start_loc_p .device ),
148- query_start_loc_p .diff (),
149- output_size = num_prefill_tokens )
150- seq_idx .unsqueeze_ (0 )
153+ query_start_loc_p .diff (),
154+ output_size = num_prefill_tokens )
155+ seq_idx_p .unsqueeze_ (0 )
151156
152157 # We compute metadata for chunked prefill once at the top level
153158 # model forward and reuse them in mamba layers. If not needed,
154159 # they will be ignored inside mamba kernels.
155160 if prep_initial_states :
156- chunk_indices , chunk_offsets = (
161+ chunk_indices_p , chunk_offsets_p = (
157162 _query_start_loc_to_chunk_indices_offsets (
158163 query_start_loc_p , self .chunk_size ,
159164 num_prefill_tokens ))
@@ -173,12 +178,12 @@ def build(self,
173178 num_decode_tokens = num_decode_tokens ,
174179 query_start_loc = query_start_loc ,
175180 seq_lens = seq_lens ,
176- has_initial_states = has_initial_states ,
177181 prep_initial_states = prep_initial_states ,
178182 chunk_size = self .chunk_size ,
179- seq_idx = seq_idx ,
180- chunk_indices = chunk_indices ,
181- chunk_offsets = chunk_offsets ,
183+ has_initial_states_p = has_initial_states_p ,
184+ seq_idx_p = seq_idx_p ,
185+ chunk_indices_p = chunk_indices_p ,
186+ chunk_offsets_p = chunk_offsets_p ,
182187 state_indices_tensor = state_indices_tensor ,
183188 )
184189 return attn_metadata
0 commit comments