55
66import torch
77
8- from vllm import _custom_ops as ops
98from vllm .attention .backends .abstract import (AttentionBackend , AttentionImpl ,
10- AttentionMetadata , AttentionType ,
11- is_quantized_kv_cache )
9+ AttentionMetadata , AttentionType )
1210from vllm .logger import init_logger
1311from vllm .platforms import current_platform
1412from vllm .v1 .attention .backends .flash_attn import (
1715from vllm .v1 .kv_cache_interface import AttentionSpec
1816from vllm .v1 .worker .block_table import BlockTable
1917
18+ _PARTITION_SIZE_ROCM = 256
19+
2020if TYPE_CHECKING :
2121 from vllm .v1 .core .sched .output import SchedulerOutput
2222 from vllm .v1 .worker .gpu_input_batch import InputBatch
@@ -38,6 +38,9 @@ def _vllm_layout_trans_kernel(
3838 b_seq_lens_loc ,
3939 block_table ,
4040 block_table_stride_0 ,
41+ k_scale ,
42+ v_scale ,
43+ output_dtype : tl .constexpr ,
4144 E_DIM : tl .constexpr ,
4245 BLOCK_SIZE : tl .constexpr ,
4346 ):
@@ -59,16 +62,27 @@ def _vllm_layout_trans_kernel(
5962 tl .arange (0 , BLOCK_SIZE )[:, None ]) < seq_len
6063
6164 kv_idx = tl .load (block_table + batch_idx * block_table_stride_0 +
62- block_idx )
65+ block_idx ). to ( tl . int64 )
6366
6467 kv_buffer_off = kv_idx * BLOCK_SIZE * E_DIM + tl .arange (
6568 0 , BLOCK_SIZE )[:, None ] * E_DIM + tl .arange (0 , E_DIM )[None , :]
6669 k_vals = tl .load (k_buffer_ptr + kv_buffer_off ,
6770 mask = block_mask ,
6871 other = 0.0 )
72+ if k_vals .dtype .is_fp8 ():
73+ k_vals = (k_vals .to (tl .float32 ) *
74+ tl .load (k_scale )).to (output_dtype )
75+ else :
76+ k_vals = k_vals .to (output_dtype )
77+
6978 v_vals = tl .load (v_buffer_ptr + kv_buffer_off ,
7079 mask = block_mask ,
7180 other = 0.0 )
81+ if v_vals .dtype .is_fp8 ():
82+ v_vals = (v_vals .to (tl .float32 ) *
83+ tl .load (v_scale )).to (output_dtype )
84+ else :
85+ v_vals = v_vals .to (output_dtype )
7286
7387 kv_values_off = batch_token_start * E_DIM + \
7488 block_idx * BLOCK_SIZE * E_DIM + \
@@ -78,21 +92,28 @@ def _vllm_layout_trans_kernel(
7892 tl .store (v_values_ptr + kv_values_off , v_vals , mask = block_mask )
7993
8094 def vllm_layout_trans (b_query_lens_loc , b_seq_lens_loc , block_table ,
81- k_buffer , v_buffer , max_seq_len , total_tokens ):
95+ k_buffer , v_buffer , max_seq_len , total_tokens ,
96+ k_scale , v_scale , output_dtype ):
8297 H_KV = v_buffer .shape [2 ]
8398 D = v_buffer .shape [3 ]
8499 BLOCK_SIZE = v_buffer .shape [1 ]
85- dtype = k_buffer .dtype
86100 k_values = torch .empty ((total_tokens , H_KV , D ),
87- dtype = dtype ,
101+ dtype = output_dtype ,
88102 device = "cuda" )
89103 v_values = torch .empty ((total_tokens , H_KV , D ),
90- dtype = dtype ,
104+ dtype = output_dtype ,
91105 device = "cuda" )
92106
93107 grid = (block_table .shape [0 ],
94108 (max_seq_len + BLOCK_SIZE - 1 ) // BLOCK_SIZE )
95109
110+ if output_dtype == torch .float16 :
111+ output_dtype = tl .float16
112+ elif output_dtype == torch .bfloat16 :
113+ output_dtype = tl .bfloat16
114+ else :
115+ raise ValueError (f"Unsupported output dtype: { output_dtype } " )
116+
96117 _vllm_layout_trans_kernel [grid ](k_buffer ,
97118 v_buffer ,
98119 k_values ,
@@ -101,6 +122,9 @@ def vllm_layout_trans(b_query_lens_loc, b_seq_lens_loc, block_table,
101122 b_seq_lens_loc ,
102123 block_table ,
103124 block_table .stride (0 ),
125+ k_scale ,
126+ v_scale ,
127+ output_dtype = output_dtype ,
104128 E_DIM = H_KV * D ,
105129 BLOCK_SIZE = BLOCK_SIZE )
106130
@@ -120,9 +144,12 @@ def flash_attn_varlen_func_impl(
120144 window_size : Optional [list [int ]], # -1 means infinite context window
121145 alibi_slopes : Optional [list [float ]],
122146 block_table : torch .Tensor ,
147+ k_scale : torch .Tensor ,
148+ v_scale : torch .Tensor ,
123149 ) -> torch .Tensor :
124150 k , v = vllm_layout_trans (cu_seqlens_q , cu_seqlens_k , block_table ,
125- k_cache , v_cache , max_seqlen_k , total_tokens )
151+ k_cache , v_cache , max_seqlen_k , total_tokens ,
152+ k_scale , v_scale , q .dtype )
126153 output = aiter .flash_attn_varlen_func (
127154 q = q ,
128155 k = k ,
@@ -154,6 +181,8 @@ def flash_attn_varlen_func_fake(
154181 window_size : Optional [list [int ]], # -1 means infinite context window
155182 alibi_slopes : Optional [list [float ]],
156183 block_table : torch .Tensor ,
184+ k_scale : torch .Tensor ,
185+ v_scale : torch .Tensor ,
157186 ) -> torch .Tensor :
158187 return torch .empty (q .shape [0 ],
159188 q .shape [1 ],
@@ -184,7 +213,6 @@ def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec,
184213 self .block_size = kv_cache_spec .block_size
185214 self .kv_cache_spec = kv_cache_spec
186215 self .block_table = block_table
187-
188216 # Sliding window size to be used with the AOT scheduler will be
189217 # populated on first build() call.
190218 self .aot_sliding_window : Optional [tuple [int , int ]] = None
@@ -281,6 +309,18 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
281309 prefix_kv_lens = None
282310 suffix_kv_lens = None
283311
312+ nbyes_per_qo_elem = torch .finfo (self .runner .dtype ).bits // 8
313+ max_num_partitions = (max_seq_len + _PARTITION_SIZE_ROCM -
314+ 1 ) // _PARTITION_SIZE_ROCM
315+
316+ workspace_buffer = torch .empty (
317+ (num_reqs * self .num_heads_q * max_num_partitions * self .headdim ) *
318+ nbyes_per_qo_elem + 2 *
319+ (num_reqs * self .num_heads_q * max_num_partitions ) * 4 ,
320+ dtype = torch .uint8 ,
321+ device = self .runner .device ,
322+ )
323+
284324 attn_metadata = AiterFlashAttentionMetadata (
285325 num_actual_tokens = num_actual_tokens ,
286326 max_query_len = max_query_len ,
@@ -292,6 +332,7 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
292332 block_table = block_table_tensor ,
293333 slot_mapping = slot_mapping ,
294334 use_cascade = use_cascade ,
335+ workspace_buffer = workspace_buffer ,
295336 common_prefix_len = common_prefix_len ,
296337 cu_prefix_query_lens = cu_prefix_query_lens ,
297338 prefix_kv_lens = prefix_kv_lens ,
@@ -315,7 +356,7 @@ class AiterFlashAttentionBackend(AttentionBackend):
315356
316357 @staticmethod
317358 def get_supported_head_sizes () -> list [int ]:
318- return [32 , 64 , 96 , 128 , 160 , 192 , 224 , 256 ]
359+ return [64 , 128 , 256 ]
319360
320361 @staticmethod
321362 def get_name () -> str :
@@ -364,6 +405,7 @@ class AiterFlashAttentionMetadata:
364405 total_tokens : int
365406 block_table : torch .Tensor
366407 slot_mapping : torch .Tensor
408+ workspace_buffer : torch .Tensor
367409
368410 # For cascade attention.
369411 use_cascade : bool
@@ -442,10 +484,6 @@ def __init__(
442484 "are not implemented for "
443485 "FlashAttentionImpl" )
444486 self .use_irope = use_irope
445- if is_quantized_kv_cache (self .kv_cache_dtype ):
446- raise NotImplementedError (
447- "AiterFlashAttention does not support fp8 kv-cache on this "
448- "device." )
449487
450488 def forward (
451489 self ,
@@ -516,12 +554,6 @@ def forward(
516554 if self .kv_cache_dtype .startswith ("fp8" ):
517555 key_cache = key_cache .view (torch .float8_e4m3fnuz )
518556 value_cache = value_cache .view (torch .float8_e4m3fnuz )
519- num_tokens , num_heads , head_size = query .shape
520- query , _ = ops .scaled_fp8_quant (
521- query .reshape (
522- (num_tokens , num_heads * head_size )).contiguous (),
523- layer ._q_scale )
524- query = query .reshape ((num_tokens , num_heads , head_size ))
525557
526558 # Compute attention and update output up to `num_actual_tokens`.
527559 use_local_attn = \
@@ -559,28 +591,14 @@ def forward(
559591 alibi_slopes = self .alibi_slopes ,
560592 window_size = self .sliding_window ,
561593 block_table = block_table ,
562- cu_seqlens_k = (cu_seq_lens if not use_local_attn else
563- local_metadata .local_cu_seq_lens ),
594+ cu_seqlens_k = cu_seq_lens ,
595+ k_scale = layer ._k_scale ,
596+ v_scale = layer ._v_scale ,
564597 )
565598
566- _ , num_heads , head_size = query .shape
567- _PARTITION_SIZE_ROCM = 256
568- num_seqs = seqused_k .shape [0 ]
569- nbyes_per_qo_elem = torch .finfo (output .dtype ).bits // 8
570- max_num_partitions = (max_seqlen_k + _PARTITION_SIZE_ROCM -
571- 1 ) // _PARTITION_SIZE_ROCM
572-
573- workspace_buffer = torch .empty (
574- (num_seqs * num_heads * max_num_partitions * head_size ) *
575- nbyes_per_qo_elem + 2 *
576- (num_seqs * num_heads * max_num_partitions ) * 4 ,
577- dtype = torch .uint8 ,
578- device = output .device ,
579- )
580-
581- aiter .paged_attention_v1 (
599+ torch .ops .aiter .paged_attention_v1 (
582600 output [:num_actual_tokens ],
583- workspace_buffer ,
601+ attn_metadata . workspace_buffer ,
584602 query [:num_actual_tokens ],
585603 key_cache ,
586604 value_cache ,
0 commit comments