Skip to content

Commit 2dda1d9

Browse files
committed
fix local attention metadata
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
1 parent f4de631 commit 2dda1d9

File tree

1 file changed

+15
-1
lines changed

1 file changed

+15
-1
lines changed

vllm/v1/attention/backends/rocm_aiter_fa.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,13 +253,25 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
253253
max_seq_len=local_max_seq_len,
254254
causal=True)
255255

256+
local_cu_seq_lens = torch.zeros(virt_k_seqlens_np.shape[0] + 1,
257+
dtype=torch.int32,
258+
device=self.runner.device)
259+
local_cu_seq_lens[1:] = torch.cumsum(
260+
torch.from_numpy(virt_k_seqlens_np).to(
261+
device=self.runner.device,
262+
dtype=torch.int32,
263+
non_blocking=True),
264+
dim=0)
265+
266+
256267
local_attn_metadata = \
257268
AiterFlashAttentionMetadata.LocalAttentionMetadata(
258269
local_query_start_loc=local_query_start_loc,
259270
local_seqused_k=local_seqused_k,
260271
local_block_table=virt_block_table_tensor,
261272
local_max_query_len=local_max_query_len,
262273
local_max_seq_len=local_max_seq_len,
274+
local_cu_seq_lens=local_cu_seq_lens,
263275
local_scheduler_metadata=local_scheduler_metadata,
264276
)
265277

@@ -368,6 +380,7 @@ class LocalAttentionMetadata:
368380
local_block_table: torch.Tensor
369381
local_max_query_len: int
370382
local_max_seq_len: int
383+
local_cu_seq_lens: torch.Tensor
371384
local_scheduler_metadata: Optional[torch.Tensor]
372385

373386
local_attn_metadata: Optional[LocalAttentionMetadata] = None
@@ -546,7 +559,8 @@ def forward(
546559
alibi_slopes=self.alibi_slopes,
547560
window_size=self.sliding_window,
548561
block_table=block_table,
549-
cu_seqlens_k=cu_seq_lens,
562+
cu_seqlens_k=(cu_seq_lens if not use_local_attn else
563+
local_metadata.local_cu_seq_lens),
550564
)
551565

552566
_, num_heads, head_size = query.shape

0 commit comments

Comments
 (0)