Skip to content

Commit 6f56012

Browse files
committed
reuse allocate memory
Signed-off-by: fsx950223 <fsx950223@outlook.com>
1 parent bdc39b0 commit 6f56012

File tree

1 file changed

+28
-30
lines changed

1 file changed

+28
-30
lines changed

vllm/v1/attention/backends/rocm_aiter_fa.py

Lines changed: 28 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -94,17 +94,11 @@ def _vllm_layout_trans_kernel(
9494
tl.store(v_values_ptr + kv_values_off, v_vals, mask=block_mask)
9595

9696
def vllm_layout_trans(b_query_lens_loc, b_seq_lens_loc, block_table,
97-
k_buffer, v_buffer, max_seq_len, total_tokens,
97+
k_cache, v_cache, k_values, v_values, max_seq_len,
9898
k_scale, v_scale, output_dtype):
99-
H_KV = v_buffer.shape[2]
100-
D = v_buffer.shape[3]
101-
BLOCK_SIZE = v_buffer.shape[1]
102-
k_values = torch.empty((total_tokens, H_KV, D),
103-
dtype=output_dtype,
104-
device="cuda")
105-
v_values = torch.empty((total_tokens, H_KV, D),
106-
dtype=output_dtype,
107-
device="cuda")
99+
H_KV = v_cache.shape[2]
100+
D = v_cache.shape[3]
101+
BLOCK_SIZE = v_cache.shape[1]
108102

109103
grid = (block_table.shape[0],
110104
(max_seq_len + BLOCK_SIZE - 1) // BLOCK_SIZE)
@@ -116,8 +110,8 @@ def vllm_layout_trans(b_query_lens_loc, b_seq_lens_loc, block_table,
116110
else:
117111
raise ValueError(f"Unsupported output dtype: {output_dtype}")
118112

119-
_vllm_layout_trans_kernel[grid](k_buffer,
120-
v_buffer,
113+
_vllm_layout_trans_kernel[grid](k_cache,
114+
v_cache,
121115
k_values,
122116
v_values,
123117
b_query_lens_loc,
@@ -136,10 +130,11 @@ def flash_attn_varlen_func_impl(
136130
q: torch.Tensor,
137131
k_cache: torch.Tensor,
138132
v_cache: torch.Tensor,
133+
k_values: torch.Tensor,
134+
v_values: torch.Tensor,
139135
out: torch.Tensor,
140136
cu_seqlens_q: torch.Tensor,
141137
cu_seqlens_k: torch.Tensor,
142-
total_tokens: int,
143138
max_seqlen_q: int,
144139
max_seqlen_k: int,
145140
softmax_scale: float,
@@ -150,8 +145,8 @@ def flash_attn_varlen_func_impl(
150145
v_scale: torch.Tensor,
151146
) -> torch.Tensor:
152147
k, v = vllm_layout_trans(cu_seqlens_q, cu_seqlens_k, block_table,
153-
k_cache, v_cache, max_seqlen_k, total_tokens,
154-
k_scale, v_scale, q.dtype)
148+
k_cache, v_cache, k_values, v_values,
149+
max_seqlen_k, k_scale, v_scale, q.dtype)
155150
output = aiter.flash_attn_varlen_func(
156151
q=q,
157152
k=k,
@@ -173,10 +168,11 @@ def flash_attn_varlen_func_fake(
173168
q: torch.Tensor,
174169
k_cache: torch.Tensor,
175170
v_cache: torch.Tensor,
171+
k_values: torch.Tensor,
172+
v_values: torch.Tensor,
176173
out: torch.Tensor,
177174
cu_seqlens_q: torch.Tensor,
178175
cu_seqlens_k: torch.Tensor,
179-
total_tokens: int,
180176
max_seqlen_q: int,
181177
max_seqlen_k: int,
182178
softmax_scale: float,
@@ -216,17 +212,15 @@ class AiterFlashAttentionMetadata:
216212
max_seq_len: int
217213
seq_lens: torch.Tensor
218214
cu_seq_lens: torch.Tensor
219-
total_tokens: int
220215
block_table: torch.Tensor
221216
slot_mapping: torch.Tensor
222217
workspace_buffer: torch.Tensor
223218

224219
# For cascade attention.
225220
use_cascade: bool
226221
common_prefix_len: int
227-
cu_prefix_query_lens: Optional[torch.Tensor]
228-
prefix_kv_lens: Optional[torch.Tensor]
229-
suffix_kv_lens: Optional[torch.Tensor]
222+
k_buffer: torch.Tensor
223+
v_buffer: torch.Tensor
230224

231225
# for local attention
232226
@dataclass
@@ -351,10 +345,6 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
351345

352346
use_cascade = common_prefix_len > 0
353347

354-
cu_prefix_query_lens = None
355-
prefix_kv_lens = None
356-
suffix_kv_lens = None
357-
358348
nbytes_per_qo_elem = torch.finfo(self.runner.dtype).bits // 8
359349
max_num_partitions = (max_seq_len + _PARTITION_SIZE_ROCM -
360350
1) // _PARTITION_SIZE_ROCM
@@ -367,23 +357,31 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
367357
device=self.runner.device,
368358
)
369359

360+
k_buffer = torch.empty(
361+
(total_tokens, self.num_heads_kv, self.headdim),
362+
dtype=self.runner.dtype,
363+
device=self.runner.device,
364+
)
365+
v_buffer = torch.empty(
366+
(total_tokens, self.num_heads_kv, self.headdim),
367+
dtype=self.runner.dtype,
368+
device=self.runner.device,
369+
)
370370
attn_metadata = AiterFlashAttentionMetadata(
371371
num_actual_tokens=num_actual_tokens,
372372
max_query_len=max_query_len,
373373
query_start_loc=query_start_loc,
374374
max_seq_len=max_seq_len,
375375
seq_lens=seq_lens,
376376
cu_seq_lens=cu_seq_lens,
377-
total_tokens=total_tokens,
378377
block_table=block_table_tensor,
379378
slot_mapping=slot_mapping,
380379
use_cascade=use_cascade,
381380
workspace_buffer=workspace_buffer,
382381
common_prefix_len=common_prefix_len,
383-
cu_prefix_query_lens=cu_prefix_query_lens,
384-
prefix_kv_lens=prefix_kv_lens,
385-
suffix_kv_lens=suffix_kv_lens,
386382
local_attn_metadata=local_attn_metadata,
383+
k_buffer=k_buffer,
384+
v_buffer=v_buffer,
387385
)
388386
return attn_metadata
389387

@@ -585,16 +583,16 @@ def forward(
585583

586584
if max_seqlen_q > 1:
587585
cu_seq_lens = attn_metadata.cu_seq_lens
588-
total_tokens = attn_metadata.total_tokens
589586
torch.ops.vllm.flash_attn_varlen_func(
590587
query[:num_actual_tokens],
591588
key_cache,
592589
value_cache,
590+
attn_metadata.k_buffer,
591+
attn_metadata.v_buffer,
593592
out=output[:num_actual_tokens],
594593
cu_seqlens_q=cu_seqlens_q,
595594
max_seqlen_q=max_seqlen_q,
596595
max_seqlen_k=max_seqlen_k,
597-
total_tokens=total_tokens,
598596
softmax_scale=self.scale,
599597
alibi_slopes=self.alibi_slopes,
600598
window_size=self.sliding_window,

0 commit comments

Comments
 (0)