Skip to content

Commit e1a7fe4

Browse files
[BugFix] fix: aot passes kvcache dtype information (#19750)
Signed-off-by: Mickael Seznec <mickael@mistral.ai>
1 parent 82de9b9 commit e1a7fe4

File tree

1 file changed

+21
-4
lines changed

1 file changed

+21
-4
lines changed

vllm/v1/attention/backends/flash_attn.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,13 @@ def get_kv_cache_stride_order() -> tuple[int, ...]:
9999
raise ValueError(f"Unknown cache layout format {cache_layout}.")
100100
return stride_order
101101

102+
@staticmethod
103+
def get_fp8_dtype_for_flashattn(kv_cache_dtype: str) -> torch.dtype:
104+
if kv_cache_dtype in ("fp8", "fp8_e4m3"):
105+
return torch.float8_e4m3fn
106+
else:
107+
raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")
108+
102109

103110
@dataclass
104111
class FlashAttentionMetadata:
@@ -161,6 +168,7 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
161168
self.parallel_config)
162169
self.num_heads_kv = self.model_config.get_num_kv_heads(
163170
self.parallel_config)
171+
self.kv_cache_dtype = kv_cache_spec.dtype
164172
self.headdim = self.model_config.get_head_size()
165173
self.block_size = kv_cache_spec.block_size
166174

@@ -239,17 +247,24 @@ def build(self,
239247

240248
def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
241249
max_seq_len, causal):
250+
cache_dtype = self.cache_config.cache_dtype
251+
if cache_dtype.startswith("fp8"):
252+
qkv_dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn(
253+
cache_dtype)
254+
else:
255+
qkv_dtype = self.kv_cache_dtype
242256
if aot_schedule:
243257
return get_scheduler_metadata(
244258
batch_size=batch_size,
245259
max_seqlen_q=max_query_len,
246260
max_seqlen_k=max_seq_len,
247-
cache_seqlens=seqlens,
248261
num_heads_q=self.num_heads_q,
249262
num_heads_kv=self.num_heads_kv,
250263
headdim=self.headdim,
251-
page_size=self.block_size,
264+
cache_seqlens=seqlens,
265+
qkv_dtype=qkv_dtype,
252266
cu_seqlens_q=cu_query_lens,
267+
page_size=self.block_size,
253268
causal=causal,
254269
window_size=self.aot_sliding_window,
255270
num_splits=self.max_num_splits,
@@ -474,8 +489,10 @@ def forward(
474489
)
475490

476491
if self.kv_cache_dtype.startswith("fp8"):
477-
key_cache = key_cache.view(torch.float8_e4m3fn)
478-
value_cache = value_cache.view(torch.float8_e4m3fn)
492+
dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn(
493+
self.kv_cache_dtype)
494+
key_cache = key_cache.view(dtype)
495+
value_cache = value_cache.view(dtype)
479496
num_tokens, num_heads, head_size = query.shape
480497
query, _ = ops.scaled_fp8_quant(
481498
query.reshape(

0 commit comments

Comments
 (0)