@@ -253,13 +253,25 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
253253                max_seq_len = local_max_seq_len ,
254254                causal = True )
255255
256+             local_cu_seq_lens  =  torch .zeros (virt_k_seqlens_np .shape [0 ] +  1 ,
257+                                             dtype = torch .int32 ,
258+                                             device = self .runner .device )
259+             local_cu_seq_lens [1 :] =  torch .cumsum (
260+                 torch .from_numpy (virt_k_seqlens_np ).to (
261+                     device = self .runner .device ,
262+                     dtype = torch .int32 ,
263+                     non_blocking = True ),
264+                 dim = 0 )
265+ 
266+ 
256267            local_attn_metadata  =  \
257268            AiterFlashAttentionMetadata .LocalAttentionMetadata (
258269                local_query_start_loc = local_query_start_loc ,
259270                local_seqused_k = local_seqused_k ,
260271                local_block_table = virt_block_table_tensor ,
261272                local_max_query_len = local_max_query_len ,
262273                local_max_seq_len = local_max_seq_len ,
274+                 local_cu_seq_lens = local_cu_seq_lens ,
263275                local_scheduler_metadata = local_scheduler_metadata ,
264276            )
265277
@@ -368,6 +380,7 @@ class LocalAttentionMetadata:
368380        local_block_table : torch .Tensor 
369381        local_max_query_len : int 
370382        local_max_seq_len : int 
383+         local_cu_seq_lens : torch .Tensor 
371384        local_scheduler_metadata : Optional [torch .Tensor ]
372385
373386    local_attn_metadata : Optional [LocalAttentionMetadata ] =  None 
@@ -546,7 +559,8 @@ def forward(
546559                    alibi_slopes = self .alibi_slopes ,
547560                    window_size = self .sliding_window ,
548561                    block_table = block_table ,
549-                     cu_seqlens_k = cu_seq_lens ,
562+                     cu_seqlens_k = (cu_seq_lens  if  not  use_local_attn  else 
563+                                   local_metadata .local_cu_seq_lens ),
550564                )
551565
552566            _ , num_heads , head_size  =  query .shape 
0 commit comments