1313from vllm .config import VllmConfig
1414from vllm .logger import init_logger
1515from vllm .platforms import current_platform
16- from vllm .v1 .attention .backends .flash_attn import (
17- make_local_attention_virtual_batches )
1816from vllm .v1 .attention .backends .utils import CommonAttentionMetadata
1917from vllm .v1 .kv_cache_interface import AttentionSpec
2018
@@ -201,9 +199,7 @@ def build(self,
201199 max_seq_len = int (common_attn_metadata .seq_lens_cpu .max ())
202200 total_tokens = int (common_attn_metadata .seq_lens_cpu .sum ())
203201 query_start_loc = common_attn_metadata .query_start_loc
204- query_start_loc_cpu = common_attn_metadata .query_start_loc_cpu
205202 seq_lens = common_attn_metadata .seq_lens
206- seq_lens_cpu = common_attn_metadata .seq_lens_cpu
207203 block_table_tensor = common_attn_metadata .block_table_tensor
208204 slot_mapping = common_attn_metadata .slot_mapping
209205
@@ -215,56 +211,6 @@ def build(self,
215211 dtype = cu_seq_lens .dtype ,
216212 out = cu_seq_lens [1 :])
217213
218- def schedule (batch_size , cu_query_lens , max_query_len , seqlens ,
219- max_seq_len , causal ):
220- return None
221-
222- # for local attention
223- local_attn_metadata = None
224- if self .model_config .attention_chunk_size is not None :
225- seqlens_q_local_np , virt_q_cu_seqlens_np , virt_k_seqlens_np , \
226- virt_block_table_tensor = make_local_attention_virtual_batches (
227- self .model_config .attention_chunk_size ,
228- query_start_loc_cpu .numpy (),
229- seq_lens_cpu .numpy (),
230- block_table_tensor ,
231- self .block_size ,
232- )
233- local_query_start_loc = torch .from_numpy (virt_q_cu_seqlens_np ).to (
234- self .device , non_blocking = True )
235- local_seqused_k = torch .from_numpy (virt_k_seqlens_np ).to (
236- self .device , non_blocking = True )
237- local_max_query_len = seqlens_q_local_np .max ().item ()
238- local_max_seq_len = virt_k_seqlens_np .max ().item ()
239- local_scheduler_metadata = schedule (
240- batch_size = local_query_start_loc .shape [0 ] - 1 ,
241- cu_query_lens = local_query_start_loc ,
242- max_query_len = local_max_query_len ,
243- seqlens = local_seqused_k ,
244- max_seq_len = local_max_seq_len ,
245- causal = True )
246-
247- local_cu_seq_lens = torch .zeros (virt_k_seqlens_np .shape [0 ] + 1 ,
248- dtype = torch .int32 ,
249- device = self .device )
250- local_cu_seq_lens [1 :] = torch .cumsum (
251- torch .from_numpy (virt_k_seqlens_np ).to (device = self .device ,
252- dtype = torch .int32 ,
253- non_blocking = True ),
254- dim = 0 )
255-
256-
257- local_attn_metadata = \
258- AiterFlashAttentionMetadata .LocalAttentionMetadata (
259- local_query_start_loc = local_query_start_loc ,
260- local_seqused_k = local_seqused_k ,
261- local_block_table = virt_block_table_tensor ,
262- local_max_query_len = local_max_query_len ,
263- local_max_seq_len = local_max_seq_len ,
264- local_cu_seq_lens = local_cu_seq_lens ,
265- local_scheduler_metadata = local_scheduler_metadata ,
266- )
267-
268214 use_cascade = common_prefix_len > 0
269215
270216 cu_prefix_query_lens = None
@@ -286,7 +232,6 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
286232 cu_prefix_query_lens = cu_prefix_query_lens ,
287233 prefix_kv_lens = prefix_kv_lens ,
288234 suffix_kv_lens = suffix_kv_lens ,
289- local_attn_metadata = local_attn_metadata ,
290235 )
291236 return attn_metadata
292237
@@ -377,19 +322,6 @@ class AiterFlashAttentionMetadata:
377322 prefix_kv_lens : Optional [torch .Tensor ]
378323 suffix_kv_lens : Optional [torch .Tensor ]
379324
380- # for local attention
381- @dataclass
382- class LocalAttentionMetadata :
383- local_query_start_loc : torch .Tensor
384- local_seqused_k : torch .Tensor
385- local_block_table : torch .Tensor
386- local_max_query_len : int
387- local_max_seq_len : int
388- local_cu_seq_lens : torch .Tensor
389- local_scheduler_metadata : Optional [torch .Tensor ]
390-
391- local_attn_metadata : Optional [LocalAttentionMetadata ] = None
392-
393325
394326class AiterFlashAttentionImpl (AttentionImpl ):
395327
@@ -521,25 +453,12 @@ def forward(
521453 layer ._q_scale )
522454 query = query .reshape ((num_tokens , num_heads , head_size ))
523455
524- # Compute attention and update output up to `num_actual_tokens`.
525- use_local_attn = \
526- (self .use_irope and attn_metadata .local_attn_metadata is not None )
527-
528- if not attn_metadata .use_cascade or use_local_attn :
529- if use_local_attn :
530- assert attn_metadata .local_attn_metadata is not None
531- local_metadata = attn_metadata .local_attn_metadata
532- cu_seqlens_q = local_metadata .local_query_start_loc
533- seqused_k = local_metadata .local_seqused_k
534- max_seqlen_q = local_metadata .local_max_query_len
535- max_seqlen_k = local_metadata .local_max_seq_len
536- block_table = local_metadata .local_block_table
537- else :
538- cu_seqlens_q = attn_metadata .query_start_loc
539- seqused_k = attn_metadata .seq_lens
540- max_seqlen_q = attn_metadata .max_query_len
541- max_seqlen_k = attn_metadata .max_seq_len
542- block_table = attn_metadata .block_table
456+ if not attn_metadata .use_cascade :
457+ cu_seqlens_q = attn_metadata .query_start_loc
458+ seqused_k = attn_metadata .seq_lens
459+ max_seqlen_q = attn_metadata .max_query_len
460+ max_seqlen_k = attn_metadata .max_seq_len
461+ block_table = attn_metadata .block_table
543462
544463 if max_seqlen_q > 1 :
545464 cu_seq_lens = attn_metadata .cu_seq_lens
@@ -557,9 +476,7 @@ def forward(
557476 alibi_slopes = self .alibi_slopes ,
558477 window_size = self .sliding_window ,
559478 block_table = block_table ,
560- cu_seqlens_k = (cu_seq_lens if not use_local_attn else
561- local_metadata .local_cu_seq_lens ),
562- )
479+ cu_seqlens_k = cu_seq_lens )
563480
564481 _ , num_heads , head_size = query .shape
565482 _PARTITION_SIZE_ROCM = 256
0 commit comments