@@ -132,6 +132,7 @@ def __init__(
132132 else :
133133 self .sliding_window = (sliding_window - 1 , 0 )
134134 self .kv_cache_dtype = kv_cache_dtype
135+ self .use_irope = use_irope
135136 if logits_soft_cap is None :
136137 # In flash-attn, setting logits_soft_cap as 0 means no soft cap.
137138 logits_soft_cap = 0
@@ -204,19 +205,45 @@ def forward(
204205 layer ._k_scale_float ,
205206 layer ._v_scale_float ,
206207 )
208+ use_local_attn = \
209+ (self .use_irope and attn_metadata .local_attn_metadata is not None )
210+
211+ if use_local_attn :
212+ assert attn_metadata .local_attn_metadata is not None
213+ local_metadata = attn_metadata .local_attn_metadata
214+ cu_seqlens_q = local_metadata .local_query_start_loc
215+ sequesd_k = local_metadata .local_seqused_k
216+ max_seqlen_q = local_metadata .local_max_query_len
217+ max_seqlen_k = local_metadata .local_max_seq_len
218+ block_table = local_metadata .local_block_table
219+ else :
220+ cu_seqlens_q = attn_metadata .query_start_loc
221+ sequesd_k = attn_metadata .seq_lens
222+ max_seqlen_q = attn_metadata .max_query_len
223+ max_seqlen_k = attn_metadata .max_seq_len
224+ block_table = attn_metadata .block_table
225+
226+ if not hasattr (attn_metadata , "seq_start_loc" ):
227+ cumsum = torch .cumsum (sequesd_k , dim = 0 )
228+ cu_seqlens_k = torch .cat ([
229+ torch .tensor ([0 ], device = sequesd_k .device , dtype = torch .int32 ),
230+ cumsum
231+ ]).to (torch .int32 )
232+ else :
233+ cu_seqlens_k = attn_metadata .seq_start_loc
207234
208235 ipex_ops .flash_attn_varlen_func (
209236 output [:num_actual_tokens ],
210237 query [:num_actual_tokens ],
211238 key_cache ,
212239 value_cache ,
213- attn_metadata . query_start_loc ,
214- attn_metadata . seq_start_loc ,
215- attn_metadata . max_query_len ,
216- attn_metadata . max_seq_len ,
240+ cu_seqlens_q ,
241+ cu_seqlens_k ,
242+ max_seqlen_q ,
243+ max_seqlen_k ,
217244 self .scale ,
218245 is_casual = True ,
219- block_table = attn_metadata . block_table ,
246+ block_table = block_table ,
220247 alibi_slopes = self .alibi_slopes ,
221248 )
222249 return output
0 commit comments