Skip to content

Commit 764481f

Browse files
committed
fix full graph capture
Signed-off-by: fsx950223 <fsx950223@outlook.com>
1 parent 6f56012 commit 764481f

File tree

1 file changed

+14
-11
lines changed

1 file changed

+14
-11
lines changed

vllm/v1/attention/backends/rocm_aiter_fa.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -48,17 +48,19 @@ def _vllm_layout_trans_kernel(
4848
):
4949
batch_idx = tl.program_id(0)
5050
block_idx = tl.program_id(1)
51-
batch_token_indexes = tl.load(b_seq_lens_loc + batch_idx +
52-
tl.arange(0, 2))
53-
batch_token_start, batch_token_end = tl.split(batch_token_indexes)
54-
seq_len = batch_token_end - batch_token_start
5551

5652
batch_query_indexes = tl.load(b_query_lens_loc + batch_idx +
5753
tl.arange(0, 2))
5854
batch_query_start, batch_query_end = tl.split(batch_query_indexes)
5955
query_len = batch_query_end - batch_query_start
6056
if query_len <= 1:
6157
return
58+
59+
batch_token_indexes = tl.load(b_seq_lens_loc + batch_idx +
60+
tl.arange(0, 2))
61+
batch_token_start, batch_token_end = tl.split(batch_token_indexes)
62+
seq_len = batch_token_end - batch_token_start
63+
6264
if block_idx * BLOCK_SIZE < seq_len:
6365
block_mask = (block_idx * BLOCK_SIZE +
6466
tl.arange(0, BLOCK_SIZE)[:, None]) < seq_len
@@ -269,12 +271,13 @@ def build(self, common_prefix_len: int,
269271
max_query_len = common_attn_metadata.max_query_len
270272

271273
max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max())
272-
total_tokens = int(self.runner.seq_lens_np[:num_reqs].sum())
273274
query_start_loc = common_attn_metadata.query_start_loc
274275
seq_lens = common_attn_metadata.seq_lens
275276
block_table = self.block_table
276277
block_table_tensor = block_table.get_device_tensor()[:num_reqs]
277-
278+
query_lens = query_start_loc[1:] - query_start_loc[:-1]
279+
masked_seq_lens = torch.where(query_lens > 1, seq_lens,
280+
torch.zeros_like(seq_lens))
278281
block_table.slot_mapping[:num_actual_tokens].copy_(
279282
block_table.slot_mapping_cpu[:num_actual_tokens],
280283
non_blocking=True)
@@ -284,10 +287,10 @@ def build(self, common_prefix_len: int,
284287

285288
slot_mapping = block_table.slot_mapping[:num_actual_tokens]
286289

287-
cu_seq_lens = torch.zeros(seq_lens.shape[0] + 1,
290+
cu_seq_lens = torch.zeros(masked_seq_lens.shape[0] + 1,
288291
dtype=torch.int32,
289292
device="cuda")
290-
torch.cumsum(seq_lens,
293+
torch.cumsum(masked_seq_lens,
291294
dim=0,
292295
dtype=cu_seq_lens.dtype,
293296
out=cu_seq_lens[1:])
@@ -356,14 +359,14 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
356359
dtype=torch.uint8,
357360
device=self.runner.device,
358361
)
359-
362+
masked_total_tokens = cu_seq_lens[-1].item()
360363
k_buffer = torch.empty(
361-
(total_tokens, self.num_heads_kv, self.headdim),
364+
(masked_total_tokens, self.num_heads_kv, self.headdim),
362365
dtype=self.runner.dtype,
363366
device=self.runner.device,
364367
)
365368
v_buffer = torch.empty(
366-
(total_tokens, self.num_heads_kv, self.headdim),
369+
(masked_total_tokens, self.num_heads_kv, self.headdim),
367370
dtype=self.runner.dtype,
368371
device=self.runner.device,
369372
)

0 commit comments

Comments
 (0)