@@ -99,6 +99,13 @@ def get_kv_cache_stride_order() -> tuple[int, ...]:
9999 raise ValueError (f"Unknown cache layout format { cache_layout } ." )
100100 return stride_order
101101
102+ @staticmethod
103+ def get_fp8_dtype_for_flashattn (kv_cache_dtype : str ) -> torch .dtype :
104+ if kv_cache_dtype in ("fp8" , "fp8_e4m3" ):
105+ return torch .float8_e4m3fn
106+ else :
107+ raise ValueError (f"Unrecognized FP8 dtype: { kv_cache_dtype } " )
108+
102109
103110@dataclass
104111class FlashAttentionMetadata :
@@ -161,6 +168,7 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
161168 self .parallel_config )
162169 self .num_heads_kv = self .model_config .get_num_kv_heads (
163170 self .parallel_config )
171+ self .kv_cache_dtype = kv_cache_spec .dtype
164172 self .headdim = self .model_config .get_head_size ()
165173 self .block_size = kv_cache_spec .block_size
166174
@@ -239,17 +247,24 @@ def build(self,
239247
240248 def schedule (batch_size , cu_query_lens , max_query_len , seqlens ,
241249 max_seq_len , causal ):
250+ cache_dtype = self .cache_config .cache_dtype
251+ if cache_dtype .startswith ("fp8" ):
252+ qkv_dtype = FlashAttentionBackend .get_fp8_dtype_for_flashattn (
253+ cache_dtype )
254+ else :
255+ qkv_dtype = self .kv_cache_dtype
242256 if aot_schedule :
243257 return get_scheduler_metadata (
244258 batch_size = batch_size ,
245259 max_seqlen_q = max_query_len ,
246260 max_seqlen_k = max_seq_len ,
247- cache_seqlens = seqlens ,
248261 num_heads_q = self .num_heads_q ,
249262 num_heads_kv = self .num_heads_kv ,
250263 headdim = self .headdim ,
251- page_size = self .block_size ,
264+ cache_seqlens = seqlens ,
265+ qkv_dtype = qkv_dtype ,
252266 cu_seqlens_q = cu_query_lens ,
267+ page_size = self .block_size ,
253268 causal = causal ,
254269 window_size = self .aot_sliding_window ,
255270 num_splits = self .max_num_splits ,
@@ -474,8 +489,10 @@ def forward(
474489 )
475490
476491 if self .kv_cache_dtype .startswith ("fp8" ):
477- key_cache = key_cache .view (torch .float8_e4m3fn )
478- value_cache = value_cache .view (torch .float8_e4m3fn )
492+ dtype = FlashAttentionBackend .get_fp8_dtype_for_flashattn (
493+ self .kv_cache_dtype )
494+ key_cache = key_cache .view (dtype )
495+ value_cache = value_cache .view (dtype )
479496 num_tokens , num_heads , head_size = query .shape
480497 query , _ = ops .scaled_fp8_quant (
481498 query .reshape (
0 commit comments