@@ -94,17 +94,11 @@ def _vllm_layout_trans_kernel(
9494 tl .store (v_values_ptr + kv_values_off , v_vals , mask = block_mask )
9595
9696 def vllm_layout_trans (b_query_lens_loc , b_seq_lens_loc , block_table ,
97- k_buffer , v_buffer , max_seq_len , total_tokens ,
97+ k_cache , v_cache , k_values , v_values , max_seq_len ,
9898 k_scale , v_scale , output_dtype ):
99- H_KV = v_buffer .shape [2 ]
100- D = v_buffer .shape [3 ]
101- BLOCK_SIZE = v_buffer .shape [1 ]
102- k_values = torch .empty ((total_tokens , H_KV , D ),
103- dtype = output_dtype ,
104- device = "cuda" )
105- v_values = torch .empty ((total_tokens , H_KV , D ),
106- dtype = output_dtype ,
107- device = "cuda" )
99+ H_KV = v_cache .shape [2 ]
100+ D = v_cache .shape [3 ]
101+ BLOCK_SIZE = v_cache .shape [1 ]
108102
109103 grid = (block_table .shape [0 ],
110104 (max_seq_len + BLOCK_SIZE - 1 ) // BLOCK_SIZE )
@@ -116,8 +110,8 @@ def vllm_layout_trans(b_query_lens_loc, b_seq_lens_loc, block_table,
116110 else :
117111 raise ValueError (f"Unsupported output dtype: { output_dtype } " )
118112
119- _vllm_layout_trans_kernel [grid ](k_buffer ,
120- v_buffer ,
113+ _vllm_layout_trans_kernel [grid ](k_cache ,
114+ v_cache ,
121115 k_values ,
122116 v_values ,
123117 b_query_lens_loc ,
@@ -136,10 +130,11 @@ def flash_attn_varlen_func_impl(
136130 q : torch .Tensor ,
137131 k_cache : torch .Tensor ,
138132 v_cache : torch .Tensor ,
133+ k_values : torch .Tensor ,
134+ v_values : torch .Tensor ,
139135 out : torch .Tensor ,
140136 cu_seqlens_q : torch .Tensor ,
141137 cu_seqlens_k : torch .Tensor ,
142- total_tokens : int ,
143138 max_seqlen_q : int ,
144139 max_seqlen_k : int ,
145140 softmax_scale : float ,
@@ -150,8 +145,8 @@ def flash_attn_varlen_func_impl(
150145 v_scale : torch .Tensor ,
151146 ) -> torch .Tensor :
152147 k , v = vllm_layout_trans (cu_seqlens_q , cu_seqlens_k , block_table ,
153- k_cache , v_cache , max_seqlen_k , total_tokens ,
154- k_scale , v_scale , q .dtype )
148+ k_cache , v_cache , k_values , v_values ,
149+ max_seqlen_k , k_scale , v_scale , q .dtype )
155150 output = aiter .flash_attn_varlen_func (
156151 q = q ,
157152 k = k ,
@@ -173,10 +168,11 @@ def flash_attn_varlen_func_fake(
173168 q : torch .Tensor ,
174169 k_cache : torch .Tensor ,
175170 v_cache : torch .Tensor ,
171+ k_values : torch .Tensor ,
172+ v_values : torch .Tensor ,
176173 out : torch .Tensor ,
177174 cu_seqlens_q : torch .Tensor ,
178175 cu_seqlens_k : torch .Tensor ,
179- total_tokens : int ,
180176 max_seqlen_q : int ,
181177 max_seqlen_k : int ,
182178 softmax_scale : float ,
@@ -216,17 +212,15 @@ class AiterFlashAttentionMetadata:
216212 max_seq_len : int
217213 seq_lens : torch .Tensor
218214 cu_seq_lens : torch .Tensor
219- total_tokens : int
220215 block_table : torch .Tensor
221216 slot_mapping : torch .Tensor
222217 workspace_buffer : torch .Tensor
223218
224219 # For cascade attention.
225220 use_cascade : bool
226221 common_prefix_len : int
227- cu_prefix_query_lens : Optional [torch .Tensor ]
228- prefix_kv_lens : Optional [torch .Tensor ]
229- suffix_kv_lens : Optional [torch .Tensor ]
222+ k_buffer : torch .Tensor
223+ v_buffer : torch .Tensor
230224
231225 # for local attention
232226 @dataclass
@@ -351,10 +345,6 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
351345
352346 use_cascade = common_prefix_len > 0
353347
354- cu_prefix_query_lens = None
355- prefix_kv_lens = None
356- suffix_kv_lens = None
357-
358348 nbytes_per_qo_elem = torch .finfo (self .runner .dtype ).bits // 8
359349 max_num_partitions = (max_seq_len + _PARTITION_SIZE_ROCM -
360350 1 ) // _PARTITION_SIZE_ROCM
@@ -367,23 +357,31 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
367357 device = self .runner .device ,
368358 )
369359
360+ k_buffer = torch .empty (
361+ (total_tokens , self .num_heads_kv , self .headdim ),
362+ dtype = self .runner .dtype ,
363+ device = self .runner .device ,
364+ )
365+ v_buffer = torch .empty (
366+ (total_tokens , self .num_heads_kv , self .headdim ),
367+ dtype = self .runner .dtype ,
368+ device = self .runner .device ,
369+ )
370370 attn_metadata = AiterFlashAttentionMetadata (
371371 num_actual_tokens = num_actual_tokens ,
372372 max_query_len = max_query_len ,
373373 query_start_loc = query_start_loc ,
374374 max_seq_len = max_seq_len ,
375375 seq_lens = seq_lens ,
376376 cu_seq_lens = cu_seq_lens ,
377- total_tokens = total_tokens ,
378377 block_table = block_table_tensor ,
379378 slot_mapping = slot_mapping ,
380379 use_cascade = use_cascade ,
381380 workspace_buffer = workspace_buffer ,
382381 common_prefix_len = common_prefix_len ,
383- cu_prefix_query_lens = cu_prefix_query_lens ,
384- prefix_kv_lens = prefix_kv_lens ,
385- suffix_kv_lens = suffix_kv_lens ,
386382 local_attn_metadata = local_attn_metadata ,
383+ k_buffer = k_buffer ,
384+ v_buffer = v_buffer ,
387385 )
388386 return attn_metadata
389387
@@ -585,16 +583,16 @@ def forward(
585583
586584 if max_seqlen_q > 1 :
587585 cu_seq_lens = attn_metadata .cu_seq_lens
588- total_tokens = attn_metadata .total_tokens
589586 torch .ops .vllm .flash_attn_varlen_func (
590587 query [:num_actual_tokens ],
591588 key_cache ,
592589 value_cache ,
590+ attn_metadata .k_buffer ,
591+ attn_metadata .v_buffer ,
593592 out = output [:num_actual_tokens ],
594593 cu_seqlens_q = cu_seqlens_q ,
595594 max_seqlen_q = max_seqlen_q ,
596595 max_seqlen_k = max_seqlen_k ,
597- total_tokens = total_tokens ,
598596 softmax_scale = self .scale ,
599597 alibi_slopes = self .alibi_slopes ,
600598 window_size = self .sliding_window ,
0 commit comments